mirror of
				https://github.com/therootcompany/golib.git
				synced 2025-10-31 05:02:52 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			428 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			428 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package csvauth
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/aes"
 | |
| 	"crypto/cipher"
 | |
| 	"crypto/pbkdf2"
 | |
| 	"crypto/rand"
 | |
| 	"crypto/sha1"
 | |
| 	"crypto/sha256"
 | |
| 	"encoding/csv"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"hash"
 | |
| 	"io"
 | |
| 	"iter"
 | |
| 	"maps"
 | |
| 	"os"
 | |
| 	"slices"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 
 | |
| 	"golang.org/x/crypto/bcrypt"
 | |
| )
 | |
| 
 | |
| var ErrNotFound = errors.New("not found")
 | |
| var ErrUnauthorized = errors.New("unauthorized")
 | |
| var ErrUnknownAlgorithm = errors.New("unknown algorithm")
 | |
| 
 | |
| const (
 | |
| 	defaultIters      = 1000 // original 2000 recommendation
 | |
| 	defaultSize       = 16   // 128-bit
 | |
| 	defaultHash       = "SHA-256"
 | |
| 	defaultBcryptCost = 12
 | |
| 	gcmNonceSize      = 12 // RFC spec
 | |
| )
 | |
| 
 | |
| // NamedReadCloser provides Name() for debugging of file-like ReadClosers, such as http responses
 | |
| type NamedReadCloser interface {
 | |
| 	io.ReadCloser
 | |
| 	Name() string
 | |
| }
 | |
| 
 | |
| type readNamer struct {
 | |
| 	io.ReadCloser
 | |
| 	name string
 | |
| }
 | |
| 
 | |
| // Name returns the name given to the wrapped ReadCloser to f8ulfill NamedReadCloser
 | |
| func (r *readNamer) Name() string {
 | |
| 	return r.name
 | |
| }
 | |
| 
 | |
| // NewNamedReadCloser wraps a ReadCloser with a name which can be referenced when debugging
 | |
| func NewNamedReadCloser(r io.ReadCloser, name string) NamedReadCloser {
 | |
| 	return &readNamer{
 | |
| 		ReadCloser: r,
 | |
| 		name:       name,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Auth holds user the encryption key and both login and service account credentials
 | |
| type Auth struct {
 | |
| 	aes128key       [16]byte
 | |
| 	credentials     map[Name]Credential
 | |
| 	serviceAccounts map[Purpose]Credential
 | |
| 	mux             sync.Mutex
 | |
| }
 | |
| 
 | |
| // New initializes an Auth with an encryption key
 | |
| func New(aes128key []byte) *Auth {
 | |
| 	var aes128Arr [16]byte
 | |
| 	copy(aes128Arr[:], aes128key)
 | |
| 
 | |
| 	return &Auth{
 | |
| 		aes128key:       aes128Arr,
 | |
| 		credentials:     map[Name]Credential{},
 | |
| 		serviceAccounts: map[Purpose]Credential{},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Load reads a credentials CSV from the given NamedReadCloser (e.g. file, wrapped http request)
 | |
| func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
 | |
| 	csvr := csv.NewReader(f)
 | |
| 	csvr.Comma = comma
 | |
| 	csvr.Comment = '#'
 | |
| 	csvr.FieldsPerRecord = -1 // ignore short rows
 | |
| 	_, _ = csvr.Read()        // strip header row
 | |
| 	for {
 | |
| 		record, err := csvr.Read()
 | |
| 		if err == io.EOF {
 | |
| 			break
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if len(record) == 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if len(record) == 1 {
 | |
| 			if len(record[0]) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if len(record) < 5 {
 | |
| 			return fmt.Errorf("invalid %q format: %#v (%d)", f.Name(), record, len(record))
 | |
| 		}
 | |
| 
 | |
| 		credential, err := FromRecord(record)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if len(credential.Purpose) == 0 || credential.Purpose == DefaultPurpose {
 | |
| 			if _, ok := a.credentials[credential.Name]; ok {
 | |
| 				fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
 | |
| 			}
 | |
| 			a.credentials[credential.Name] = credential
 | |
| 		} else {
 | |
| 			if _, ok := a.serviceAccounts[credential.Purpose]; ok {
 | |
| 				fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
 | |
| 			}
 | |
| 			a.serviceAccounts[credential.Purpose] = credential
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // NewCredential derives the hashed, encrypted, or raw value from the given secret and sets additional required and provided parameters
 | |
| func (a *Auth) NewCredential(purpose, name, secret string, params []string, roles []string, extra string) *Credential {
 | |
| 	c := &Credential{
 | |
| 		Purpose: purpose,
 | |
| 		Name:    name,
 | |
| 		//plain: secret,
 | |
| 		Params: params,
 | |
| 		//Salt: ...
 | |
| 		//Derived: ...
 | |
| 		Roles: roles,
 | |
| 		Extra: extra,
 | |
| 	}
 | |
| 
 | |
| 	switch c.Params[0] {
 | |
| 	case "plain":
 | |
| 		if len(params) != 1 {
 | |
| 			fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " "))
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 		c.plain = secret
 | |
| 
 | |
| 		c.Params = []string{"plain"}
 | |
| 		h := sha256.Sum256([]byte(secret))
 | |
| 		c.Derived = h[:]
 | |
| 	case "aes-128-gcm":
 | |
| 		if len(params) != 1 {
 | |
| 			fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " "))
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 
 | |
| 		c.Params = []string{"aes-128-gcm"}
 | |
| 		nonce := make([]byte, gcmNonceSize)
 | |
| 		if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
 | |
| 			panic(err)
 | |
| 		}
 | |
| 		c.Salt = nonce
 | |
| 
 | |
| 		var err error
 | |
| 		var salt [12]byte
 | |
| 		copy(salt[:], c.Salt)
 | |
| 		c.plain = secret
 | |
| 		c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
 | |
| 		if err != nil {
 | |
| 			fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err)
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 	case "pbkdf2":
 | |
| 		if len(params) > 4 {
 | |
| 			fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " "))
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 		iters := defaultIters
 | |
| 		if len(params) > 1 {
 | |
| 			var err error
 | |
| 			iters, err = strconv.Atoi(params[1])
 | |
| 			if err != nil || iters <= 0 {
 | |
| 				fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " "))
 | |
| 				os.Exit(1)
 | |
| 			}
 | |
| 		}
 | |
| 		size := defaultSize
 | |
| 		if len(params) > 2 {
 | |
| 			var err error
 | |
| 			size, err = strconv.Atoi(params[2])
 | |
| 			if err != nil || size < 8 || size > 32 {
 | |
| 				fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " "))
 | |
| 				os.Exit(1)
 | |
| 			}
 | |
| 		}
 | |
| 		hashName := defaultHash
 | |
| 		if len(params) > 3 {
 | |
| 			if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
 | |
| 				fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " "))
 | |
| 				os.Exit(1)
 | |
| 			}
 | |
