mirror of
https://github.com/therootcompany/golib.git
synced 2025-10-07 01:28:19 +00:00
428 lines
10 KiB
Go
428 lines
10 KiB
Go
package csvauth
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/pbkdf2"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"crypto/sha256"
|
|
"encoding/csv"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"io"
|
|
"iter"
|
|
"maps"
|
|
"os"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var ErrNotFound = errors.New("not found")
|
|
var ErrUnauthorized = errors.New("unauthorized")
|
|
var ErrUnknownAlgorithm = errors.New("unknown algorithm")
|
|
|
|
const (
|
|
defaultIters = 1000 // original 2000 recommendation
|
|
defaultSize = 16 // 128-bit
|
|
defaultHash = "SHA-256"
|
|
defaultBcryptCost = 12
|
|
gcmNonceSize = 12 // RFC spec
|
|
)
|
|
|
|
// NamedReadCloser provides Name() for debugging of file-like ReadClosers, such as http responses
|
|
type NamedReadCloser interface {
|
|
io.ReadCloser
|
|
Name() string
|
|
}
|
|
|
|
type readNamer struct {
|
|
io.ReadCloser
|
|
name string
|
|
}
|
|
|
|
// Name returns the name given to the wrapped ReadCloser to f8ulfill NamedReadCloser
|
|
func (r *readNamer) Name() string {
|
|
return r.name
|
|
}
|
|
|
|
// NewNamedReadCloser wraps a ReadCloser with a name which can be referenced when debugging
|
|
func NewNamedReadCloser(r io.ReadCloser, name string) NamedReadCloser {
|
|
return &readNamer{
|
|
ReadCloser: r,
|
|
name: name,
|
|
}
|
|
}
|
|
|
|
// Auth holds user the encryption key and both login and service account credentials
|
|
type Auth struct {
|
|
aes128key [16]byte
|
|
credentials map[Name]Credential
|
|
serviceAccounts map[Purpose]Credential
|
|
mux sync.Mutex
|
|
}
|
|
|
|
// New initializes an Auth with an encryption key
|
|
func New(aes128key []byte) *Auth {
|
|
var aes128Arr [16]byte
|
|
copy(aes128Arr[:], aes128key)
|
|
|
|
return &Auth{
|
|
aes128key: aes128Arr,
|
|
credentials: map[Name]Credential{},
|
|
serviceAccounts: map[Purpose]Credential{},
|
|
}
|
|
}
|
|
|
|
// Load reads a credentials CSV from the given NamedReadCloser (e.g. file, wrapped http request)
|
|
func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
|
|
csvr := csv.NewReader(f)
|
|
csvr.Comma = comma
|
|
csvr.Comment = '#'
|
|
csvr.FieldsPerRecord = -1 // ignore short rows
|
|
_, _ = csvr.Read() // strip header row
|
|
for {
|
|
record, err := csvr.Read()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(record) == 0 {
|
|
continue
|
|
}
|
|
|
|
if len(record) == 1 {
|
|
if len(record[0]) == 0 {
|
|
continue
|
|
}
|
|
}
|
|
|
|
if len(record) < 5 {
|
|
return fmt.Errorf("invalid %q format: %#v (%d)", f.Name(), record, len(record))
|
|
}
|
|
|
|
credential, err := FromRecord(record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(credential.Purpose) == 0 || credential.Purpose == DefaultPurpose {
|
|
if _, ok := a.credentials[credential.Name]; ok {
|
|
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
|
}
|
|
a.credentials[credential.Name] = credential
|
|
} else {
|
|
if _, ok := a.serviceAccounts[credential.Purpose]; ok {
|
|
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
|
}
|
|
a.serviceAccounts[credential.Purpose] = credential
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NewCredential derives the hashed, encrypted, or raw value from the given secret and sets additional required and provided parameters
|
|
func (a *Auth) NewCredential(purpose, name, secret string, params []string, roles []string, extra string) *Credential {
|
|
c := &Credential{
|
|
Purpose: purpose,
|
|
Name: name,
|
|
//plain: secret,
|
|
Params: params,
|
|
//Salt: ...
|
|
//Derived: ...
|
|
Roles: roles,
|
|
Extra: extra,
|
|
}
|
|
|
|
switch c.Params[0] {
|
|
case "plain":
|
|
if len(params) != 1 {
|
|
fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
c.plain = secret
|
|
|
|
c.Params = []string{"plain"}
|
|
h := sha256.Sum256([]byte(secret))
|
|
c.Derived = h[:]
|
|
case "aes-128-gcm":
|
|
if len(params) != 1 {
|
|
fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
|
|
c.Params = []string{"aes-128-gcm"}
|
|
nonce := make([]byte, gcmNonceSize)
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
panic(err)
|
|
}
|
|
c.Salt = nonce
|
|
|
|
var err error
|
|
var salt [12]byte
|
|
copy(salt[:], c.Salt)
|
|
c.plain = secret
|
|
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
case "pbkdf2":
|
|
if len(params) > 4 {
|
|
fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
iters := defaultIters
|
|
if len(params) > 1 {
|
|
var err error
|
|
iters, err = strconv.Atoi(params[1])
|
|
if err != nil || iters <= 0 {
|
|
fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
size := defaultSize
|
|
if len(params) > 2 {
|
|
var err error
|
|
size, err = strconv.Atoi(params[2])
|
|
if err != nil || size < 8 || size > 32 {
|
|
fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
hashName := defaultHash
|
|
if len(params) > 3 {
|
|
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
|
|
fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
hashName = params[3]
|
|
}
|
|
c.Params = []string{"pbkdf2", strconv.Itoa(iters), strconv.Itoa(size), hashName}
|
|
saltBytes := make([]byte, 16)
|
|
if _, err := io.ReadFull(rand.Reader, saltBytes); err != nil {
|
|
panic(err)
|
|
}
|
|
c.Salt = saltBytes
|
|
var hasher func() hash.Hash
|
|
hashNameUpper := strings.ToUpper(hashName)
|
|
switch hashNameUpper {
|
|
case "SHA-1", "SHA1":
|
|
hashName = "SHA-1"
|
|
hasher = sha1.New
|
|
case "SHA-256", "SHA256":
|
|
hashName = "SHA-256"
|
|
hasher = sha256.New
|
|
default:
|
|
fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName)
|
|
os.Exit(1)
|
|
}
|
|
var err error
|
|
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
case "bcrypt":
|
|
if len(params) > 2 {
|
|
fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
cost := defaultBcryptCost
|
|
if len(params) > 1 {
|
|
var err error
|
|
cost, err = strconv.Atoi(params[1])
|
|
if err != nil || cost < 4 || cost > 31 {
|
|
fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " "))
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
c.Params = []string{"bcrypt"} // cost is included in the digest
|
|
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
c.Derived = derived
|
|
default:
|
|
fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0])
|
|
os.Exit(1)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
|
|
block, err := aes.NewCipher(aes128key[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err)
|
|
}
|
|
|
|
// nonceSize := len(gcmNonce) // should always be 12
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err)
|
|
}
|
|
|
|
plaintext := []byte(secret)
|
|
ciphertext := gcm.Seal(nil, gcmNonce[:], plaintext, nil)
|
|
return ciphertext, nil
|
|
}
|
|
|
|
// CredentialKeys returns the names that serve as IDs for each of the login credentials
|
|
func (a *Auth) CredentialKeys() iter.Seq[Name] {
|
|
a.mux.Lock()
|
|
defer a.mux.Unlock()
|
|
return maps.Keys(a.credentials)
|
|
}
|
|
|
|
func (a *Auth) LoadCredential(name Name) (Credential, error) {
|
|
a.mux.Lock()
|
|
c, ok := a.credentials[name]
|
|
a.mux.Unlock()
|
|
if !ok {
|
|
return c, ErrNotFound
|
|
}
|
|
|
|
var err error
|
|
if c.plain, err = a.maybeDecryptCredential(c); err != nil {
|
|
return c, err
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (a *Auth) CacheCredential(c Credential) error {
|
|
a.mux.Lock()
|
|
a.credentials[c.Name] = c
|
|
a.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// CredentialKeys returns the names that serve as IDs for each of the login credentials
|
|
func (a *Auth) ServiceAccountKeys() iter.Seq[Purpose] {
|
|
a.mux.Lock()
|
|
defer a.mux.Unlock()
|
|
return maps.Keys(a.serviceAccounts)
|
|
}
|
|
|
|
func (a *Auth) LoadServiceAccount(purpose Purpose) (Credential, error) {
|
|
a.mux.Lock()
|
|
c, ok := a.serviceAccounts[purpose]
|
|
a.mux.Unlock()
|
|
if !ok {
|
|
return c, ErrNotFound
|
|
}
|
|
|
|
var err error
|
|
if c.plain, err = a.maybeDecryptCredential(c); err != nil {
|
|
return c, err
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (a *Auth) maybeDecryptCredential(c Credential) (string, error) {
|
|
switch c.Params[0] {
|
|
case "aes-128-gcm":
|
|
var salt [12]byte
|
|
copy(salt[:], c.Salt)
|
|
return a.gcmDecrypt(a.aes128key, salt, c.Derived)
|
|
default:
|
|
break
|
|
}
|
|
|
|
return c.plain, nil
|
|
}
|
|
|
|
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
|
|
block, err := aes.NewCipher(aes128key[:])
|
|
if err != nil {
|
|
return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", fmt.Errorf("new gcm failed: %v", err)
|
|
}
|
|
|
|
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("gcm open (decryption) failed: %v", err)
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
func (a *Auth) CacheServiceAccount(c Credential) error {
|
|
a.mux.Lock()
|
|
defer a.mux.Unlock()
|
|
a.serviceAccounts[c.Purpose] = c
|
|
return nil
|
|
}
|
|
|
|
// Verify checks Basic Auth credentials
|
|
func (a *Auth) Verify(name, secret string) error {
|
|
a.mux.Lock()
|
|
defer a.mux.Unlock()
|
|
c, ok := a.credentials[name]
|
|
if !ok {
|
|
return ErrNotFound
|
|
}
|
|
return c.Verify(name, secret)
|
|
}
|
|
|
|
// Verify checks Basic Auth credentials
|
|
func (c Credential) Verify(name, secret string) error {
|
|
known := c.Derived
|
|
var derived []byte
|
|
switch c.Params[0] {
|
|
case "aes-128-gcm":
|
|
knownHash := sha256.Sum256([]byte(c.plain))
|
|
known = knownHash[:]
|
|
|
|
h := sha256.Sum256([]byte(secret))
|
|
derived = h[:]
|
|
case "plain":
|
|
h := sha256.Sum256([]byte(secret))
|
|
derived = h[:]
|
|
case "pbkdf2":
|
|
// these are checked on load
|
|
iters, _ := strconv.Atoi(c.Params[1])
|
|
size, _ := strconv.Atoi(c.Params[2])
|
|
var hasher func() hash.Hash
|
|
switch c.Params[3] {
|
|
case "SHA-1":
|
|
hasher = sha1.New
|
|
case "SHA-256":
|
|
hasher = sha256.New
|
|
default:
|
|
panic(fmt.Errorf("invalid hash %q", c.Params[3]))
|
|
}
|
|
derived, _ = pbkdf2.Key(hasher, secret, c.Salt, iters, size)
|
|
case "bcrypt":
|
|
err := bcrypt.CompareHashAndPassword(c.Derived, []byte(secret))
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return ErrUnauthorized
|
|
default:
|
|
return ErrUnknownAlgorithm
|
|
}
|
|
|
|
if bytes.Equal(known, derived) {
|
|
return nil
|
|
}
|
|
return ErrUnauthorized
|
|
}
|