fix(auth/csvauth): turn the old CLI-only warnings and errors into returned errors

This commit is contained in:
AJ ONeal 2026-03-03 01:19:43 -07:00
parent 8842791e34
commit d415a8c743
No known key found for this signature in database
3 changed files with 41 additions and 48 deletions

View File

@ -387,7 +387,11 @@ func handleSet(args []string, aesKey []byte, csvFile csvauth.NamedReadCloser) {
records = append(records, record)
}
for _, u := range slices.Sorted(auth.CredentialKeys()) {
c, _ := auth.LoadCredential(u)
c, err := auth.LoadCredential(u)
if err != nil {
fmt.Fprintf(os.Stderr, "Sanity fail while loading %s: %v", u, err)
continue
}
record := c.ToRecord()
records = append(records, record)
}

View File

@ -3,8 +3,8 @@ package csvauth
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"os"
"slices"
"strconv"
"strings"
@ -12,6 +12,8 @@ import (
"github.com/therootcompany/golib/auth"
)
var ErrDecodeFields = errors.New("could not decode credential")
type BasicAuthVerifier interface {
Verify(string, string) error
}
@ -110,13 +112,13 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
paramList = strings.ReplaceAll(paramList, ",", " ")
credential.Params = strings.Split(paramList, " ")
if len(credential.Params) == 0 {
fmt.Fprintf(os.Stderr, "no algorithm parameters for %q\n", name)
return Credential{}, fmt.Errorf("%w: no algorithm parameters for %q", ErrDecodeFields, name)
}
switch credential.Params[0] {
case "aes-128-gcm":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
return credential, fmt.Errorf("%w: invalid plain parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
}
salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
@ -143,12 +145,12 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "could not decode salt %q for %q\n", saltBase64, name)
return credential, fmt.Errorf("%w: bad salt for %q: %q", ErrDecodeFields, name, saltBase64)
}
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
if err != nil {
fmt.Fprintf(os.Stderr, "could not decode derived data %q for %q\n", derived, name)
return credential, fmt.Errorf("%w: bad derived data for %q: %q", ErrDecodeFields, name, derived)
}
iters, err := strconv.Atoi(credential.Params[1])
@ -156,7 +158,7 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
return credential, err
}
if iters <= 0 {
return credential, fmt.Errorf("invalid iterations %s", credential.Params[1])
return credential, fmt.Errorf("%w: invalid iterations for %q: %q", ErrDecodeFields, name, credential.Params[1])
}
size, err := strconv.Atoi(credential.Params[2])
@ -164,20 +166,20 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
return credential, err
}
if size < 8 || size > 32 {
return credential, fmt.Errorf("invalid size %s", credential.Params[2])
return credential, fmt.Errorf("%w: invalid size for %q: %q", ErrDecodeFields, name, credential.Params[2])
}
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
return credential, fmt.Errorf("invalid hash %s", credential.Params[3])
return credential, fmt.Errorf("%w: invalid hash for %q: %q", ErrDecodeFields, name, credential.Params[3])
}
case "bcrypt":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid bcrypt parameters %#v", credential.Params)
return credential, fmt.Errorf("%w: invalid bcrypt parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
}
credential.Derived = []byte(derived)
default:
return credential, fmt.Errorf("invalid algorithm %s", credential.Params[0])
return credential, fmt.Errorf("%w: invalid algorithm for %q: %q", ErrDecodeFields, name, credential.Params[0])
}
return credential, nil

View File