| 			hashName = params[3]
 | |
| 		}
 | |
| 		c.Params = []string{"pbkdf2", strconv.Itoa(iters), strconv.Itoa(size), hashName}
 | |
| 		saltBytes := make([]byte, 16)
 | |
| 		if _, err := io.ReadFull(rand.Reader, saltBytes); err != nil {
 | |
| 			panic(err)
 | |
| 		}
 | |
| 		c.Salt = saltBytes
 | |
| 		var hasher func() hash.Hash
 | |
| 		hashNameUpper := strings.ToUpper(hashName)
 | |
| 		switch hashNameUpper {
 | |
| 		case "SHA-1", "SHA1":
 | |
| 			hashName = "SHA-1"
 | |
| 			hasher = sha1.New
 | |
| 		case "SHA-256", "SHA256":
 | |
| 			hashName = "SHA-256"
 | |
| 			hasher = sha256.New
 | |
| 		default:
 | |
| 			fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName)
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 		var err error
 | |
| 		c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
 | |
| 		if err != nil {
 | |
| 			fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err)
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 	case "bcrypt":
 | |
| 		if len(params) > 2 {
 | |
| 			fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " "))
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 		cost := defaultBcryptCost
 | |
| 		if len(params) > 1 {
 | |
| 			var err error
 | |
| 			cost, err = strconv.Atoi(params[1])
 | |
| 			if err != nil || cost < 4 || cost > 31 {
 | |
| 				fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " "))
 | |
| 				os.Exit(1)
 | |
| 			}
 | |
| 		}
 | |
| 		c.Params = []string{"bcrypt"} // cost is included in the digest
 | |
| 		derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
 | |
| 		if err != nil {
 | |
| 			fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err)
 | |
| 			os.Exit(1)
 | |
| 		}
 | |
| 		c.Derived = derived
 | |
| 	default:
 | |
| 		fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0])
 | |
| 		os.Exit(1)
 | |
| 	}
 | |
| 
 | |
| 	return c
 | |
| }
 | |
| 
 | |
| func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
 | |
| 	block, err := aes.NewCipher(aes128key[:])
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// nonceSize := len(gcmNonce) // should always be 12
 | |
| 	gcm, err := cipher.NewGCM(block)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	plaintext := []byte(secret)
 | |
| 	ciphertext := gcm.Seal(nil, gcmNonce[:], plaintext, nil)
 | |
| 	return ciphertext, nil
 | |
| }
 | |
| 
 | |
| // CredentialKeys returns the names that serve as IDs for each of the login credentials
 | |
| func (a *Auth) CredentialKeys() iter.Seq[Name] {
 | |
| 	a.mux.Lock()
 | |
| 	defer a.mux.Unlock()
 | |
| 	return maps.Keys(a.credentials)
 | |
| }
 | |
