golib/auth/csvauth/credential.go
AJ ONeal 02fef67e53
fix(auth/csvauth): ID() returns Name only, not Name~hashID for tokens
Principal identity is the subject (who), not the credential instance
(which token). The hashID suffix was an internal cache fingerprint that
leaked into the public ID. Callers that need to distinguish individual
token instances must use a separate mechanism.

TSV serialization in ToRecord() still writes Name~hashID when hashID is
set so the credential file round-trips correctly.
2026-04-13 22:57:21 -06:00

233 lines
5.6 KiB
Go

package csvauth
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"github.com/therootcompany/golib/auth"
)
var ErrDecodeFields = errors.New("could not decode credential")
type BasicAuthVerifier interface {
Verify(string, string) error
}
const (
// deprecated, misspelling of PurposeDefault
DefaultPurpose = "login"
PurposeDefault = "login"
PurposeToken = "token"
hashIDSep = "~"
)
type Purpose = string
type Name = string
type secretValue string
func (s secretValue) String() string {
return "[redacted]"
}
func (s secretValue) GoString() string {
return `"[redacted]"`
}
func (s secretValue) MarshalText() string {
return s.String()
}
// Credential represents a row in the CSV file
type Credential struct {
Purpose Purpose
Name Name
plain secretValue
Params []string
Salt []byte
Derived []byte
Roles []string
Extra string
hashID string
}
func (c *Credential) ID() string {
return c.Name
}
func (c *Credential) Permissions() []string {
return c.Roles
}
func (c Credential) Secret() string {
return string(c.plain)
}
func FromRecord(record []string) (Credential, error) {
var roleList, extra string
purpose, name, paramList, salt64, derived := record[0], record[1], record[2], record[3], record[4]
if len(record) >= 6 {
roleList = record[5]
}
if len(record) >= 7 {
extra = record[6]
}
return FromFields(purpose, name, paramList, salt64, derived, roleList, extra)
}
func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra string) (Credential, error) {
var credential Credential
if len(purpose) == 0 {
purpose = PurposeDefault
}
credential.Purpose = purpose
if credential.Purpose == PurposeToken {
name, hashID, _ := splitLast(name, hashIDSep)
credential.Name = name
// this can only be verified if plain or aes
credential.hashID = hashID
} else {
credential.Name = name
}
var roles []string
if len(roleList) > 0 {
roleList = strings.ReplaceAll(roleList, ",", " ")
roles = strings.Split(roleList, " ")
}
credential.Roles = roles
credential.Extra = extra
paramList = strings.ReplaceAll(paramList, ",", " ")
credential.Params = strings.Split(paramList, " ")
if len(credential.Params) == 0 {
return Credential{}, fmt.Errorf("%w: no algorithm parameters for %q", ErrDecodeFields, name)
}
switch credential.Params[0] {
case "aes-128-gcm":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("%w: invalid plain parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
}
salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
if err != nil {
return credential, err
}
credential.Salt = salt
bytes, err := base64.RawURLEncoding.DecodeString(derived)
if err != nil {
return credential, err
}
credential.Derived = bytes
case "plain":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
}
credential.plain = secretValue(derived)
h := sha256.Sum256([]byte(derived))
credential.Derived = h[:]
case "pbkdf2":
var err error
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
if err != nil {
return credential, fmt.Errorf("%w: bad salt for %q: %q", ErrDecodeFields, name, saltBase64)
}
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
if err != nil {
return credential, fmt.Errorf("%w: bad derived data for %q: %q", ErrDecodeFields, name, derived)
}
iters, err := strconv.Atoi(credential.Params[1])
if err != nil {
return credential, err
}
if iters <= 0 {
return credential, fmt.Errorf("%w: invalid iterations for %q: %q", ErrDecodeFields, name, credential.Params[1])
}
size, err := strconv.Atoi(credential.Params[2])
if err != nil {
return credential, err
}
if size < 8 || size > 32 {
return credential, fmt.Errorf("%w: invalid size for %q: %q", ErrDecodeFields, name, credential.Params[2])
}
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
return credential, fmt.Errorf("%w: invalid hash for %q: %q", ErrDecodeFields, name, credential.Params[3])
}
case "bcrypt":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("%w: invalid bcrypt parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`))
}
credential.Derived = []byte(derived)
default:
return credential, fmt.Errorf("%w: invalid algorithm for %q: %q", ErrDecodeFields, name, credential.Params[0])
}
return credential, nil
}
func splitLast(s, sep string) (before, after string, found bool) {
if sep == "" {
return s, "", false
}
idx := strings.LastIndex(s, sep)
if idx == -1 {
return s, "", false
}
return s[:idx], s[idx+len(sep):], true
}
func (c Credential) ToRecord() []string {
var paramList, salt, derived string
paramList = strings.Join(c.Params, " ")
switch c.Params[0] {
case "aes-128-gcm":
salt = base64.RawURLEncoding.EncodeToString(c.Salt)
derived = base64.RawURLEncoding.EncodeToString(c.Derived)
case "plain":
salt = ""
derived = string(c.plain)
case "pbkdf2":
salt = base64.RawURLEncoding.EncodeToString(c.Salt)
derived = base64.RawURLEncoding.EncodeToString(c.Derived)
case "bcrypt":
derived = string(c.Derived)
default:
panic(fmt.Errorf("unknown algorithm %q", c.Params[0]))
}
purpose := c.Purpose
if len(purpose) == 0 {
purpose = PurposeDefault
}
name := c.Name
if c.hashID != "" {
name += hashIDSep + c.hashID
}
record := []string{purpose, name, paramList, salt, derived, strings.Join(c.Roles, " "), c.Extra}
return record
}
var _ BasicAuthVerifier = (*Credential)(nil)
var _ auth.BasicPrinciple = (*Credential)(nil)