mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-13 12:27:59 +00:00
fix(auth/csvauth): turn the old CLI-only warnings and errors into returned errors
This commit is contained in:
parent
8842791e34
commit
d415a8c743
@ -387,7 +387,11 @@ func handleSet(args []string, aesKey []byte, csvFile csvauth.NamedReadCloser) {
|
|||||||
records = append(records, record)
|
records = append(records, record)
|
||||||
}
|
}
|
||||||
for _, u := range slices.Sorted(auth.CredentialKeys()) {
|
for _, u := range slices.Sorted(auth.CredentialKeys()) {
|
||||||
c, _ := auth.LoadCredential(u)
|
c, err := auth.LoadCredential(u)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Sanity fail while loading %s: %v", u, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
record := c.ToRecord()
|
record := c.ToRecord()
|
||||||
records = append(records, record)
|
records = append(records, record)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,8 +3,8 @@ package csvauth
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/therootcompany/golib/auth"
|
"github.com/therootcompany/golib/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrDecodeFields = errors.New("could not decode credential")
|
||||||
|
|
||||||
type BasicAuthVerifier interface {
|
type BasicAuthVerifier interface {
|
||||||
Verify(string, string) error
|
Verify(string, string) error
|
||||||
}
|
}
|
||||||
@ -110,13 +112,13 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
|
|||||||
paramList = strings.ReplaceAll(paramList, ",", " ")
|
paramList = strings.ReplaceAll(paramList, ",", " ")
|
||||||
credential.Params = strings.Split(paramList, " ")
|
credential.Params = strings.Split(paramList, " ")
|
||||||
if len(credential.Params) == 0 {
|
if len(credential.Params) == 0 {
|
||||||
fmt.Fprintf(os.Stderr, "no algorithm parameters for %q\n", name)
|
return Credential{}, fmt.Errorf("%w: no algorithm parameters for %q", ErrDecodeFields, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch credential.Params[0] {
|
switch credential.Params[0] {
|
||||||
case "aes-128-gcm":
|
case "aes-128-gcm":
|
||||||
if len(credential.Params) > 1 {
|
if len(credential.Params) > 1 {
|
||||||
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
|
return credential, fmt.Errorf("%w: invalid plain parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
|
||||||
}
|
}
|
||||||
|
|
||||||
salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
|
salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
|
||||||
@ -143,12 +145,12 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
|
|||||||
|
|
||||||
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
|
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "could not decode salt %q for %q\n", saltBase64, name)
|
return credential, fmt.Errorf("%w: bad salt for %q: %q", ErrDecodeFields, name, saltBase64)
|
||||||
}
|
}
|
||||||
|
|
||||||
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
|
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "could not decode derived data %q for %q\n", derived, name)
|
return credential, fmt.Errorf("%w: bad derived data for %q: %q", ErrDecodeFields, name, derived)
|
||||||
}
|
}
|
||||||
|
|
||||||
iters, err := strconv.Atoi(credential.Params[1])
|
iters, err := strconv.Atoi(credential.Params[1])
|
||||||
@ -156,7 +158,7 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
|
|||||||
return credential, err
|
return credential, err
|
||||||
}
|
}
|
||||||
if iters <= 0 {
|
if iters <= 0 {
|
||||||
return credential, fmt.Errorf("invalid iterations %s", credential.Params[1])
|
return credential, fmt.Errorf("%w: invalid iterations for %q: %q", ErrDecodeFields, name, credential.Params[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
size, err := strconv.Atoi(credential.Params[2])
|
size, err := strconv.Atoi(credential.Params[2])
|
||||||
@ -164,20 +166,20 @@ func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra s
|
|||||||
return credential, err
|
return credential, err
|
||||||
}
|
}
|
||||||
if size < 8 || size > 32 {
|
if size < 8 || size > 32 {
|
||||||
return credential, fmt.Errorf("invalid size %s", credential.Params[2])
|
return credential, fmt.Errorf("%w: invalid size for %q: %q", ErrDecodeFields, name, credential.Params[2])
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
|
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
|
||||||
return credential, fmt.Errorf("invalid hash %s", credential.Params[3])
|
return credential, fmt.Errorf("%w: invalid hash for %q: %q", ErrDecodeFields, name, credential.Params[3])
|
||||||
}
|
}
|
||||||
case "bcrypt":
|
case "bcrypt":
|
||||||
if len(credential.Params) > 1 {
|
if len(credential.Params) > 1 {
|
||||||
return credential, fmt.Errorf("invalid bcrypt parameters %#v", credential.Params)
|
return credential, fmt.Errorf("%w: invalid bcrypt parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
|
||||||
}
|
}
|
||||||
|
|
||||||
credential.Derived = []byte(derived)
|
credential.Derived = []byte(derived)
|
||||||
default:
|
default:
|
||||||
return credential, fmt.Errorf("invalid algorithm %s", credential.Params[0])
|
return credential, fmt.Errorf("%w: invalid algorithm for %q: %q", ErrDecodeFields, name, credential.Params[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
return credential, nil
|
return credential, nil
|
||||||
|
|||||||
@ -146,25 +146,25 @@ func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := a.credentials[name]; ok {
|
if _, ok := a.credentials[name]; ok {
|
||||||
fmt.Fprintf(os.Stderr, "overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
fmt.Fprintf(os.Stderr, "warn: overwriting plain cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
||||||
}
|
}
|
||||||
a.credentials[name] = credential
|
a.credentials[name] = credential
|
||||||
|
|
||||||
nameID := a.nameCacheID(name)
|
nameID := a.nameCacheID(name)
|
||||||
if _, ok := a.hashedCredentials[nameID]; ok {
|
if _, ok := a.hashedCredentials[nameID]; ok {
|
||||||
fmt.Fprintf(os.Stderr, "overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
fmt.Fprintf(os.Stderr, "warn: overwriting hashed cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
||||||
}
|
}
|
||||||
a.hashedCredentials[nameID] = credential
|
a.hashedCredentials[nameID] = credential
|
||||||
|
|
||||||
if credential.Purpose == PurposeToken {
|
if credential.Purpose == PurposeToken {
|
||||||
if _, ok := a.tokens[credential.hashID]; ok {
|
if _, ok := a.tokens[credential.hashID]; ok {
|
||||||
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
||||||
}
|
}
|
||||||
a.tokens[credential.hashID] = credential
|
a.tokens[credential.hashID] = credential
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if _, ok := a.serviceAccounts[credential.Purpose]; ok {
|
if _, ok := a.serviceAccounts[credential.Purpose]; ok {
|
||||||
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
fmt.Fprintf(os.Stderr, "warn: overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
||||||
}
|
}
|
||||||
a.serviceAccounts[credential.Purpose] = credential
|
a.serviceAccounts[credential.Purpose] = credential
|
||||||
}
|
}
|
||||||
@ -193,8 +193,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
|
|||||||
switch c.Params[0] {
|
switch c.Params[0] {
|
||||||
case "plain":
|
case "plain":
|
||||||
if len(params) != 1 {
|
if len(params) != 1 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " "))
|
panic(fmt.Errorf("invalid plain algorithm format: %q", strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
c.plain = secretValue(secret)
|
c.plain = secretValue(secret)
|
||||||
|
|
||||||
@ -203,8 +202,7 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
|
|||||||
c.Derived = h[:]
|
c.Derived = h[:]
|
||||||
case "aes-128-gcm":
|
case "aes-128-gcm":
|
||||||
if len(params) != 1 {
|
if len(params) != 1 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " "))
|
panic(fmt.Errorf("invalid aes-128-gcm algorithm format: %q", strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Params = []string{"aes-128-gcm"}
|
c.Params = []string{"aes-128-gcm"}
|
||||||
@ -220,21 +218,18 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
|
|||||||
c.plain = secretValue(secret)
|
c.plain = secretValue(secret)
|
||||||
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
|
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err)
|
panic(fmt.Errorf("could not aes-128-gcm encrypt secret: %w", err))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
case "pbkdf2":
|
case "pbkdf2":
|
||||||
if len(params) > 4 {
|
if len(params) > 4 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " "))
|
panic(fmt.Errorf("invalid pbkdf2 algorithm format: %q", strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
iters := defaultIters
|
iters := defaultIters
|
||||||
if len(params) > 1 {
|
if len(params) > 1 {
|
||||||
var err error
|
var err error
|
||||||
iters, err = strconv.Atoi(params[1])
|
iters, err = strconv.Atoi(params[1])
|
||||||
if err != nil || iters <= 0 {
|
if err != nil || iters <= 0 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " "))
|
panic(fmt.Errorf("invalid iterations %q in %q", params[1], strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
size := defaultSize
|
size := defaultSize
|
||||||
@ -242,15 +237,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
|
|||||||
var err error
|
var err error
|
||||||
size, err = strconv.Atoi(params[2])
|
size, err = strconv.Atoi(params[2])
|
||||||
if err != nil || size < 8 || size > 32 {
|
if err != nil || size < 8 || size > 32 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " "))
|
panic(fmt.Errorf("invalid size %q in %q", params[2], strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hashName := defaultHash
|
hashName := defaultHash
|
||||||
if len(params) > 3 {
|
if len(params) > 3 {
|
||||||
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
|
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
|
||||||
fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " "))
|
panic(fmt.Errorf("invalid hash %q in %q", params[3], strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
hashName = params[3]
|
hashName = params[3]
|
||||||
}
|
}
|
||||||
@ -264,45 +257,39 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
|
|||||||
hashNameUpper := strings.ToUpper(hashName)
|
hashNameUpper := strings.ToUpper(hashName)
|
||||||
switch hashNameUpper {
|
switch hashNameUpper {
|
||||||
case "SHA-1", "SHA1":
|
case "SHA-1", "SHA1":
|
||||||
hashName = "SHA-1"
|
// hashName = "SHA-1"
|
||||||
hasher = sha1.New
|
hasher = sha1.New
|
||||||
case "SHA-256", "SHA256":
|
case "SHA-256", "SHA256":
|
||||||
hashName = "SHA-256"
|
// hashName = "SHA-256"
|
||||||
hasher = sha256.New
|
hasher = sha256.New
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName)
|
panic(fmt.Errorf("invalid hash %q (expected SHA-1 or SHA-256)", hashName))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
|
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err)
|
panic(fmt.Errorf("invalid pbkdf2 parameters: %v", err))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
case "bcrypt":
|
case "bcrypt":
|
||||||
if len(params) > 2 {
|
if len(params) > 2 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " "))
|
panic(fmt.Errorf("invalid bcrypt algorithm format: %q", strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
cost := defaultBcryptCost
|
cost := defaultBcryptCost
|
||||||
if len(params) > 1 {
|
if len(params) > 1 {
|
||||||
var err error
|
var err error
|
||||||
cost, err = strconv.Atoi(params[1])
|
cost, err = strconv.Atoi(params[1])
|
||||||
if err != nil || cost < 4 || cost > 31 {
|
if err != nil || cost < 4 || cost > 31 {
|
||||||
fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " "))
|
panic(fmt.Errorf("invalid bcrypt cost %q in %q", params[1], strings.Join(params, " ")))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Params = []string{"bcrypt"} // cost is included in the digest
|
c.Params = []string{"bcrypt"} // cost is included in the digest
|
||||||
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
|
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err)
|
panic(fmt.Errorf("failed bcrypt hash: %w", err))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
c.Derived = derived
|
c.Derived = derived
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0])
|
panic(fmt.Errorf("invalid algorithm: %q", params[0]))
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
@ -311,13 +298,13 @@ func (a *Auth) NewCredential(purpose, name, secret string, params []string, role
|
|||||||
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
|
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
|
||||||
block, err := aes.NewCipher(aes128key[:])
|
block, err := aes.NewCipher(aes128key[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err)
|
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nonceSize := len(gcmNonce) // should always be 12
|
// nonceSize := len(gcmNonce) // should always be 12
|
||||||
gcm, err := cipher.NewGCM(block)
|
gcm, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err)
|
return nil, fmt.Errorf("new gcm (encrypt) failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
plaintext := []byte(secret)
|
plaintext := []byte(secret)
|
||||||
@ -342,17 +329,17 @@ func (a *Auth) maybeDecryptCredential(c Credential) (secretValue, error) {
|
|||||||
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
|
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
|
||||||
block, err := aes.NewCipher(aes128key[:])
|
block, err := aes.NewCipher(aes128key[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err)
|
return "", fmt.Errorf("new aes (decrypt) cipher failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
gcm, err := cipher.NewGCM(block)
|
gcm, err := cipher.NewGCM(block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("new gcm failed: %v", err)
|
return "", fmt.Errorf("new gcm failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
|
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("gcm open (decryption) failed: %v", err)
|
return "", fmt.Errorf("gcm open (decryption) failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(plaintext), nil
|
return string(plaintext), nil
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user