package csvauth import ( "crypto/sha256" "encoding/base64" "errors" "fmt" "slices" "strconv" "strings" "github.com/therootcompany/golib/auth" ) var ErrDecodeFields = errors.New("could not decode credential") type BasicAuthVerifier interface { Verify(string, string) error } const ( // deprecated, misspelling of PurposeDefault DefaultPurpose = "login" PurposeDefault = "login" PurposeToken = "token" hashIDSep = "~" ) type Purpose = string type Name = string type secretValue string func (s secretValue) String() string { return "[redacted]" } func (s secretValue) GoString() string { return `"[redacted]"` } func (s secretValue) MarshalText() string { return s.String() } // Credential represents a row in the CSV file type Credential struct { Purpose Purpose Name Name plain secretValue Params []string Salt []byte Derived []byte Roles []string Extra string hashID string } func (c *Credential) ID() string { if c.Purpose == PurposeToken { return c.Name + hashIDSep + c.hashID } return c.Name } func (c *Credential) Permissions() []string { return c.Roles } func (c Credential) Secret() string { return string(c.plain) } func FromRecord(record []string) (Credential, error) { var roleList, extra string purpose, name, paramList, salt64, derived := record[0], record[1], record[2], record[3], record[4] if len(record) >= 6 { roleList = record[5] } if len(record) >= 7 { extra = record[6] } return FromFields(purpose, name, paramList, salt64, derived, roleList, extra) } func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra string) (Credential, error) { var credential Credential if len(purpose) == 0 { purpose = PurposeDefault } credential.Purpose = purpose if credential.Purpose == PurposeToken { name, hashID, _ := splitLast(name, hashIDSep) credential.Name = name // this can only be verified if plain or aes credential.hashID = hashID } else { credential.Name = name } var roles []string if len(roleList) > 0 { roleList = strings.ReplaceAll(roleList, ",", " ") roles = strings.Split(roleList, " ") } credential.Roles = roles credential.Extra = extra paramList = strings.ReplaceAll(paramList, ",", " ") credential.Params = strings.Split(paramList, " ") if len(credential.Params) == 0 { return Credential{}, fmt.Errorf("%w: no algorithm parameters for %q", ErrDecodeFields, name) } switch credential.Params[0] { case "aes-128-gcm": if len(credential.Params) > 1 { return credential, fmt.Errorf("%w: invalid plain parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`)) } salt, err := base64.RawURLEncoding.DecodeString(saltBase64) if err != nil { return credential, err } credential.Salt = salt bytes, err := base64.RawURLEncoding.DecodeString(derived) if err != nil { return credential, err } credential.Derived = bytes case "plain": if len(credential.Params) > 1 { return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params) } credential.plain = secretValue(derived) h := sha256.Sum256([]byte(derived)) credential.Derived = h[:] case "pbkdf2": var err error credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64) if err != nil { return credential, fmt.Errorf("%w: bad salt for %q: %q", ErrDecodeFields, name, saltBase64) } credential.Derived, err = base64.RawURLEncoding.DecodeString(derived) if err != nil { return credential, fmt.Errorf("%w: bad derived data for %q: %q", ErrDecodeFields, name, derived) } iters, err := strconv.Atoi(credential.Params[1]) if err != nil { return credential, err } if iters <= 0 { return credential, fmt.Errorf("%w: invalid iterations for %q: %q", ErrDecodeFields, name, credential.Params[1]) } size, err := strconv.Atoi(credential.Params[2]) if err != nil { return credential, err } if size < 8 || size > 32 { return credential, fmt.Errorf("%w: invalid size for %q: %q", ErrDecodeFields, name, credential.Params[2]) } if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) { return credential, fmt.Errorf("%w: invalid hash for %q: %q", ErrDecodeFields, name, credential.Params[3]) } case "bcrypt": if len(credential.Params) > 1 { return credential, fmt.Errorf("%w: invalid bcrypt parameters for %q: %q", ErrDecodeFields, name, strings.Join(credential.Params, `", "`)) } credential.Derived = []byte(derived) default: return credential, fmt.Errorf("%w: invalid algorithm for %q: %q", ErrDecodeFields, name, credential.Params[0]) } return credential, nil } func splitLast(s, sep string) (before, after string, found bool) { if sep == "" { return s, "", false } idx := strings.LastIndex(s, sep) if idx == -1 { return s, "", false } return s[:idx], s[idx+len(sep):], true } func (c Credential) ToRecord() []string { var paramList, salt, derived string paramList = strings.Join(c.Params, " ") switch c.Params[0] { case "aes-128-gcm": salt = base64.RawURLEncoding.EncodeToString(c.Salt) derived = base64.RawURLEncoding.EncodeToString(c.Derived) case "plain": salt = "" derived = string(c.plain) case "pbkdf2": salt = base64.RawURLEncoding.EncodeToString(c.Salt) derived = base64.RawURLEncoding.EncodeToString(c.Derived) case "bcrypt": derived = string(c.Derived) default: panic(fmt.Errorf("unknown algorithm %q", c.Params[0])) } purpose := c.Purpose if len(purpose) == 0 { purpose = PurposeDefault } name := c.Name if c.hashID != "" { name += hashIDSep + c.hashID } record := []string{purpose, name, paramList, salt, derived, strings.Join(c.Roles, " "), c.Extra} return record } var _ BasicAuthVerifier = (*Credential)(nil) var _ auth.BasicPrinciple = (*Credential)(nil)