golib/auth/csvauth/csvauth.go

428 lines
10 KiB
Go

package csvauth
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/pbkdf2"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"encoding/csv"
"errors"
"fmt"
"hash"
"io"
"iter"
"maps"
"os"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/crypto/bcrypt"
)
var ErrNotFound = errors.New("not found")
var ErrUnauthorized = errors.New("unauthorized")
var ErrUnknownAlgorithm = errors.New("unknown algorithm")
const (
defaultIters = 1000 // original 2000 recommendation
defaultSize = 16 // 128-bit
defaultHash = "SHA-256"
defaultBcryptCost = 12
gcmNonceSize = 12 // RFC spec
)
// NamedReadCloser provides Name() for debugging of file-like ReadClosers, such as http responses
type NamedReadCloser interface {
io.ReadCloser
Name() string
}
type readNamer struct {
io.ReadCloser
name string
}
// Name returns the name given to the wrapped ReadCloser to f8ulfill NamedReadCloser
func (r *readNamer) Name() string {
return r.name
}
// NewNamedReadCloser wraps a ReadCloser with a name which can be referenced when debugging
func NewNamedReadCloser(r io.ReadCloser, name string) NamedReadCloser {
return &readNamer{
ReadCloser: r,
name: name,
}
}
// Auth holds user the encryption key and both login and service account credentials
type Auth struct {
aes128key [16]byte
credentials map[Name]Credential
serviceAccounts map[Purpose]Credential
mux sync.Mutex
}
// New initializes an Auth with an encryption key
func New(aes128key []byte) *Auth {
var aes128Arr [16]byte
copy(aes128Arr[:], aes128key)
return &Auth{
aes128key: aes128Arr,
credentials: map[Name]Credential{},
serviceAccounts: map[Purpose]Credential{},
}
}
// Load reads a credentials CSV from the given NamedReadCloser (e.g. file, wrapped http request)
func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
csvr := csv.NewReader(f)
csvr.Comma = comma
csvr.Comment = '#'
csvr.FieldsPerRecord = -1 // ignore short rows
_, _ = csvr.Read() // strip header row
for {
record, err := csvr.Read()
if err == io.EOF {
break
}
if err != nil {
return err
}
if len(record) == 0 {
continue
}
if len(record) == 1 {
if len(record[0]) == 0 {
continue
}
}
if len(record) < 5 {
return fmt.Errorf("invalid %q format: %#v (%d)", f.Name(), record, len(record))
}
credential, err := FromRecord(record)
if err != nil {
return err
}
if len(credential.Purpose) == 0 || credential.Purpose == DefaultPurpose {
if _, ok := a.credentials[credential.Name]; ok {
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
}
a.credentials[credential.Name] = credential
} else {
if _, ok := a.serviceAccounts[credential.Purpose]; ok {
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
}
a.serviceAccounts[credential.Purpose] = credential
}
}
return nil
}
// NewCredential derives the hashed, encrypted, or raw value from the given secret and sets additional required and provided parameters
func (a *Auth) NewCredential(purpose, name, secret string, params []string, roles []string, extra string) *Credential {
c := &Credential{
Purpose: purpose,
Name: name,
//plain: secret,
Params: params,
//Salt: ...
//Derived: ...
Roles: roles,
Extra: extra,
}
switch c.Params[0] {
case "plain":
if len(params) != 1 {
fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
}
c.plain = secret
c.Params = []string{"plain"}
h := sha256.Sum256([]byte(secret))
c.Derived = h[:]
case "aes-128-gcm":
if len(params) != 1 {
fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
}
c.Params = []string{"aes-128-gcm"}
nonce := make([]byte, gcmNonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
panic(err)
}
c.Salt = nonce
var err error
var salt [12]byte
copy(salt[:], c.Salt)
c.plain = secret
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
if err != nil {
fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err)
os.Exit(1)
}
case "pbkdf2":
if len(params) > 4 {
fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
}
iters := defaultIters
if len(params) > 1 {
var err error
iters, err = strconv.Atoi(params[1])
if err != nil || iters <= 0 {
fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " "))
os.Exit(1)
}
}
size := defaultSize
if len(params) > 2 {
var err error
size, err = strconv.Atoi(params[2])
if err != nil || size < 8 || size > 32 {
fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " "))
os.Exit(1)
}
}
hashName := defaultHash
if len(params) > 3 {
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " "))
os.Exit(1)
}
hashName = params[3]
}
c.Params = []string{"pbkdf2", strconv.Itoa(iters), strconv.Itoa(size), hashName}
saltBytes := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, saltBytes); err != nil {
panic(err)
}
c.Salt = saltBytes
var hasher func() hash.Hash
hashNameUpper := strings.ToUpper(hashName)
switch hashNameUpper {
case "SHA-1", "SHA1":
hashName = "SHA-1"
hasher = sha1.New
case "SHA-256", "SHA256":
hashName = "SHA-256"
hasher = sha256.New
default:
fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName)
os.Exit(1)
}
var err error
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err)
os.Exit(1)
}
case "bcrypt":
if len(params) > 2 {
fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " "))
os.Exit(1)
}
cost := defaultBcryptCost
if len(params) > 1 {
var err error
cost, err = strconv.Atoi(params[1])
if err != nil || cost < 4 || cost > 31 {
fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " "))
os.Exit(1)
}
}
c.Params = []string{"bcrypt"} // cost is included in the digest
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
if err != nil {
fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err)
os.Exit(1)
}
c.Derived = derived
default:
fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0])
os.Exit(1)
}
return c
}
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
block, err := aes.NewCipher(aes128key[:])
if err != nil {
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err)
}
// nonceSize := len(gcmNonce) // should always be 12
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err)
}
plaintext := []byte(secret)
ciphertext := gcm.Seal(nil, gcmNonce[:], plaintext, nil)
return ciphertext, nil
}
// CredentialKeys returns the names that serve as IDs for each of the login credentials
func (a *Auth) CredentialKeys() iter.Seq[Name] {
a.mux.Lock()
defer a.mux.Unlock()
return maps.Keys(a.credentials)
}
func (a *Auth) LoadCredential(name Name) (Credential, error) {
a.mux.Lock()
c, ok := a.credentials[name]
a.mux.Unlock()
if !ok {
return c, ErrNotFound
}
var err error
if c.plain, err = a.maybeDecryptCredential(c); err != nil {
return c, err
}
return c, nil
}
func (a *Auth) CacheCredential(c Credential) error {
a.mux.Lock()
a.credentials[c.Name] = c
a.mux.Unlock()
return nil
}
// CredentialKeys returns the names that serve as IDs for each of the login credentials
func (a *Auth) ServiceAccountKeys() iter.Seq[Purpose] {
a.mux.Lock()
defer a.mux.Unlock()
return maps.Keys(a.serviceAccounts)
}
func (a *Auth) LoadServiceAccount(purpose Purpose) (Credential, error) {
a.mux.Lock()
c, ok := a.serviceAccounts[purpose]
a.mux.Unlock()
if !ok {
return c, ErrNotFound
}
var err error
if c.plain, err = a.maybeDecryptCredential(c); err != nil {
return c, err
}
return c, nil
}
func (a *Auth) maybeDecryptCredential(c Credential) (string, error) {
switch c.Params[0] {
case "aes-128-gcm":
var salt [12]byte
copy(salt[:], c.Salt)
return a.gcmDecrypt(a.aes128key, salt, c.Derived)
default:
break
}
return c.plain, nil
}
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
block, err := aes.NewCipher(aes128key[:])
if err != nil {
return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("new gcm failed: %v", err)
}
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
if err != nil {
return "", fmt.Errorf("gcm open (decryption) failed: %v", err)
}
return string(plaintext), nil
}
func (a *Auth) CacheServiceAccount(c Credential) error {
a.mux.Lock()
defer a.mux.Unlock()
a.serviceAccounts[c.Purpose] = c
return nil
}
// Verify checks Basic Auth credentials
func (a *Auth) Verify(name, secret string) error {
a.mux.Lock()
defer a.mux.Unlock()
c, ok := a.credentials[name]
if !ok {
return ErrNotFound
}
return c.Verify(name, secret)
}
// Verify checks Basic Auth credentials
func (c Credential) Verify(name, secret string) error {
known := c.Derived
var derived []byte
switch c.Params[0] {
case "aes-128-gcm":
knownHash := sha256.Sum256([]byte(c.plain))
known = knownHash[:]
h := sha256.Sum256([]byte(secret))
derived = h[:]
case "plain":
h := sha256.Sum256([]byte(secret))
derived = h[:]
case "pbkdf2":
// these are checked on load
iters, _ := strconv.Atoi(c.Params[1])
size, _ := strconv.Atoi(c.Params[2])
var hasher func() hash.Hash
switch c.Params[3] {
case "SHA-1":
hasher = sha1.New
case "SHA-256":
hasher = sha256.New
default:
panic(fmt.Errorf("invalid hash %q", c.Params[3]))
}
derived, _ = pbkdf2.Key(hasher, secret, c.Salt, iters, size)
case "bcrypt":
err := bcrypt.CompareHashAndPassword(c.Derived, []byte(secret))
if err == nil {
return nil
}
return ErrUnauthorized
default:
return ErrUnknownAlgorithm
}
if bytes.Equal(known, derived) {
return nil
}
return ErrUnauthorized
}