| 
 | |
| func (a *Auth) LoadCredential(name Name) (Credential, error) {
 | |
| 	a.mux.Lock()
 | |
| 	c, ok := a.credentials[name]
 | |
| 	a.mux.Unlock()
 | |
| 	if !ok {
 | |
| 		return c, ErrNotFound
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	if c.plain, err = a.maybeDecryptCredential(c); err != nil {
 | |
| 		return c, err
 | |
| 	}
 | |
| 
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| func (a *Auth) CacheCredential(c Credential) error {
 | |
| 	a.mux.Lock()
 | |
| 	a.credentials[c.Name] = c
 | |
| 	a.mux.Unlock()
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // CredentialKeys returns the names that serve as IDs for each of the login credentials
 | |
| func (a *Auth) ServiceAccountKeys() iter.Seq[Purpose] {
 | |
| 	a.mux.Lock()
 | |
| 	defer a.mux.Unlock()
 | |
| 	return maps.Keys(a.serviceAccounts)
 | |
| }
 | |
| 
 | |
| func (a *Auth) LoadServiceAccount(purpose Purpose) (Credential, error) {
 | |
| 	a.mux.Lock()
 | |
| 	c, ok := a.serviceAccounts[purpose]
 | |
| 	a.mux.Unlock()
 | |
| 	if !ok {
 | |
| 		return c, ErrNotFound
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	if c.plain, err = a.maybeDecryptCredential(c); err != nil {
 | |
| 		return c, err
 | |
| 	}
 | |
| 
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| func (a *Auth) maybeDecryptCredential(c Credential) (string, error) {
 | |
| 	switch c.Params[0] {
 | |
| 	case "aes-128-gcm":
 | |
| 		var salt [12]byte
 | |
| 		copy(salt[:], c.Salt)
 | |
| 		return a.gcmDecrypt(a.aes128key, salt, c.Derived)
 | |
| 	default:
 | |
| 		break
 | |
| 	}
 | |
| 
 | |
| 	return c.plain, nil
 | |
| }
 | |
| 
 | |
| func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
 | |
| 	block, err := aes.NewCipher(aes128key[:])
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	gcm, err := cipher.NewGCM(block)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("new gcm failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("gcm open (decryption) failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return string(plaintext), nil
 | |
| }
 | |
| 
 | |
| func (a *Auth) CacheServiceAccount(c Credential) error {
 | |
| 	a.mux.Lock()
 | |
| 	defer a.mux.Unlock()
 | |
| 	a.serviceAccounts[c.Purpose] = c
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Verify checks Basic Auth credentials
 | |
| func (a *Auth) Verify(name, secret string) error {
 | |
| 	a.mux.Lock()
 | |
| 	defer a.mux.Unlock()
 | |
| 	c, ok := a.credentials[name]
 | |
| 	if !ok {
 | |
| 		return ErrNotFound
 | |
| 	}
 | |
| 	return c.Verify(name, secret)
 | |
| }
 | |
| 
 | |
| // Verify checks Basic Auth credentials
 | |
| func (c Credential) Verify(name, secret string) error {
 | |
| 	known := c.Derived
 | |
| 	var derived []byte
 | |
| 	switch c.Params[0] {
 | |
| 	case "aes-128-gcm":
 | |
| 		knownHash := sha256.Sum256([]byte(c.plain))
 | |
| 		known = knownHash[:]
 | |
| 
 | |
| 		h := sha256.Sum256([]byte(secret))
 | |
| 		derived = h[:]
 | |
| 	case "plain":
 | |
| 		h := sha256.Sum256([]byte(secret))
 | |
| 		derived = h[:]
 | |
| 	case "pbkdf2":
 | |
| 		// these are checked on load
 | |
| 		iters, _ := strconv.Atoi(c.Params[1])
 | |
| 		size, _ := strconv.Atoi(c.Params[2])
 | |
| 		var hasher func() hash.Hash
 | |
| 		switch c.Params[3] {
 | |
| 		case "SHA-1":
 | |
| 			hasher = sha1.New
 | |
| 		case "SHA-256":
 | |
| 			hasher = sha256.New
 | |
| 		default:
 | |
| 			panic(fmt.Errorf("invalid hash %q", c.Params[3]))
 | |
| 		}
 | |
| 		derived, _ = pbkdf2.Key(hasher, secret, c.Salt, iters, size)
 | |
| 	case "bcrypt":
 | |
| 		err := bcrypt.CompareHashAndPassword(c.Derived, []byte(secret))
 | |
| 		if err == nil {
 | |
| 			return nil
 | |
| 		}
 | |
| 		return ErrUnauthorized
 | |
| 	default:
 | |
| 		return ErrUnknownAlgorithm
 | |
| 	}
 | |
| 
 | |
| 	if bytes.Equal(known, derived) {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return ErrUnauthorized
 | |
| }
 |