From d415a8c7432d383a64b509e302dd3bf2610fe087 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 3 Mar 2026 01:19:43 -0700 Subject: [PATCH] fix(auth/csvauth): turn the old CLI-only warnings and errors into returned errors --- auth/csvauth/cmd/csvauth/main.go | 6 +++- auth/csvauth/credential.go | 22 ++++++------ auth/csvauth/csvauth.go | 61 +++++++++++++------------------- 3 files changed, 41 insertions(+), 48 deletions(-) diff --git a/auth/csvauth/cmd/csvauth/main.go b/auth/csvauth/cmd/csvauth/main.go index 888169a..344d821 100644 --- a/auth/csvauth/cmd/csvauth/main.go +++ b/auth/csvauth/cmd/csvauth/main.go @@ -387,7 +387,11 @@ func handleSet(args []string, aesKey []byte, csvFile csvauth.NamedReadCloser) { records = append(records, record) } for _, u := range slices.Sorted(auth.CredentialKeys()) { - c, _ := auth.LoadCredential(u) + c, err := auth.LoadCredential(u) + if err != nil { + fmt.Fprintf(os.Stderr, "Sanity fail while loading %s: %v", u, err) + continue + } record := c.ToRecord() records = append(records, record) } diff --git a/auth/csvauth/credential.go b/auth/csvauth/credential.go index f42f18e..2026913 100644 --- a/auth/csvauth/credential.go +++ b/auth/csvauth/credential.go @@ -3,8 +3,8 @@ package csvauth import ( "crypto/sha256" "encoding/base64" + "errors" "fmt" - "os" "slices" "strconv" "strings" @@ -12,6 +12,8 @@ import ( "github.com/therootcompany/golib/auth" ) +var ErrDecodeFields = errors.New("could not decode credential") + type BasicAuthVerifier interface { Verify(string, string) error } @@ -110,13 +112,13 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s paramList = strings.ReplaceAll(paramList, ",", " ") credential.Params = strings.Split(paramList, " ") if len(credential.Params) == 0 { - fmt.Fprintf(os.Stderr, "no algorithm parameters for %q\n", name) + return Credential{}, fmt.Errorf("%w: no algorithm parameters for %q", ErrDecodeFields, name) } switch credential.Params[0] { case "aes-128-gcm": if len(credential.Params) > 1 { - return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params) + return credential, fmt.Errorf("%w: invalid plain parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`)) } salt, err := base64.RawURLEncoding.DecodeString(saltBase64) @@ -143,12 +145,12 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64) if err != nil { - fmt.Fprintf(os.Stderr, "could not decode salt %q for %q\n", saltBase64, name) + return credential, fmt.Errorf("%w: bad salt for %q: %q", ErrDecodeFields, name, saltBase64) } credential.Derived, err = base64.RawURLEncoding.DecodeString(derived) if err != nil { - fmt.Fprintf(os.Stderr, "could not decode derived data %q for %q\n", derived, name) + return credential, fmt.Errorf("%w: bad derived data for %q: %q", ErrDecodeFields, name, derived) } iters, err := strconv.Atoi(credential.Params[1]) @@ -156,7 +158,7 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s return credential, err } if iters <= 0 { - return credential, fmt.Errorf("invalid iterations %s", credential.Params[1]) + return credential, fmt.Errorf("%w: invalid iterations for %q: %q", ErrDecodeFields, name, credential.Params[1]) } size, err := strconv.Atoi(credential.Params[2]) @@ -164,20 +166,20 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s return credential, err } if size < 8 || size > 32 { - return credential, fmt.Errorf("invalid size %s", credential.Params[2]) + return credential, fmt.Errorf("%w: invalid size for %q: %q", ErrDecodeFields, name, credential.Params[2]) } if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) { - return credential, fmt.Errorf("invalid hash %s", credential.Params[3]) + return credential, fmt.Errorf("%w: invalid hash for %q: %q", ErrDecodeFields, name, credential.Params[3]) } case "bcrypt": if len(credential.Params) > 1 { - return credential, fmt.Errorf("invalid bcrypt parameters %#v", credential.Params) + return credential, fmt.Errorf("%w: invalid bcrypt parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`)) } credential.Derived = []byte(derived) default: - return credential, fmt.Errorf("invalid algorithm %s", credential.Params[0]) + return credential, fmt.Errorf("%w: invalid algorithm for %q: %q", ErrDecodeFields, name, credential.Params[0]) } return credential, nil diff --git a/auth/csvauth/csvauth.go b/auth/csvauth/csvauth.go index bfeedb8..46d5a1f 100644 --- a/auth/csvauth/csvauth.go +++ b/auth/csvauth/csvauth.go @@ -146,25 +146,25 @@ func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error { } if _, ok := a.credentials[name]; ok { - fmt.Fprintf(os.Stderr, "overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name) + fmt.Fprintf(os.Stderr, "warn: overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name) } a.credentials[name] = credential nameID := a.nameCacheID(name) if _, ok := a.hashedCredentials[nameID]; ok { - fmt.Fprintf(os.Stderr, "overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name) + fmt.Fprintf(os.Stderr, "warn: overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name) } a.hashedCredentials[nameID] = credential if credential.Purpose == PurposeToken { if _, ok := a.tokens[credential.hashID]; ok { - fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name) + fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name) } a.tokens[credential.hashID] = credential } default: if _, ok := a.serviceAccounts[credential.Purpose]; ok { - fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name) + fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name) } a.serviceAccounts[credential.Purpose] = credential } @@ -193,8 +193,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role switch c.Params[0] { case "plain": if len(params) != 1 { - fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid plain algorithm format: %q", strings.Join(params, " "))) } c.plain = secretValue(secret) @@ -203,8 +202,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role c.Derived = h[:] case "aes-128-gcm": if len(params) != 1 { - fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid aes-128-gcm algorithm format: %q", strings.Join(params, " "))) } c.Params = []string{"aes-128-gcm"} @@ -220,21 +218,18 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role c.plain = secretValue(secret) c.Derived, err = gcmEncrypt(a.aes128key, salt, secret) if err != nil { - fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err) - os.Exit(1) + panic(fmt.Errorf("could not aes-128-gcm encrypt secret: %w", err)) } case "pbkdf2": if len(params) > 4 { - fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid pbkdf2 algorithm format: %q", strings.Join(params, " "))) } iters := defaultIters if len(params) > 1 { var err error iters, err = strconv.Atoi(params[1]) if err != nil || iters <= 0 { - fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid iterations %q in %q", params[1], strings.Join(params, " "))) } } size := defaultSize @@ -242,15 +237,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role var err error size, err = strconv.Atoi(params[2]) if err != nil || size < 8 || size > 32 { - fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid size %q in %q", params[2], strings.Join(params, " "))) } } hashName := defaultHash if len(params) > 3 { if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) { - fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid hash %q in %q", params[3], strings.Join(params, " "))) } hashName = params[3] } @@ -264,45 +257,39 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role hashNameUpper := strings.ToUpper(hashName) switch hashNameUpper { case "SHA-1", "SHA1": - hashName = "SHA-1" + // hashName = "SHA-1" hasher = sha1.New case "SHA-256", "SHA256": - hashName = "SHA-256" + // hashName = "SHA-256" hasher = sha256.New default: - fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName) - os.Exit(1) + panic(fmt.Errorf("invalid hash %q (expected SHA-1 or SHA-256)", hashName)) } var err error c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size) if err != nil { - fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err) - os.Exit(1) + panic(fmt.Errorf("invalid pbkdf2 parameters: %v", err)) } case "bcrypt": if len(params) > 2 { - fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid bcrypt algorithm format: %q", strings.Join(params, " "))) } cost := defaultBcryptCost if len(params) > 1 { var err error cost, err = strconv.Atoi(params[1]) if err != nil || cost < 4 || cost > 31 { - fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " ")) - os.Exit(1) + panic(fmt.Errorf("invalid bcrypt cost %q in %q", params[1], strings.Join(params, " "))) } } c.Params = []string{"bcrypt"} // cost is included in the digest derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost) if err != nil { - fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err) - os.Exit(1) + panic(fmt.Errorf("failed bcrypt hash: %w", err)) } c.Derived = derived default: - fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0]) - os.Exit(1) + panic(fmt.Errorf("invalid algorithm: %q", params[0])) } return c @@ -311,13 +298,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) { block, err := aes.NewCipher(aes128key[:]) if err != nil { - return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err) + return nil, fmt.Errorf("new aes (encrypt) cipher failed: %w", err) } // nonceSize := len(gcmNonce) // should always be 12 gcm, err := cipher.NewGCM(block) if err != nil { - return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err) + return nil, fmt.Errorf("new gcm (encrypt) failed: %w", err) } plaintext := []byte(secret) @@ -342,17 +329,17 @@ func (a *Auth) maybeDecryptCredential(c Credential) (secretValue, error) { func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) { block, err := aes.NewCipher(aes128key[:]) if err != nil { - return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err) + return "", fmt.Errorf("new aes (decrypt) cipher failed: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { - return "", fmt.Errorf("new gcm failed: %v", err) + return "", fmt.Errorf("new gcm failed: %w", err) } plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil) if err != nil { - return "", fmt.Errorf("gcm open (decryption) failed: %v", err) + return "", fmt.Errorf("gcm open (decryption) failed: %w", err) } return string(plaintext), nil