golib/auth/csvauth/credential.go

218 lines
5.0 KiB
Go

package csvauth
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"os"
"slices"
"strconv"
"strings"
)
type BasicAuthVerifier interface {
Verify(string, string) error
}
const (
// deprecated, misspelling of PurposeDefault
DefaultPurpose = "login"
PurposeDefault = "login"
PurposeToken = "token"
hashIDSep = "~"
)
type Purpose = string
type Name = string
type secretValue string
func (s secretValue) String() string {
return "[redacted]"
}
func (s secretValue) GoString() string {
return `"[redacted]"`
}
func (s secretValue) MarshalText() string {
return s.String()
}
// Credential represents a row in the CSV file
type Credential struct {
Purpose Purpose
Name Name
plain secretValue
Params []string
Salt []byte
Derived []byte
Roles []string
Extra string
hashID string
}
func (c Credential) Secret() string {
return string(c.plain)
}
func FromRecord(record []string) (Credential, error) {
var roleList, extra string
purpose, name, paramList, salt64, derived := record[0], record[1], record[2], record[3], record[4]
if len(record) >= 6 {
roleList = record[5]
}
if len(record) >= 7 {
extra = record[6]
}
return FromFields(purpose, name, paramList, salt64, derived, roleList, extra)
}
func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra string) (Credential, error) {
var credential Credential
if len(purpose) == 0 {
purpose = PurposeDefault
}
credential.Purpose = purpose
if credential.Purpose == PurposeToken {
name, hashID, _ := splitLast(name, hashIDSep)
credential.Name = name
// this can only be verified if plain or aes
credential.hashID = hashID
} else {
credential.Name = name
}
var roles []string
if len(roleList) > 0 {
roleList = strings.ReplaceAll(roleList, ",", " ")
roles = strings.Split(roleList, " ")
}
credential.Roles = roles
credential.Extra = extra
paramList = strings.ReplaceAll(paramList, ",", " ")
credential.Params = strings.Split(paramList, " ")
if len(credential.Params) == 0 {
fmt.Fprintf(os.Stderr, "no algorithm parameters for %q\n", name)
}
switch credential.Params[0] {
case "aes-128-gcm":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
}
salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
if err != nil {
return credential, err
}
credential.Salt = salt
bytes, err := base64.RawURLEncoding.DecodeString(derived)
if err != nil {
return credential, err
}
credential.Derived = bytes
case "plain":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
}
credential.plain = secretValue(derived)
h := sha256.Sum256([]byte(derived))
credential.Derived = h[:]
case "pbkdf2":
var err error
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "could not decode salt %q for %q\n", saltBase64, name)
}
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
if err != nil {
fmt.Fprintf(os.Stderr, "could not decode derived data %q for %q\n", derived, name)
}
iters, err := strconv.Atoi(credential.Params[1])
if err != nil {
return credential, err
}
if iters <= 0 {
return credential, fmt.Errorf("invalid iterations %s", credential.Params[1])
}
size, err := strconv.Atoi(credential.Params[2])
if err != nil {
return credential, err
}
if size < 8 || size > 32 {
return credential, fmt.Errorf("invalid size %s", credential.Params[2])
}
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
return credential, fmt.Errorf("invalid hash %s", credential.Params[3])
}
case "bcrypt":
if len(credential.Params) > 1 {
return credential, fmt.Errorf("invalid bcrypt parameters %#v", credential.Params)
}
credential.Derived = []byte(derived)
default:
return credential, fmt.Errorf("invalid algorithm %s", credential.Params[0])
}
return credential, nil
}
func splitLast(s, sep string) (before, after string, found bool) {
if sep == "" {
return s, "", false
}
idx := strings.LastIndex(s, sep)
if idx == -1 {
return s, "", false
}
return s[:idx], s[idx+len(sep):], true
}
func (c Credential) ToRecord() []string {
var paramList, salt, derived string
paramList = strings.Join(c.Params, " ")
switch c.Params[0] {
case "aes-128-gcm":
salt = base64.RawURLEncoding.EncodeToString(c.Salt)
derived = base64.RawURLEncoding.EncodeToString(c.Derived)
case "plain":
salt = ""
derived = string(c.plain)
case "pbkdf2":
salt = base64.RawURLEncoding.EncodeToString(c.Salt)
derived = base64.RawURLEncoding.EncodeToString(c.Derived)
case "bcrypt":
derived = string(c.Derived)
default:
panic(fmt.Errorf("unknown algorithm %q", c.Params[0]))
}
purpose := c.Purpose
if len(purpose) == 0 {
purpose = PurposeDefault
}
name := c.Name
if c.hashID != "" {
name += hashIDSep + c.hashID
}
record := []string{purpose, name, paramList, salt, derived, strings.Join(c.Roles, " "), c.Extra}
return record
}