fix(auth/csvauth): turn the old CLI-only warnings and errors into returned errors

This commit is contained in:
AJ ONeal 2026-03-03 01:19:43 -07:00
parent 8842791e34
commit d415a8c743
No known key found for this signature in database
3 changed files with 41 additions and 48 deletions

View File

@ -387,7 +387,11 @@ func handleSet(args []string, aesKey []byte, csvFile csvauth.NamedReadCloser) {
records = append(records, record) records = append(records, record)
} }
for _, u := range slices.Sorted(auth.CredentialKeys()) { for _, u := range slices.Sorted(auth.CredentialKeys()) {
c, _ := auth.LoadCredential(u) c, err := auth.LoadCredential(u)
if err != nil {
fmt.Fprintf(os.Stderr, "Sanity fail while loading %s: %v", u, err)
continue
}
record := c.ToRecord() record := c.ToRecord()
records = append(records, record) records = append(records, record)
} }

View File

@ -3,8 +3,8 @@ package csvauth
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"os"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@ -12,6 +12,8 @@ import (
"github.com/therootcompany/golib/auth" "github.com/therootcompany/golib/auth"
) )
var ErrDecodeFields = errors.New("could not decode credential")
type BasicAuthVerifier interface { type BasicAuthVerifier interface {
Verify(string, string) error Verify(string, string) error
} }
@ -110,13 +112,13 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
paramList = strings.ReplaceAll(paramList, ",", " ") paramList = strings.ReplaceAll(paramList, ",", " ")
credential.Params = strings.Split(paramList, " ") credential.Params = strings.Split(paramList, " ")
if len(credential.Params) == 0 { if len(credential.Params) == 0 {
fmt.Fprintf(os.Stderr, "no algorithm parameters for %q\n", name) return Credential{}, fmt.Errorf("%w: no algorithm parameters for %q", ErrDecodeFields, name)
} }
switch credential.Params[0] { switch credential.Params[0] {
case "aes-128-gcm": case "aes-128-gcm":
if len(credential.Params) > 1 { if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params) return credential, fmt.Errorf("%w: invalid plain parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
} }
salt, err := base64.RawURLEncoding.DecodeString(saltBase64) salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
@ -143,12 +145,12 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64) credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "could not decode salt %q for %q\n", saltBase64, name) return credential, fmt.Errorf("%w: bad salt for %q: %q", ErrDecodeFields, name, saltBase64)
} }
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived) credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "could not decode derived data %q for %q\n", derived, name) return credential, fmt.Errorf("%w: bad derived data for %q: %q", ErrDecodeFields, name, derived)
} }
iters, err := strconv.Atoi(credential.Params[1]) iters, err := strconv.Atoi(credential.Params[1])
@ -156,7 +158,7 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
return credential, err return credential, err
} }
if iters <= 0 { if iters <= 0 {
return credential, fmt.Errorf("invalid iterations %s", credential.Params[1]) return credential, fmt.Errorf("%w: invalid iterations for %q: %q", ErrDecodeFields, name, credential.Params[1])
} }
size, err := strconv.Atoi(credential.Params[2]) size, err := strconv.Atoi(credential.Params[2])
@ -164,20 +166,20 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
return credential, err return credential, err
} }
if size < 8 || size > 32 { if size < 8 || size > 32 {
return credential, fmt.Errorf("invalid size %s", credential.Params[2]) return credential, fmt.Errorf("%w: invalid size for %q: %q", ErrDecodeFields, name, credential.Params[2])
} }
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) { if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
return credential, fmt.Errorf("invalid hash %s", credential.Params[3]) return credential, fmt.Errorf("%w: invalid hash for %q: %q", ErrDecodeFields, name, credential.Params[3])
} }
case "bcrypt": case "bcrypt":
if len(credential.Params) > 1 { if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid bcrypt parameters %#v", credential.Params) return credential, fmt.Errorf("%w: invalid bcrypt parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
} }
credential.Derived = []byte(derived) credential.Derived = []byte(derived)
default: default:
return credential, fmt.Errorf("invalid algorithm %s", credential.Params[0]) return credential, fmt.Errorf("%w: invalid algorithm for %q: %q", ErrDecodeFields, name, credential.Params[0])
} }
return credential, nil return credential, nil

View File

