mirror of
https://github.com/therootcompany/golib.git
synced 2025-10-07 01:28:19 +00:00
feat(csvauth): store and verify or retrieve credentials
This commit is contained in:
parent
e8fbe603af
commit
612cd2e53c
148
auth/csvauth/README.md
Normal file
148
auth/csvauth/README.md
Normal file
@ -0,0 +1,148 @@
|
||||
# csvauth
|
||||
|
||||
Simple, non-scalable credentials stored in a tab-separated file. \
|
||||
(logical successor to [envauth](https://github.com/therootcompany/golib/tree/main/auth/envauth))
|
||||
|
||||
1. Login Credentials
|
||||
- Save recoverable (aes or plain) or salted hashed passwords (pbkdf2 or bcrypt)
|
||||
- Great in http middleware, authorizing login or api requests
|
||||
- Stored by _username_
|
||||
2. Service Accounts
|
||||
- Store API keys for services like SMTP and S3
|
||||
- Great for contacting other services
|
||||
- Stored by _purpose_
|
||||
|
||||
Also useful for generating pbkdf2 or bcrypt hashes for manual entry in a _real_ database.
|
||||
|
||||
Can be adapted to pull from a Google Sheets URL (CSV format).
|
||||
|
||||
```sh
|
||||
# create login credentials
|
||||
csvauth store 'bot@example.com'
|
||||
|
||||
# store service account
|
||||
csvauth store --purpose 'postmark_smtp_notifier' 'admin@example.com'
|
||||
```
|
||||
|
||||
`credentials.tsv`:
|
||||
|
||||
```tsv
|
||||
purpose name algo salt derived roles extra
|
||||
ntfy_sh mytopic-1234 plain mytopic-1234
|
||||
s3_files account1 aes xxxxxxxxxxxx xxxxxxxxxxxxxxxx
|
||||
login johndoe pbkdf2 1000 16 SHA-256 5cLjzprCHP3WmMbzfqVaew k-elXFa4B_P4-iZ-Rr9GnA admin
|
||||
login janedoe bcrypt $2a$12$Xbe3OnIapGXUv9eF3k3cSu7sazeZSJquUwGzaovJxb9XQcN54/rte {"foo": "bar"}
|
||||
```
|
||||
|
||||
```go
|
||||
f, err := os.Open("./credentials.tsv")
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
auth, err := csvauth.Load(f)
|
||||
|
||||
// ...
|
||||
|
||||
if err := auth.Verify(username, password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ...
|
||||
|
||||
account := auth.LoadServiceAccount("account-mailer")
|
||||
req.SetBasicAuth(account.Name, account.Secret())
|
||||
```
|
||||
|
||||
## Login Credentials
|
||||
|
||||
1. Use `csvauth store [options] <username>` to create new login credentials.
|
||||
|
||||
```sh
|
||||
go run ./cmd/csvauth/ store --help
|
||||
```
|
||||
|
||||
```sh
|
||||
go run ./cmd/csvauth/ store 'john.doe@example.com'
|
||||
|
||||
go run ./cmd/csvauth/ store --algorithm aes-128-gcm 'johndoe'
|
||||
go run ./cmd/csvauth/ store --algorithm plain 'johndoe'
|
||||
go run ./cmd/csvauth/ store --algorithm 'pbkdf2 1000 16 SHA-256' 'johndoe'
|
||||
go run ./cmd/csvauth/ store --algorithm 'bcrypt 12' 'john.doe@example.com'
|
||||
|
||||
go run ./cmd/csvauth/ store --ask-password 'john.doe@example.com'
|
||||
go run ./cmd/csvauth/ store --password-file ./password.txt 'johndoe'
|
||||
|
||||
go run ./cmd/csvauth/ store --roles 'admin' --extra '{"foo":"bar"}' 'jimbob'
|
||||
```
|
||||
|
||||
2. Use `github.com/therootcompany/golib/auth/csvauth` to verify credentials
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/therootcompany/golib/auth/csvauth"
|
||||
)
|
||||
|
||||
var auth csvauth.Auth
|
||||
|
||||
func main() {
|
||||
f, _ := os.Open("./credentials.tsv")
|
||||
defer func() { _ = f.Close() }()
|
||||
auth, _ = csvauth.Load(f)
|
||||
|
||||
// ...
|
||||
}
|
||||
|
||||
func handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || !auth.Verify(username, password) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
credential, err := auth.LoadCredential(username)
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
## Service Account
|
||||
|
||||
1. Use `csvauth store --purpose <account> [options] <username>` to store API credentials
|
||||
|
||||
```sh
|
||||
go run ./cmd/csvauth/ store --help
|
||||
```
|
||||
|
||||
```sh
|
||||
go run ./cmd/csvauth/ store --purpose ntfy_sh_admins 'acme-admins-1234abcd'
|
||||
```
|
||||
|
||||
2. Use `github.com/therootcompany/golib/auth/csvauth` to verify credentials
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/therootcompany/golib/auth/csvauth"
|
||||
)
|
||||
|
||||
func main() {
|
||||
f, _ := os.Open("./credentials.tsv")
|
||||
defer func() { _ = f.Close() }()
|
||||
auth, _ := csvauth.Load(f)
|
||||
|
||||
// ...
|
||||
|
||||
credential := auth.LoadServiceAccount("ntfy_sh_admins")
|
||||
req, _ := http.NewRequest("POST", "https://ntfy.sh/"+credential.Secret(), bytes.NewBuffer(message))
|
||||
|
||||
// ...
|
||||
}
|
||||
```
|
470
auth/csvauth/cmd/csvauth/main.go
Normal file
470
auth/csvauth/cmd/csvauth/main.go
Normal file
@ -0,0 +1,470 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/csv"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/therootcompany/golib/auth/csvauth"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAESKeyENVName = "CSVAUTH_AES_128_KEY"
|
||||
defaultCSVFileENVName = "CSVAUTH_CSV_FILE"
|
||||
defaultCSVPath = "credentials.tsv"
|
||||
passwordEntropy = 12 // 96-bit
|
||||
)
|
||||
|
||||
var (
|
||||
keyRelPath = filepath.Join(".config", "csvauth", "aes-128.key")
|
||||
)
|
||||
|
||||
func main() {
|
||||
var subcmd string
|
||||
if len(os.Args) > 1 {
|
||||
subcmd = os.Args[1]
|
||||
}
|
||||
if len(os.Args) > 2 {
|
||||
switch os.Args[2] {
|
||||
case "", "help":
|
||||
os.Args[2] = "--help"
|
||||
}
|
||||
} else {
|
||||
os.Args = append(os.Args, "--help")
|
||||
}
|
||||
|
||||
homedir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
return
|
||||
}
|
||||
filename := filepath.Join(homedir, keyRelPath)
|
||||
csvPath := getCSVPath()
|
||||
|
||||
var aesKey []byte
|
||||
var csvFile csvauth.NamedReadCloser
|
||||
switch subcmd {
|
||||
case "store", "check":
|
||||
var keyErr error
|
||||
aesKey, keyErr = getAESKey(defaultAESKeyENVName, filename)
|
||||
if keyErr != nil {
|
||||
if os.IsNotExist(keyErr) {
|
||||
fmt.Fprintf(os.Stderr, "no AES key found, run 'csvauth init' to create it, or provide %s or ~/%s\n", defaultAESKeyENVName, keyRelPath)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", keyErr)
|
||||
}
|
||||
}
|
||||
|
||||
var csvErr error
|
||||
csvFile, csvErr = getCSVFile(csvPath)
|
||||
if csvErr != nil {
|
||||
if os.IsNotExist(csvErr) {
|
||||
fmt.Fprintf(os.Stderr, "no credentials file found, run 'csvauth init' to create it, or provide %s or %s\n", defaultCSVFileENVName, csvPath)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", csvErr)
|
||||
}
|
||||
}
|
||||
|
||||
if keyErr != nil || csvErr != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
}
|
||||
|
||||
switch subcmd {
|
||||
case "init":
|
||||
if err := handleInit(defaultAESKeyENVName, filename, csvPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "store":
|
||||
handleSet(os.Args[2:], aesKey, csvFile)
|
||||
case "check":
|
||||
handleCheck(os.Args[2:], aesKey, csvFile)
|
||||
case "--help", "-help", "help", "":
|
||||
fallthrough
|
||||
default:
|
||||
if len(subcmd) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "unknown subcommand %q\n", subcmd)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "USAGE\n\tcsvauth [store|check] [--help] [--algorithm <aes|plain|pbkdf2[,iters[,size[,hash]]]|bcrypt[,cost]] [--ask-password] [--password-file <filepath>] [--roles 'role1,role2'] [--extra '{\"foo\": \"bar\"}'] <username>\n\n")
|
||||
|
||||
handleSet([]string{"--help"}, nil, nil)
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
handleCheck([]string{"--help"}, nil, nil)
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
switch subcmd {
|
||||
case "--help", "-help", "help":
|
||||
return
|
||||
default:
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getCSVPath() string {
|
||||
path := os.Getenv(defaultCSVFileENVName)
|
||||
if len(path) == 0 {
|
||||
path = defaultCSVPath
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func getOrCreateAESKey(envname, filename string) ([]byte, error) {
|
||||
aesKey, err := getAESKey(envname, filename)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if aesKey != nil {
|
||||
return aesKey, nil
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(filename), 0750); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for %s: %v", filename, err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Creating new AES-128 key at %s\n", filename)
|
||||
key := make([]byte, 16)
|
||||
if _, err = io.ReadFull(rand.Reader, key); err != nil {
|
||||
panic(err) // the universe has run out of entropy
|
||||
}
|
||||
hexKey := hex.EncodeToString(key) + "\n"
|
||||
|
||||
if err := os.WriteFile(filename, []byte(hexKey), 0640); err != nil {
|
||||
return nil, fmt.Errorf("failed to write %s: %v", filename, err)
|
||||
}
|
||||
return aesKey, nil
|
||||
}
|
||||
|
||||
func getAESKey(envname, filename string) ([]byte, error) {
|
||||
envKey := os.Getenv(envname)
|
||||
if envKey != "" {
|
||||
key, err := hex.DecodeString(strings.TrimSpace(envKey))
|
||||
if err != nil || len(key) != 16 {
|
||||
return nil, fmt.Errorf("invalid %s: must be 32-char hex string", envname)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Found AES Key in %s\n", envname)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filename); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %v", filename, err)
|
||||
}
|
||||
key, err := hex.DecodeString(strings.TrimSpace(string(data)))
|
||||
if err != nil || len(key) != 16 {
|
||||
return nil, fmt.Errorf("invalid key in %s: must be 32-char hex string", filename)
|
||||
}
|
||||
// relpath := strings.Replace(filename, homedir, "~", 1)
|
||||
fmt.Fprintf(os.Stderr, "Found AES Key at %s\n", filename)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func getOrCreateCSVFile(csvPath string) (csvauth.NamedReadCloser, error) {
|
||||
r, err := getCSVFile(csvPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
csvAbs, err := filepath.Abs(csvPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Creating new credentials csv at %s\n", csvAbs)
|
||||
r, err = os.OpenFile(csvPath, os.O_RDWR|os.O_CREATE, 0640)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func getCSVFile(csvPath string) (csvauth.NamedReadCloser, error) {
|
||||
f, csvErr := os.Open(csvPath)
|
||||
if csvErr != nil {
|
||||
return nil, csvErr
|
||||
}
|
||||
|
||||
csvAbs, err := filepath.Abs(csvPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Found credentials db at %s\n", csvAbs)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func handleInit(keyenv, keypath, csvpath string) error {
|
||||
_, keyErr := getOrCreateAESKey(keyenv, keypath)
|
||||
_, csvErr := getOrCreateCSVFile(csvpath)
|
||||
|
||||
if keyErr != nil {
|
||||
return keyErr
|
||||
}
|
||||
|
||||
if csvErr != nil {
|
||||
return csvErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleSet(args []string, aesKey []byte, csvFile csvauth.NamedReadCloser) {
|
||||
storeFlags := flag.NewFlagSet("csvauth-store", flag.ContinueOnError)
|
||||
purpose := storeFlags.String("purpose", "login", "'login' for users, or a service account name, such as 'basecamp_api_key'")
|
||||
roleList := storeFlags.String("roles", "", "a space- or comma-separated list of roles (defined by you), such as 'triage audit'")
|
||||
extra := storeFlags.String("extra", "", "free form data to retrieve with the user (hint: JSON might be nice)")
|
||||
algorithm := storeFlags.String("algorithm", "", "Hash algorithm: aes, plain, pbkdf2[,iters[,size[,hash]]], or bcrypt[,cost]")
|
||||
askPassword := storeFlags.Bool("ask-password", false, "Read password from stdin")
|
||||
passwordFile := storeFlags.String("password-file", "", "Read password from file")
|
||||
// storeFlags.StringVar(&tsvPath, "tsv", tsvPath, "Credentials file to use")
|
||||
if err := storeFlags.Parse(args); err != nil {
|
||||
if err == flag.ErrHelp {
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(storeFlags.Args()) > 1 {
|
||||
fmt.Fprintf(os.Stderr, "too many arguments: %q\n", strings.Join(storeFlags.Args(), " "))
|
||||
fmt.Fprintf(os.Stderr, "note: flags should come before arguments\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
name := storeFlags.Arg(0)
|
||||
switch name {
|
||||
case "id", "name", "purpose":
|
||||
fmt.Fprintf(os.Stderr, "invalid username %q\n", name)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(*algorithm) == 0 {
|
||||
if *purpose == "login" {
|
||||
*algorithm = "pbkdf2"
|
||||
} else {
|
||||
// *algorithm = "plain"
|
||||
*algorithm = "aes-128-gcm"
|
||||
}
|
||||
}
|
||||
if *purpose != "login" {
|
||||
*askPassword = true
|
||||
}
|
||||
|
||||
var pass string
|
||||
if len(*passwordFile) > 0 {
|
||||
data, err := os.ReadFile(*passwordFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading password file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
pass = strings.TrimSpace(string(data))
|
||||
} else if *askPassword {
|
||||
fmt.Fprintf(os.Stderr, "New Password: ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
data, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading password from stdin: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
pass = strings.TrimSpace(data)
|
||||
} else {
|
||||
pass = generatePassword()
|
||||
fmt.Println(pass)
|
||||
}
|
||||
|
||||
*algorithm = strings.ReplaceAll(*algorithm, ",", " ")
|
||||
params := strings.Split(*algorithm, " ")
|
||||
switch params[0] {
|
||||
case "aes", "aes128", "aes-128":
|
||||
params[0] = "aes-128-gcm"
|
||||
}
|
||||
|
||||
var roles []string
|
||||
if len(*roleList) > 0 {
|
||||
*roleList = strings.ReplaceAll(*roleList, ",", " ")
|
||||
roles = strings.Split(*roleList, " ")
|
||||
}
|
||||
|
||||
defer func() { _ = csvFile.Close() }()
|
||||
auth := csvauth.New(aesKey)
|
||||
c := auth.NewCredential(*purpose, name, pass, params, roles, *extra)
|
||||
|
||||
if err := auth.LoadCSV(csvFile, '\t'); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CSV: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
_ = csvFile.Close()
|
||||
|
||||
var exists bool
|
||||
if len(*purpose) > 0 && *purpose != "login" {
|
||||
if _, err := auth.LoadServiceAccount(*purpose); err != nil {
|
||||
if !errors.Is(csvauth.ErrNotFound, err) {
|
||||
fmt.Fprintf(os.Stderr, "could not load %s: %v\n", *purpose, err)
|
||||
}
|
||||
} else {
|
||||
exists = true
|
||||
}
|
||||
c.Purpose = *purpose
|
||||
_ = auth.CacheServiceAccount(*c)
|
||||
} else {
|
||||
if _, err := auth.LoadCredential(name); err != nil {
|
||||
if !errors.Is(csvauth.ErrNotFound, err) {
|
||||
fmt.Fprintf(os.Stderr, "could not load %s: %v\n", name, err)
|
||||
}
|
||||
} else {
|
||||
exists = true
|
||||
}
|
||||
_ = auth.CacheCredential(*c)
|
||||
}
|
||||
|
||||
var records [][]string
|
||||
for _, purpose := range slices.Sorted(auth.ServiceAccountKeys()) {
|
||||
c, _ := auth.LoadServiceAccount(purpose)
|
||||
record := c.ToRecord()
|
||||
records = append(records, record)
|
||||
}
|
||||
for _, u := range slices.Sorted(auth.CredentialKeys()) {
|
||||
c, _ := auth.LoadCredential(u)
|
||||
record := c.ToRecord()
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
writeCSV(csvFile.Name(), records)
|
||||
if exists {
|
||||
fmt.Fprintf(os.Stderr, "Wrote %q with new password for %q\n", csvFile.Name(), name)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Added password for %q to %q\n", name, csvFile.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func handleCheck(args []string, aesKey []byte, csvFile csvauth.NamedReadCloser) {
|
||||
checkFlags := flag.NewFlagSet("csvauth-check", flag.ContinueOnError)
|
||||
purpose := checkFlags.String("purpose", "login", "'login' for users, or a service account name, such as 'basecamp_api_key'")
|
||||
_ = checkFlags.Bool("ask-password", true, "Read password from stdin")
|
||||
passwordFile := checkFlags.String("password-file", "", "Read password from file")
|
||||
// storeFlags.StringVar(&tsvPath, "tsv", tsvPath, "Credentials file to use")
|
||||
if err := checkFlags.Parse(args); err != nil {
|
||||
if err == flag.ErrHelp {
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(checkFlags.Args()) > 1 {
|
||||
fmt.Fprintf(os.Stderr, "too many arguments: %q\n", strings.Join(checkFlags.Args(), " "))
|
||||
fmt.Fprintf(os.Stderr, "note: flags should come before arguments\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
name := checkFlags.Arg(0)
|
||||
switch name {
|
||||
case "id", "name", "purpose":
|
||||
fmt.Fprintf(os.Stderr, "invalid username %q\n", name)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var pass string
|
||||
if len(*passwordFile) > 0 {
|
||||
data, err := os.ReadFile(*passwordFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading password file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
pass = strings.TrimSpace(string(data))
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Current Password: ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
data, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading password from stdin: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
pass = strings.TrimSpace(data)
|
||||
}
|
||||
|
||||
defer func() { _ = csvFile.Close() }()
|
||||
auth := csvauth.New(aesKey)
|
||||
|
||||
if err := auth.LoadCSV(csvFile, '\t'); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading CSV: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var v csvauth.BasicAuthVerifier
|
||||
var err error
|
||||
if *purpose != "login" {
|
||||
v, err = auth.LoadServiceAccount(*purpose)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "couldn't load %s: %v", *purpose, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
v = auth
|
||||
}
|
||||
|
||||
if err := v.Verify(name, pass); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "user '%s' not found or incorrect secret\n", name)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("verified")
|
||||
}
|
||||
|
||||
func writeCSV(csvPath string, records [][]string) {
|
||||
f, err := os.Create(csvPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating CSV: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
writer := csv.NewWriter(f)
|
||||
writer.Comma = '\t'
|
||||
|
||||
_ = writer.Write([]string{"purpose", "name", "algo", "salt", "derived", "roles", "extra"})
|
||||
for _, record := range records {
|
||||
_ = writer.Write(record)
|
||||
}
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing CSV: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func generatePassword() string {
|
||||
bytes := make([]byte, passwordEntropy)
|
||||
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
|
||||
panic(err) // the universe has run out of entropy
|
||||
}
|
||||
encoded := base64.RawURLEncoding.EncodeToString(bytes)
|
||||
parts := make([]string, 4)
|
||||
start := 0
|
||||
for i := range 4 {
|
||||
parts[i] = encoded[start : start+4]
|
||||
start += 4
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
172
auth/csvauth/credential.go
Normal file
172
auth/csvauth/credential.go
Normal file
@ -0,0 +1,172 @@
|
||||
package csvauth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BasicAuthVerifier interface {
|
||||
Verify(string, string) error
|
||||
}
|
||||
|
||||
const DefaultPurpose = "login"
|
||||
|
||||
type Purpose = string
|
||||
type Name = string
|
||||
|
||||
// Credential represents a row in the CSV file
|
||||
type Credential struct {
|
||||
Purpose Purpose
|
||||
Name Name
|
||||
plain string
|
||||
Params []string
|
||||
Salt []byte
|
||||
Derived []byte
|
||||
Roles []string
|
||||
Extra string
|
||||
}
|
||||
|
||||
func (c Credential) Secret() string {
|
||||
return c.plain
|
||||
}
|
||||
|
||||
func FromRecord(record []string) (Credential, error) {
|
||||
var roleList, extra string
|
||||
purpose, name, paramList, salt64, derived := record[0], record[1], record[2], record[3], record[4]
|
||||
if len(record) >= 6 {
|
||||
roleList = record[5]
|
||||
}
|
||||
if len(record) >= 7 {
|
||||
extra = record[6]
|
||||
}
|
||||
|
||||
return FromFields(purpose, name, paramList, salt64, derived, roleList, extra)
|
||||
}
|
||||
|
||||
func FromFields(purpose, name, paramList, saltBase64, derived, roleList, extra string) (Credential, error) {
|
||||
var credential Credential
|
||||
credential.Name = name
|
||||
|
||||
if len(purpose) == 0 {
|
||||
purpose = DefaultPurpose
|
||||
}
|
||||
credential.Purpose = purpose
|
||||
|
||||
var roles []string
|
||||
if len(roleList) > 0 {
|
||||
roleList = strings.ReplaceAll(roleList, ",", " ")
|
||||
roles = strings.Split(roleList, " ")
|
||||
}
|
||||
credential.Roles = roles
|
||||
|
||||
credential.Extra = extra
|
||||
|
||||
paramList = strings.ReplaceAll(paramList, ",", " ")
|
||||
credential.Params = strings.Split(paramList, " ")
|
||||
if len(credential.Params) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "no algorithm parameters for %q\n", name)
|
||||
}
|
||||
|
||||
switch credential.Params[0] {
|
||||
case "aes-128-gcm":
|
||||
if len(credential.Params) > 1 {
|
||||
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
|
||||
}
|
||||
|
||||
salt, err := base64.RawURLEncoding.DecodeString(saltBase64)
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
credential.Salt = salt
|
||||
|
||||
bytes, err := base64.RawURLEncoding.DecodeString(derived)
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
credential.Derived = bytes
|
||||
case "plain":
|
||||
if len(credential.Params) > 1 {
|
||||
return credential, fmt.Errorf("invalid plain parameters %#v", credential.Params)
|
||||
}
|
||||
|
||||
credential.plain = derived
|
||||
h := sha256.Sum256([]byte(derived))
|
||||
credential.Derived = h[:]
|
||||
case "pbkdf2":
|
||||
var err error
|
||||
|
||||
credential.Salt, err = base64.RawURLEncoding.DecodeString(saltBase64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "could not decode salt %q for %q\n", saltBase64, name)
|
||||
}
|
||||
|
||||
credential.Derived, err = base64.RawURLEncoding.DecodeString(derived)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "could not decode derived data %q for %q\n", derived, name)
|
||||
}
|
||||
|
||||
iters, err := strconv.Atoi(credential.Params[1])
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
if iters <= 0 {
|
||||
return credential, fmt.Errorf("invalid iterations %s", credential.Params[1])
|
||||
}
|
||||
|
||||
size, err := strconv.Atoi(credential.Params[2])
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
if size < 8 || size > 32 {
|
||||
return credential, fmt.Errorf("invalid size %s", credential.Params[2])
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"SHA-256", "SHA-1"}, credential.Params[3]) {
|
||||
return credential, fmt.Errorf("invalid hash %s", credential.Params[3])
|
||||
}
|
||||
case "bcrypt":
|
||||
if len(credential.Params) > 1 {
|
||||
return credential, fmt.Errorf("invalid bcrypt parameters %#v", credential.Params)
|
||||
}
|
||||
|
||||
credential.Derived = []byte(derived)
|
||||
default:
|
||||
return credential, fmt.Errorf("invalid algorithm %s", credential.Params[0])
|
||||
}
|
||||
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
func (c Credential) ToRecord() []string {
|
||||
var paramList, salt, derived string
|
||||
|
||||
paramList = strings.Join(c.Params, " ")
|
||||
switch c.Params[0] {
|
||||
case "aes-128-gcm":
|
||||
salt = base64.RawURLEncoding.EncodeToString(c.Salt)
|
||||
derived = base64.RawURLEncoding.EncodeToString(c.Derived)
|
||||
case "plain":
|
||||
salt = ""
|
||||
derived = c.plain
|
||||
case "pbkdf2":
|
||||
salt = base64.RawURLEncoding.EncodeToString(c.Salt)
|
||||
derived = base64.RawURLEncoding.EncodeToString(c.Derived)
|
||||
case "bcrypt":
|
||||
derived = string(c.Derived)
|
||||
default:
|
||||
panic(fmt.Errorf("unknown algorithm %q", c.Params[0]))
|
||||
}
|
||||
|
||||
purpose := c.Purpose
|
||||
if len(purpose) == 0 {
|
||||
purpose = DefaultPurpose
|
||||
}
|
||||
|
||||
record := []string{purpose, c.Name, paramList, salt, derived, strings.Join(c.Roles, " "), c.Extra}
|
||||
return record
|
||||
}
|
9
auth/csvauth/credentials.tsv
Normal file
9
auth/csvauth/credentials.tsv
Normal file
@ -0,0 +1,9 @@
|
||||
purpose name algo salt derived roles extra
|
||||
service1 acme aes-128-gcm 2z92DVgMF9Hn-GBy i37kF34cwa64j3tmnrvlJ5ZSekWD-w token1
|
||||
service2 acme plain token2 token2
|
||||
service3 user3 pbkdf2 1000 16 SHA-256 DYdA9iz1EN81bESTXcSgUg IzkeBCxRVmqybOBeAntfdA token3
|
||||
service4 user4 bcrypt $2a$12$HueMNxFGYIYtNNTySFW/Lu4vAMqpdcchBnJrW.VdYgP9xPQdITipu token4
|
||||
login user1 pbkdf2 1000 16 SHA-256 R-NgfDcY1A6L5a4jO89TNw -Pe9o-NwYvF6M4tlCwhm_g pass1
|
||||
login user2 bcrypt $2a$12$pad8UgUphO43PioF1JlSHOblRPdaX.ikTqjA8D1EfrcBiNGI9WQ/y pass2
|
||||
login user3 aes-128-gcm YC0xno0-W9pWR6rK D9CZFCtGGJecLpCv2Fk1I-wcXmN3 pass3
|
||||
login user4 plain pass4 pass4
|
|
427
auth/csvauth/csvauth.go
Normal file
427
auth/csvauth/csvauth.go
Normal file
@ -0,0 +1,427 @@
|
||||
package csvauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"iter"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrUnauthorized = errors.New("unauthorized")
|
||||
var ErrUnknownAlgorithm = errors.New("unknown algorithm")
|
||||
|
||||
const (
|
||||
defaultIters = 1000 // original 2000 recommendation
|
||||
defaultSize = 16 // 128-bit
|
||||
defaultHash = "SHA-256"
|
||||
defaultBcryptCost = 12
|
||||
gcmNonceSize = 12 // RFC spec
|
||||
)
|
||||
|
||||
// NamedReadCloser provides Name() for debugging of file-like ReadClosers, such as http responses
|
||||
type NamedReadCloser interface {
|
||||
io.ReadCloser
|
||||
Name() string
|
||||
}
|
||||
|
||||
type readNamer struct {
|
||||
io.ReadCloser
|
||||
name string
|
||||
}
|
||||
|
||||
// Name returns the name given to the wrapped ReadCloser to f8ulfill NamedReadCloser
|
||||
func (r *readNamer) Name() string {
|
||||
return r.name
|
||||
}
|
||||
|
||||
// NewNamedReadCloser wraps a ReadCloser with a name which can be referenced when debugging
|
||||
func NewNamedReadCloser(r io.ReadCloser, name string) NamedReadCloser {
|
||||
return &readNamer{
|
||||
ReadCloser: r,
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// Auth holds user the encryption key and both login and service account credentials
|
||||
type Auth struct {
|
||||
aes128key [16]byte
|
||||
credentials map[Name]Credential
|
||||
serviceAccounts map[Purpose]Credential
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// New initializes an Auth with an encryption key
|
||||
func New(aes128key []byte) *Auth {
|
||||
var aes128Arr [16]byte
|
||||
copy(aes128Arr[:], aes128key)
|
||||
|
||||
return &Auth{
|
||||
aes128key: aes128Arr,
|
||||
credentials: map[Name]Credential{},
|
||||
serviceAccounts: map[Purpose]Credential{},
|
||||
}
|
||||
}
|
||||
|
||||
// Load reads a credentials CSV from the given NamedReadCloser (e.g. file, wrapped http request)
|
||||
func (a *Auth) LoadCSV(f NamedReadCloser, comma rune) error {
|
||||
csvr := csv.NewReader(f)
|
||||
csvr.Comma = comma
|
||||
csvr.Comment = '#'
|
||||
csvr.FieldsPerRecord = -1 // ignore short rows
|
||||
_, _ = csvr.Read() // strip header row
|
||||
for {
|
||||
record, err := csvr.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(record) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(record) == 1 {
|
||||
if len(record[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(record) < 5 {
|
||||
return fmt.Errorf("invalid %q format: %#v (%d)", f.Name(), record, len(record))
|
||||
}
|
||||
|
||||
credential, err := FromRecord(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(credential.Purpose) == 0 || credential.Purpose == DefaultPurpose {
|
||||
if _, ok := a.credentials[credential.Name]; ok {
|
||||
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
||||
}
|
||||
a.credentials[credential.Name] = credential
|
||||
} else {
|
||||
if _, ok := a.serviceAccounts[credential.Purpose]; ok {
|
||||
fmt.Fprintf(os.Stderr, "overwriting cache of previous value for %s: %s\n", credential.Purpose, credential.Name)
|
||||
}
|
||||
a.serviceAccounts[credential.Purpose] = credential
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewCredential derives the hashed, encrypted, or raw value from the given secret and sets additional required and provided parameters
|
||||
func (a *Auth) NewCredential(purpose, name, secret string, params []string, roles []string, extra string) *Credential {
|
||||
c := &Credential{
|
||||
Purpose: purpose,
|
||||
Name: name,
|
||||
//plain: secret,
|
||||
Params: params,
|
||||
//Salt: ...
|
||||
//Derived: ...
|
||||
Roles: roles,
|
||||
Extra: extra,
|
||||
}
|
||||
|
||||
switch c.Params[0] {
|
||||
case "plain":
|
||||
if len(params) != 1 {
|
||||
fmt.Fprintf(os.Stderr, "invalid plain algorithm format: %q\n", strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
c.plain = secret
|
||||
|
||||
c.Params = []string{"plain"}
|
||||
h := sha256.Sum256([]byte(secret))
|
||||
c.Derived = h[:]
|
||||
case "aes-128-gcm":
|
||||
if len(params) != 1 {
|
||||
fmt.Fprintf(os.Stderr, "invalid aes-128-gcm algorithm format: %q\n", strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
c.Params = []string{"aes-128-gcm"}
|
||||
nonce := make([]byte, gcmNonceSize)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
c.Salt = nonce
|
||||
|
||||
var err error
|
||||
var salt [12]byte
|
||||
copy(salt[:], c.Salt)
|
||||
c.plain = secret
|
||||
c.Derived, err = gcmEncrypt(a.aes128key, salt, secret)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "could not aes-128-gcm encrypt secret: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "pbkdf2":
|
||||
if len(params) > 4 {
|
||||
fmt.Fprintf(os.Stderr, "invalid pbkdf2 algorithm format: %q\n", strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
iters := defaultIters
|
||||
if len(params) > 1 {
|
||||
var err error
|
||||
iters, err = strconv.Atoi(params[1])
|
||||
if err != nil || iters <= 0 {
|
||||
fmt.Fprintf(os.Stderr, "invalid iterations %q in %q\n", params[1], strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
size := defaultSize
|
||||
if len(params) > 2 {
|
||||
var err error
|
||||
size, err = strconv.Atoi(params[2])
|
||||
if err != nil || size < 8 || size > 32 {
|
||||
fmt.Fprintf(os.Stderr, "invalid size %q in %q\n", params[2], strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
hashName := defaultHash
|
||||
if len(params) > 3 {
|
||||
if !slices.Contains([]string{"SHA-256", "SHA-1"}, params[3]) {
|
||||
fmt.Fprintf(os.Stderr, "invalid hash %q in %q\n", params[3], strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
hashName = params[3]
|
||||
}
|
||||
c.Params = []string{"pbkdf2", strconv.Itoa(iters), strconv.Itoa(size), hashName}
|
||||
saltBytes := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, saltBytes); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
c.Salt = saltBytes
|
||||
var hasher func() hash.Hash
|
||||
hashNameUpper := strings.ToUpper(hashName)
|
||||
switch hashNameUpper {
|
||||
case "SHA-1", "SHA1":
|
||||
hashName = "SHA-1"
|
||||
hasher = sha1.New
|
||||
case "SHA-256", "SHA256":
|
||||
hashName = "SHA-256"
|
||||
hasher = sha256.New
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "invalid hash %q (expected SHA-1 or SHA-256)\n", hashName)
|
||||
os.Exit(1)
|
||||
}
|
||||
var err error
|
||||
c.Derived, err = pbkdf2.Key(hasher, secret, saltBytes, iters, size)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "invalid pbkdf2 parameters: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "bcrypt":
|
||||
if len(params) > 2 {
|
||||
fmt.Fprintf(os.Stderr, "invalid bcrypt algorithm format: %q\n", strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
cost := defaultBcryptCost
|
||||
if len(params) > 1 {
|
||||
var err error
|
||||
cost, err = strconv.Atoi(params[1])
|
||||
if err != nil || cost < 4 || cost > 31 {
|
||||
fmt.Fprintf(os.Stderr, "invalid bcrypt cost %q in %q\n", params[1], strings.Join(params, " "))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
c.Params = []string{"bcrypt"} // cost is included in the digest
|
||||
derived, err := bcrypt.GenerateFromPassword([]byte(secret), cost)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating bcrypt hash: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
c.Derived = derived
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "invalid algorithm %q\n", params[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func gcmEncrypt(aes128key [16]byte, gcmNonce [12]byte, secret string) ([]byte, error) {
|
||||
block, err := aes.NewCipher(aes128key[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new aes (encrypt) cipher failed: %v", err)
|
||||
}
|
||||
|
||||
// nonceSize := len(gcmNonce) // should always be 12
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new gcm (encrypt) failed: %v", err)
|
||||
}
|
||||
|
||||
plaintext := []byte(secret)
|
||||
ciphertext := gcm.Seal(nil, gcmNonce[:], plaintext, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// CredentialKeys returns the names that serve as IDs for each of the login credentials
|
||||
func (a *Auth) CredentialKeys() iter.Seq[Name] {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
return maps.Keys(a.credentials)
|
||||
}
|
||||
|
||||
func (a *Auth) LoadCredential(name Name) (Credential, error) {
|
||||
a.mux.Lock()
|
||||
c, ok := a.credentials[name]
|
||||
a.mux.Unlock()
|
||||
if !ok {
|
||||
return c, ErrNotFound
|
||||
}
|
||||
|
||||
var err error
|
||||
if c.plain, err = a.maybeDecryptCredential(c); err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (a *Auth) CacheCredential(c Credential) error {
|
||||
a.mux.Lock()
|
||||
a.credentials[c.Name] = c
|
||||
a.mux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialKeys returns the names that serve as IDs for each of the login credentials
|
||||
func (a *Auth) ServiceAccountKeys() iter.Seq[Purpose] {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
return maps.Keys(a.serviceAccounts)
|
||||
}
|
||||
|
||||
func (a *Auth) LoadServiceAccount(purpose Purpose) (Credential, error) {
|
||||
a.mux.Lock()
|
||||
c, ok := a.serviceAccounts[purpose]
|
||||
a.mux.Unlock()
|
||||
if !ok {
|
||||
return c, ErrNotFound
|
||||
}
|
||||
|
||||
var err error
|
||||
if c.plain, err = a.maybeDecryptCredential(c); err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (a *Auth) maybeDecryptCredential(c Credential) (string, error) {
|
||||
switch c.Params[0] {
|
||||
case "aes-128-gcm":
|
||||
var salt [12]byte
|
||||
copy(salt[:], c.Salt)
|
||||
return a.gcmDecrypt(a.aes128key, salt, c.Derived)
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
return c.plain, nil
|
||||
}
|
||||
|
||||
func (a *Auth) gcmDecrypt(aes128key [16]byte, gcmNonce [12]byte, derived []byte) (string, error) {
|
||||
block, err := aes.NewCipher(aes128key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new aes (decrypt) cipher failed: %v", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new gcm failed: %v", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, gcmNonce[:], derived, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm open (decryption) failed: %v", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
func (a *Auth) CacheServiceAccount(c Credential) error {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
a.serviceAccounts[c.Purpose] = c
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify checks Basic Auth credentials
|
||||
func (a *Auth) Verify(name, secret string) error {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
c, ok := a.credentials[name]
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
return c.Verify(name, secret)
|
||||
}
|
||||
|
||||
// Verify checks Basic Auth credentials
|
||||
func (c Credential) Verify(name, secret string) error {
|
||||
known := c.Derived
|
||||
var derived []byte
|
||||
switch c.Params[0] {
|
||||
case "aes-128-gcm":
|
||||
knownHash := sha256.Sum256([]byte(c.plain))
|
||||
known = knownHash[:]
|
||||
|
||||
h := sha256.Sum256([]byte(secret))
|
||||
derived = h[:]
|
||||
case "plain":
|
||||
h := sha256.Sum256([]byte(secret))
|
||||
derived = h[:]
|
||||
case "pbkdf2":
|
||||
// these are checked on load
|
||||
iters, _ := strconv.Atoi(c.Params[1])
|
||||
size, _ := strconv.Atoi(c.Params[2])
|
||||
var hasher func() hash.Hash
|
||||
switch c.Params[3] {
|
||||
case "SHA-1":
|
||||
hasher = sha1.New
|
||||
case "SHA-256":
|
||||
hasher = sha256.New
|
||||
default:
|
||||
panic(fmt.Errorf("invalid hash %q", c.Params[3]))
|
||||
}
|
||||
derived, _ = pbkdf2.Key(hasher, secret, c.Salt, iters, size)
|
||||
case "bcrypt":
|
||||
err := bcrypt.CompareHashAndPassword(c.Derived, []byte(secret))
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return ErrUnauthorized
|
||||
default:
|
||||
return ErrUnknownAlgorithm
|
||||
}
|
||||
|
||||
if bytes.Equal(known, derived) {
|
||||
return nil
|
||||
}
|
||||
return ErrUnauthorized
|
||||
}
|
145
auth/csvauth/csvauth_test.go
Normal file
145
auth/csvauth/csvauth_test.go
Normal file
@ -0,0 +1,145 @@
|
||||
package csvauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCredentialCreationAndVerification(t *testing.T) {
|
||||
type testCase struct {
|
||||
purpose string
|
||||
name string
|
||||
params []string
|
||||
roles []string
|
||||
extra string
|
||||
isLogin bool
|
||||
isRecoverable bool
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{"service1", "acme", []string{"aes-128-gcm"}, nil, "token1", false, true},
|
||||
{"service2", "acme", []string{"plain"}, nil, "token2", false, true},
|
||||
{"service3", "user3", []string{"pbkdf2", "1000", "16", "SHA-256"}, nil, "token3", false, false},
|
||||
{"service4", "user4", []string{"bcrypt"}, []string{"audit", "triage"}, "token4", false, false},
|
||||
{"login", "user1", []string{"pbkdf2", "1000", "16", "SHA-256"}, nil, "pass1", true, false},
|
||||
{"login", "user2", []string{"bcrypt"}, nil, "pass2", true, false},
|
||||
{"login", "user3", []string{"aes-128-gcm"}, nil, "pass3", true, true},
|
||||
{"login", "user4", []string{"plain"}, nil, "pass4", true, true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(fmt.Sprintf("%s/%s", tc.purpose, tc.name), func(t *testing.T) {
|
||||
var key [16]byte
|
||||
a := &Auth{
|
||||
aes128key: key,
|
||||
credentials: make(map[Name]Credential),
|
||||
serviceAccounts: make(map[Purpose]Credential),
|
||||
}
|
||||
secret := tc.extra
|
||||
c := a.NewCredential(tc.purpose, tc.name, secret, tc.params, tc.roles, tc.extra)
|
||||
if c == nil {
|
||||
t.Fatal("NewCredential returned nil")
|
||||
}
|
||||
if tc.isLogin {
|
||||
_ = a.CacheCredential(*c)
|
||||
} else {
|
||||
_ = a.CacheServiceAccount(*c)
|
||||
}
|
||||
record := c.ToRecord()
|
||||
|
||||
// Verify record format
|
||||
if record[0] != tc.purpose {
|
||||
t.Errorf("purpose mismatch: got %q want %q", record[0], tc.purpose)
|
||||
}
|
||||
if record[1] != tc.name {
|
||||
t.Errorf("name mismatch: got %q want %q", record[1], tc.name)
|
||||
}
|
||||
if record[2] != strings.Join(tc.params, " ") {
|
||||
t.Errorf("params mismatch: got %q want %q", record[2], strings.Join(tc.params, " "))
|
||||
}
|
||||
salt64 := record[3]
|
||||
derived64 := record[4]
|
||||
algo := tc.params[0]
|
||||
switch algo {
|
||||
case "plain":
|
||||
if salt64 != "" {
|
||||
t.Errorf("plain salt should be empty, got %q", salt64)
|
||||
}
|
||||
if derived64 != secret {
|
||||
t.Errorf("plain derived mismatch: got %q want %q", derived64, secret)
|
||||
}
|
||||
case "aes-128-gcm":
|
||||
saltb, err := base64.RawURLEncoding.DecodeString(salt64)
|
||||
if err != nil || len(saltb) != 12 {
|
||||
t.Errorf("gcm salt invalid: len %d err %v", len(saltb), err)
|
||||
}
|
||||
derivedb, err := base64.RawURLEncoding.DecodeString(derived64)
|
||||
if err != nil {
|
||||
t.Errorf("gcm derived %q invalid: err %v", derivedb, err)
|
||||
}
|
||||
case "pbkdf2":
|
||||
saltb, err := base64.RawURLEncoding.DecodeString(salt64)
|
||||
if err != nil || len(saltb) != 16 {
|
||||
t.Errorf("pbkdf2 salt invalid: len %d err %v", len(saltb), err)
|
||||
}
|
||||
derivedb, err := base64.RawURLEncoding.DecodeString(derived64)
|
||||
if err != nil || len(derivedb) != 16 {
|
||||
t.Errorf("pbkdf2 derived invalid: len %d err %v", len(derivedb), err)
|
||||
}
|
||||
case "bcrypt":
|
||||
if salt64 != "" {
|
||||
t.Errorf("bcrypt salt should be empty, got %q", salt64)
|
||||
}
|
||||
if !strings.HasPrefix(derived64, "$2a$12$") {
|
||||
t.Errorf("bcrypt derived invalid: got %q", derived64)
|
||||
}
|
||||
}
|
||||
if len(tc.roles) > 0 && record[5] != strings.Join(tc.roles, " ") {
|
||||
t.Errorf("roles mismatch: got %q want %q", record[5], strings.Join(tc.roles, " "))
|
||||
}
|
||||
if len(tc.extra) > 0 && record[6] != tc.extra {
|
||||
t.Errorf("extra mismatch: got %q want %q", record[6], tc.extra)
|
||||
}
|
||||
|
||||
// Verify functionality
|
||||
var c2 Credential
|
||||
var err error
|
||||
if tc.isLogin {
|
||||
if err := a.Verify(tc.name, secret); err != nil {
|
||||
t.Errorf("Auth.Verify failed for %s %s with %s: %v", tc.purpose, tc.name, secret, err)
|
||||
}
|
||||
c2, err = a.LoadCredential(tc.name)
|
||||
if err != nil {
|
||||
t.Errorf("LoadCredential failed for %s %s: %v", tc.purpose, tc.name, err)
|
||||
}
|
||||
} else {
|
||||
c2, err = a.LoadServiceAccount(tc.purpose)
|
||||
if err != nil {
|
||||
t.Errorf("LoadServiceAccount failed for %s %s: %v", tc.purpose, tc.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.isRecoverable {
|
||||
if c2.Secret() != secret {
|
||||
t.Errorf("Secret mismatch: got %q want %q", c2.Secret(), secret)
|
||||
}
|
||||
} else {
|
||||
if c2.Secret() != "" {
|
||||
t.Errorf("Secret should be empty for hashed service account, got %q", c2.Secret())
|
||||
}
|
||||
}
|
||||
|
||||
if err := c2.Verify(tc.name, secret); err != nil {
|
||||
t.Errorf("Auth.Verify failed for %s: %v", tc.name, err)
|
||||
}
|
||||
if err := c2.Verify(tc.name, ""); err == nil {
|
||||
t.Errorf("Auth.Verify incorrectly passed an empty password for %s %s", tc.purpose, tc.name)
|
||||
}
|
||||
if err := c2.Verify(tc.name, "wrong"); err == nil {
|
||||
t.Errorf("Auth.Verify incorrectly passed a wrong password for %s %s", tc.purpose, tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
5
auth/csvauth/go.mod
Normal file
5
auth/csvauth/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module github.com/therootcompany/golib/auth/csvauth
|
||||
|
||||
go 1.24.3
|
||||
|
||||
require golang.org/x/crypto v0.42.0
|
2
auth/csvauth/go.sum
Normal file
2
auth/csvauth/go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
Loading…
x
Reference in New Issue
Block a user