@ -146,25 +146,25 @@ func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
}
if _, ok := a.credentials[name]; ok {
fmt.Fprintf(os.Stderr, "overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
fmt.Fprintf(os.Stderr, "warn: overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
}
a.credentials[name] = credential
nameID := a.nameCacheID(name)
if _, ok := a.hashedCredentials[nameID]; ok {
fmt.Fprintf(os.Stderr, "overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
fmt.Fprintf(os.Stderr, "warn: overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
}
a.hashedCredentials[nameID] = credential
if credential.Purpose == PurposeToken {
if _, ok := a.tokens[credential.hashID]; ok {
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
}
a.tokens[credential.hashID] = credential
}
default:
if _, ok := a.serviceAccounts[credential.Purpose]; ok {
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
}
a.serviceAccounts[credential.Purpose] = credential
}
@ -193,8 +193,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
switch c.Params[0] {
case "plain":
if len(params) != 1 {
fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid plain algorithm format: %q", strings.Join(params, " ")))
}
c.plain = secretValue(secret)
@ -203,8 +202,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
c.Derived = h[:]
case "aes-128-gcm":
if len(params) != 1 {
fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid aes-128-gcm algorithm format: %q", strings.Join(params, " ")))
}
c.Params = []string{"aes-128-gcm"}
@ -220,21 +218,18 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
c.plain = secretValue(secret)
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
if err != nil {
fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err)
os.Exit(1)
panic(fmt.Errorf("could not aes-128-gcm encrypt secret: %w", err))
}
case "pbkdf2":
if len(params) > 4 {
fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid pbkdf2 algorithm format: %q", strings.Join(params, " ")))
}
iters := defaultIters
if len(params) > 1 {
var err error
iters, err = strconv.Atoi(params[1])
if err != nil || iters <= 0 {
fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid iterations %q in %q", params[1], strings.Join(params, " ")))
}
}
size := defaultSize
@ -242,15 +237,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
var err error
size, err = strconv.Atoi(params[2])
if err != nil || size < 8 || size > 32 {
fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid size %q in %q", params[2], strings.Join(params, " ")))
}
}
hashName := defaultHash
if len(params) > 3 {
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid hash %q in %q", params[3], strings.Join(params, " ")))
}
hashName = params[3]
}
@ -264,45 +257,39 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
hashNameUpper := strings.ToUpper(hashName)
switch hashNameUpper {
case "SHA-1", "SHA1":
hashName = "SHA-1"
// hashName = "SHA-1"
hasher = sha1.New
case "SHA-256", "SHA256":
hashName = "SHA-256"
// hashName = "SHA-256"
hasher = sha256.New
default:
fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName)
os.Exit(1)
panic(fmt.Errorf("invalid hash %q (expected SHA-1 or SHA-256)", hashName))
}
var err error
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err)
os.Exit(1)
panic(fmt.Errorf("invalid pbkdf2 parameters: %v", err))
}
case "bcrypt":
if len(params) > 2 {
fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid bcrypt algorithm format: %q", strings.Join(params, " ")))
}
cost := defaultBcryptCost
if len(params) > 1 {
var err error
cost, err = strconv.Atoi(params[1])
if err != nil || cost < 4 || cost > 31 {
fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " "))
os.Exit(1)
panic(fmt.Errorf("invalid bcrypt cost %q in %q", params[1], strings.Join(params, " ")))
}
}
c.Params = []string{"bcrypt"} // cost is included in the digest
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
if err != nil {
fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err)
os.Exit(1)
panic(fmt.Errorf("failed bcrypt hash: %w", err))
}
c.Derived = derived
default:
fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0])
os.Exit(1)
panic(fmt.Errorf("invalid algorithm: %q", params[0]))
}
return c
@ -311,13 +298,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
block, err := aes.NewCipher(aes128key[:])
if err != nil {
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err)
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %w", err)
}
// nonceSize := len(gcmNonce) // should always be 12
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err)
return nil, fmt.Errorf("new gcm (encrypt) failed: %w", err)
}
plaintext := []byte(secret)
@ -342,17 +329,17 @@ func (a *Auth) maybeDecryptCredential(c Credential) (secretValue, error) {
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
block, err := aes.NewCipher(aes128key[:])
if err != nil {
return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err)
return "", fmt.Errorf("new aes (decrypt) cipher failed: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("new gcm failed: %v", err)
return "", fmt.Errorf("new gcm failed: %w", err)
}
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
if err != nil {
return "", fmt.Errorf("gcm open (decryption) failed: %v", err)
return "", fmt.Errorf("gcm open (decryption) failed: %w", err)
}
return string(plaintext), nil