@ -146,25 +146,25 @@ func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
} }
if _, ok := a.credentials[name]; ok { if _, ok := a.credentials[name]; ok {
fmt.Fprintf(os.Stderr, "overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name) fmt.Fprintf(os.Stderr, "warn: overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
} }
a.credentials[name] = credential a.credentials[name] = credential
nameID := a.nameCacheID(name) nameID := a.nameCacheID(name)
if _, ok := a.hashedCredentials[nameID]; ok { if _, ok := a.hashedCredentials[nameID]; ok {
fmt.Fprintf(os.Stderr, "overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name) fmt.Fprintf(os.Stderr, "warn: overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
} }
a.hashedCredentials[nameID] = credential a.hashedCredentials[nameID] = credential
if credential.Purpose == PurposeToken { if credential.Purpose == PurposeToken {
if _, ok := a.tokens[credential.hashID]; ok { if _, ok := a.tokens[credential.hashID]; ok {
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name) fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
} }
a.tokens[credential.hashID] = credential a.tokens[credential.hashID] = credential
} }
default: default:
if _, ok := a.serviceAccounts[credential.Purpose]; ok { if _, ok := a.serviceAccounts[credential.Purpose]; ok {
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name) fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
} }
a.serviceAccounts[credential.Purpose] = credential a.serviceAccounts[credential.Purpose] = credential
} }
@ -193,8 +193,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
switch c.Params[0] { switch c.Params[0] {
case "plain": case "plain":
if len(params) != 1 { if len(params) != 1 {
fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " ")) panic(fmt.Errorf("invalid plain algorithm format: %q", strings.Join(params, " ")))
os.Exit(1)
} }
c.plain = secretValue(secret) c.plain = secretValue(secret)
@ -203,8 +202,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
c.Derived = h[:] c.Derived = h[:]
case "aes-128-gcm": case "aes-128-gcm":
if len(params) != 1 { if len(params) != 1 {
fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " ")) panic(fmt.Errorf("invalid aes-128-gcm algorithm format: %q", strings.Join(params, " ")))
os.Exit(1)
} }
c.Params = []string{"aes-128-gcm"} c.Params = []string{"aes-128-gcm"}
@ -220,21 +218,18 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
c.plain = secretValue(secret) c.plain = secretValue(secret)
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret) c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err) panic(fmt.Errorf("could not aes-128-gcm encrypt secret: %w", err))
os.Exit(1)
} }
case "pbkdf2": case "pbkdf2":
if len(params) > 4 { if len(params) > 4 {
fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " ")) panic(fmt.Errorf("invalid pbkdf2 algorithm format: %q", strings.Join(params, " ")))
os.Exit(1)
} }
iters := defaultIters iters := defaultIters
if len(params) > 1 { if len(params) > 1 {
var err error var err error
iters, err = strconv.Atoi(params[1]) iters, err = strconv.Atoi(params[1])
if err != nil || iters <= 0 { if err != nil || iters <= 0 {
fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " ")) panic(fmt.Errorf("invalid iterations %q in %q", params[1], strings.Join(params, " ")))
os.Exit(1)
} }
} }
size := defaultSize size := defaultSize
@ -242,15 +237,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
var err error var err error
size, err = strconv.Atoi(params[2]) size, err = strconv.Atoi(params[2])
if err != nil || size < 8 || size > 32 { if err != nil || size < 8 || size > 32 {
fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " ")) panic(fmt.Errorf("invalid size %q in %q", params[2], strings.Join(params, " ")))
os.Exit(1)
} }
} }
hashName := defaultHash hashName := defaultHash
if len(params) > 3 { if len(params) > 3 {
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) { if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " ")) panic(fmt.Errorf("invalid hash %q in %q", params[3], strings.Join(params, " ")))
os.Exit(1)
} }
hashName = params[3] hashName = params[3]
} }
@ -264,45 +257,39 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
hashNameUpper := strings.ToUpper(hashName) hashNameUpper := strings.ToUpper(hashName)
switch hashNameUpper { switch hashNameUpper {
case "SHA-1", "SHA1": case "SHA-1", "SHA1":
hashName = "SHA-1" // hashName = "SHA-1"
hasher = sha1.New hasher = sha1.New
case "SHA-256", "SHA256": case "SHA-256", "SHA256":
hashName = "SHA-256" // hashName = "SHA-256"
hasher = sha256.New hasher = sha256.New
default: default:
fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName) panic(fmt.Errorf("invalid hash %q (expected SHA-1 or SHA-256)", hashName))
os.Exit(1)
} }
var err error var err error
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size) c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err) panic(fmt.Errorf("invalid pbkdf2 parameters: %v", err))
os.Exit(1)
} }
case "bcrypt": case "bcrypt":
if len(params) > 2 { if len(params) > 2 {
fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " ")) panic(fmt.Errorf("invalid bcrypt algorithm format: %q", strings.Join(params, " ")))
os.Exit(1)
} }
cost := defaultBcryptCost cost := defaultBcryptCost
if len(params) > 1 { if len(params) > 1 {
var err error var err error
cost, err = strconv.Atoi(params[1]) cost, err = strconv.Atoi(params[1])
if err != nil || cost < 4 || cost > 31 { if err != nil || cost < 4 || cost > 31 {
fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " ")) panic(fmt.Errorf("invalid bcrypt cost %q in %q", params[1], strings.Join(params, " ")))
os.Exit(1)
} }
} }
c.Params = []string{"bcrypt"} // cost is included in the digest c.Params = []string{"bcrypt"} // cost is included in the digest
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost) derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err) panic(fmt.Errorf("failed bcrypt hash: %w", err))
os.Exit(1)
} }
c.Derived = derived c.Derived = derived
default: default:
fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0]) panic(fmt.Errorf("invalid algorithm: %q", params[0]))
os.Exit(1)
} }
return c return c
@ -311,13 +298,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) { func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
block, err := aes.NewCipher(aes128key[:]) block, err := aes.NewCipher(aes128key[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err) return nil, fmt.Errorf("new aes (encrypt) cipher failed: %w", err)
} }
// nonceSize := len(gcmNonce) // should always be 12 // nonceSize := len(gcmNonce) // should always be 12
gcm, err := cipher.NewGCM(block) gcm, err := cipher.NewGCM(block)
if err != nil { if err != nil {
return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err) return nil, fmt.Errorf("new gcm (encrypt) failed: %w", err)
} }
plaintext := []byte(secret) plaintext := []byte(secret)
@ -342,17 +329,17 @@ func (a *Auth) maybeDecryptCredential(c Credential) (secretValue, error) {
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) { func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
block, err := aes.NewCipher(aes128key[:]) block, err := aes.NewCipher(aes128key[:])
if err != nil { if err != nil {
return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err) return "", fmt.Errorf("new aes (decrypt) cipher failed: %w", err)
} }
gcm, err := cipher.NewGCM(block) gcm, err := cipher.NewGCM(block)
if err != nil { if err != nil {
return "", fmt.Errorf("new gcm failed: %v", err) return "", fmt.Errorf("new gcm failed: %w", err)
} }
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil) plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("gcm open (decryption) failed: %v", err) return "", fmt.Errorf("gcm open (decryption) failed: %w", err)
} }
return string(plaintext), nil return string(plaintext), nil