From 13eeb6793b585df83abce4e97c37128af47cc54a Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Sat, 21 Feb 2026 17:48:17 -0700 Subject: [PATCH] feat: add cmd/auth-proxy to add Basic, Bearer, X-API-Key, or access_token auth to routes --- cmd/auth-proxy/go.mod | 13 + cmd/auth-proxy/go.sum | 12 + cmd/auth-proxy/main.go | 680 +++++++++++++++++++++++++++++++++ cmd/auth-proxy/pattern_test.go | 112 ++++++ 4 files changed, 817 insertions(+) create mode 100644 cmd/auth-proxy/go.mod create mode 100644 cmd/auth-proxy/go.sum create mode 100644 cmd/auth-proxy/main.go create mode 100644 cmd/auth-proxy/pattern_test.go diff --git a/cmd/auth-proxy/go.mod b/cmd/auth-proxy/go.mod new file mode 100644 index 0000000..11b8e93 --- /dev/null +++ b/cmd/auth-proxy/go.mod @@ -0,0 +1,13 @@ +module github.com/therootcompany/golib/cmd/auth-proxy + +go 1.25.0 + +require ( + github.com/joho/godotenv v1.5.1 + github.com/therootcompany/golib/auth/csvauth v1.2.2 +) + +require ( + github.com/therootcompany/golib/auth v1.0.0 // indirect + golang.org/x/crypto v0.42.0 // indirect +) diff --git a/cmd/auth-proxy/go.sum b/cmd/auth-proxy/go.sum new file mode 100644 index 0000000..713140d --- /dev/null +++ b/cmd/auth-proxy/go.sum @@ -0,0 +1,12 @@ +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/therootcompany/golib/auth v1.0.0 h1:17hfwcJO/Efc22/8RcCTKUD49mhCc5tyoHiKonA3Slg= +github.com/therootcompany/golib/auth v1.0.0/go.mod h1:DSw8llmDkMtvMZWrzrTRtcaLPpPMsT6Sg+qwGf5O2U8= +github.com/therootcompany/golib/auth/csvauth v1.2.0 h1:EgCT0pft4AhJz1omOntca5axv/B/1686wGIOgAS1++k= +github.com/therootcompany/golib/auth/csvauth v1.2.0/go.mod h1:e4kjRLWD7DWsuUT8+5MvPSYygXY6Is2UVeA9UpDllWU= +github.com/therootcompany/golib/auth/csvauth v1.2.1 h1:iy6S69nj/+Us0nx5pXke7SevgH1YaEM1vgbne7bnq70= +github.com/therootcompany/golib/auth/csvauth v1.2.1/go.mod h1:e4kjRLWD7DWsuUT8+5MvPSYygXY6Is2UVeA9UpDllWU= +github.com/therootcompany/golib/auth/csvauth v1.2.2 h1:t14Y8fSPhUZS6J+gRj0EFjkfE2JdIq93vOiKYcfItE8= +github.com/therootcompany/golib/auth/csvauth v1.2.2/go.mod h1://Iy5nbRa/bSndWa8b/NM4jclPVgkFWWhYUH6Ppelz0= +golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= +golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= diff --git a/cmd/auth-proxy/main.go b/cmd/auth-proxy/main.go new file mode 100644 index 0000000..59dafc8 --- /dev/null +++ b/cmd/auth-proxy/main.go @@ -0,0 +1,680 @@ +// auth-proxy - A reverse proxy to require Basic Auth, Bearer Token, or access_token +// +// Copyright 2026 AJ ONeal +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// This Source Code Form is "Incompatible With Secondary Licenses", as +// defined by the Mozilla Public License, v. 2.0. +// +// SPDX-License-Identifier: MPL-2.0 + +package main + +import ( + "context" + "encoding/hex" + "errors" + "flag" + "fmt" + "io" + "iter" + "log" + "net/http" + "net/http/httputil" + "net/url" + "os" + "os/signal" + "path/filepath" + "slices" + "strings" + "syscall" + "time" + "unicode/utf8" + + "github.com/joho/godotenv" + + "github.com/therootcompany/golib/auth" + "github.com/therootcompany/golib/auth/csvauth" +) + +const ( + name = "auth-proxy" + licenseYear = "2026" + licenseOwner = "AJ ONeal (https://therootcompany.com)" + licenseType = "CC0-1.0" +) + +// replaced by goreleaser / ldflags +var ( + version = "0.0.0-dev" + commit = "0000000" + date = "0001-01-01" +) + +// printVersion displays the version, commit, and build date. +func printVersion(w io.Writer) { + _, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date) + _, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner) + _, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType) +} + +var ( + ErrNoAuth = errors.New("request missing the required form of authorization") +) + +var creds *csvauth.Auth + +const basicAPIKeyName = "" + +type MainConfig struct { + Address string + Port int + CredentialsPath string + ProxyTarget string + AES128KeyPath string + ShowVersion bool + AuthorizationHeaderSchemes []string + TokenHeaderNames []string + QueryParamNames []string + comma rune + commaString string + tokenSchemeList string + tokenHeaderList string + tokenParamList string +} + +func (c *MainConfig) Addr() string { + return fmt.Sprintf("%s:%d", c.Address, c.Port) +} + +func main() { + cli := MainConfig{ + Address: "0.0.0.0", + Port: 8081, + CredentialsPath: "./credentials.tsv", + ProxyTarget: "http://127.0.0.1:8080", + AES128KeyPath: filepath.Join("~", ".config", "csvauth", "aes-128.key"), + comma: '\t', + commaString: "", + tokenSchemeList: "", + tokenHeaderList: "", + tokenParamList: "", + AuthorizationHeaderSchemes: nil, // []string{"Bearer", "Token"} + TokenHeaderNames: nil, // []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"}, + QueryParamNames: nil, // []string{"access_token", "token"}, + } + + // Peek for --envfile early + envPath := peekOption(os.Args[1:], []string{"-envfile", "--envfile"}, ".env") + _ = godotenv.Load(envPath) // silent if missing + + // Override defaults from env + if v := os.Getenv("AUTHPROXY_PORT"); v != "" { + if _, err := fmt.Sscanf(v, "%d", &cli.Port); err != nil { + fmt.Fprintf(os.Stderr, "invalid AUTHPROXY_PORT value: %s\n", v) + os.Exit(1) + } + } + if v := os.Getenv("AUTHPROXY_ADDRESS"); v != "" { + cli.Address = v + } + if v := os.Getenv("AUTHPROXY_CREDENTIALS_FILE"); v != "" { + cli.CredentialsPath = v + } + if v := os.Getenv("AUTHPROXY_TARGET"); v != "" { + cli.ProxyTarget = v + } + + // Flags + fs := flag.NewFlagSet(name, flag.ContinueOnError) + fs.BoolVar(&cli.ShowVersion, "version", false, "show version and exit") + fs.IntVar(&cli.Port, "port", cli.Port, "port to listen on") + fs.StringVar(&cli.Address, "address", cli.Address, "address to bind to (e.g. 127.0.0.1)") + fs.StringVar(&cli.AES128KeyPath, "aes-128-key", cli.AES128KeyPath, "path to credentials TSV/CSV file") + fs.StringVar(&cli.CredentialsPath, "credentials", cli.CredentialsPath, "path to credentials TSV/CSV file") + fs.StringVar(&cli.ProxyTarget, "proxy-target", cli.ProxyTarget, "upstream target to proxy requests to") + fs.StringVar(&cli.commaString, "comma", "\\t", "single-character CSV separator for credentials file (literal characters and escapes accepted)") + fs.StringVar(&cli.tokenSchemeList, "token-schemes", "Bearer,Token", "checks for header 'Authorization: '") + fs.StringVar(&cli.tokenHeaderList, "token-headers", "X-API-Key,X-Auth-Token,X-Access-Token", "checks for header ': '") + fs.StringVar(&cli.tokenParamList, "token-params", "access_token,token", "checks for query param '?='") + + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "USAGE\n %s [flags]\n\n", name) + fmt.Fprintf(os.Stderr, "FLAGS\n") + fs.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nENVIRONMENT\n") + fmt.Fprintf(os.Stderr, " AUTHPROXY_PORT port to listen on\n") + fmt.Fprintf(os.Stderr, " AUTHPROXY_ADDRESS bind address\n") + fmt.Fprintf(os.Stderr, " AUTHPROXY_CREDENTIALS_FILE path to tokens file\n") + fmt.Fprintf(os.Stderr, " AUTHPROXY_TARGET upstream URL\n") + } + + // Special handling for version/help + if len(os.Args) > 1 { + arg := os.Args[1] + if arg == "-V" || arg == "--version" || arg == "version" { + printVersion(os.Stdout) + os.Exit(0) + } + if arg == "help" || arg == "-help" || arg == "--help" { + printVersion(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "") + fs.SetOutput(os.Stdout) + fs.Usage() + os.Exit(0) + } + } + + printVersion(os.Stderr) + fmt.Fprintln(os.Stderr, "") + + if err := fs.Parse(os.Args[1:]); err != nil { + if err == flag.ErrHelp { + fs.Usage() + os.Exit(0) + } + log.Fatalf("flag parse error: %v", err) + } + + { + homedir, err := os.UserHomeDir() + if err == nil { + var found bool + if cli.AES128KeyPath, found = strings.CutPrefix(cli.AES128KeyPath, "~"); found { + cli.AES128KeyPath = homedir + cli.AES128KeyPath + } + } + } + + // credentials file delimiter + var err error + cli.comma, err = DecodeDelimiter(cli.commaString) + if err != nil { + log.Fatalf("comma parse error: %v", err) + } + + // Authorization: + cli.tokenSchemeList = strings.TrimSpace(cli.tokenSchemeList) + if cli.tokenSchemeList != "" && cli.tokenSchemeList != "none" { + cli.tokenSchemeList = strings.ReplaceAll(cli.tokenSchemeList, ",", " ") + cli.AuthorizationHeaderSchemes = strings.Fields(cli.tokenSchemeList) + if len(cli.AuthorizationHeaderSchemes) == 1 && cli.AuthorizationHeaderSchemes[0] == "" { + cli.AuthorizationHeaderSchemes = nil + } + } + + // : + // trick: this allows `Authorization: ` without the scheme + cli.tokenHeaderList = strings.TrimSpace(cli.tokenHeaderList) + if cli.tokenHeaderList != "" && cli.tokenHeaderList != "none" { + cli.tokenHeaderList = strings.ReplaceAll(cli.tokenHeaderList, ",", " ") + cli.TokenHeaderNames = strings.Fields(cli.tokenHeaderList) + if len(cli.TokenHeaderNames) == 1 && cli.TokenHeaderNames[0] == "" { + cli.TokenHeaderNames = nil + } + } + + // ?= + // trick: this allows `Authorization: ` without the scheme + cli.tokenParamList = strings.TrimSpace(cli.tokenParamList) + if cli.tokenParamList != "" && cli.tokenParamList != "none" { + cli.tokenParamList = strings.ReplaceAll(cli.tokenParamList, ",", " ") + cli.QueryParamNames = strings.Fields(cli.tokenParamList) + if len(cli.QueryParamNames) == 1 && cli.QueryParamNames[0] == "" { + cli.QueryParamNames = nil + } + } + + run(&cli) +} + +const ( + fileSeparator = '\x1c' + groupSeparator = '\x1d' + recordSeparator = '\x1e' + unitSeparator = '\x1f' +) + +func DecodeDelimiter(delimString string) (rune, error) { + switch delimString { + case "^_", "\\x1f": + delimString = string(unitSeparator) + case "^^", "\\x1e": + delimString = string(recordSeparator) + case "^]", "\\x1d": + delimString = string(groupSeparator) + case "^\\", "\\x1c": + delimString = string(fileSeparator) + case "^L", "\\f": + delimString = "\f" + case "^K", "\\v": + delimString = "\v" + case "^I", "\\t": + delimString = "\t" + default: + // it is what it is + } + delim, _ := utf8.DecodeRuneInString(delimString) + return delim, nil +} + +func run(cli *MainConfig) { + defaultAESKeyENVName := "CSVAUTH_AES_128_KEY" + aesKey, err := getAESKey(defaultAESKeyENVName, cli.AES128KeyPath) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) + os.Exit(1) + return + } + + // Load credentials from CSV/TSV file once at startup + f, err := os.Open(cli.CredentialsPath) + if err != nil { + log.Fatalf("Failed to open credentials file %q: %v", cli.CredentialsPath, err) + } + defer func() { _ = f.Close() }() + + creds = csvauth.New(aesKey) + if err := creds.LoadCSV(f, cli.comma); err != nil { + log.Fatalf("Failed to load CSV auth: %v", err) + } + + var usableRoles int + for key := range creds.CredentialKeys() { + u, err := creds.LoadCredential(key) + if err != nil { + log.Fatalf("Failed to read users from CSV auth: %v", err) + } + if len(u.Roles) == 0 { + continue + } + if usableRoles == 0 { + fmt.Fprintf(os.Stderr, "Current credentials, tokens, and permissions:\n") + fmt.Fprintf(os.Stderr, " %s\t%s\t%s\n", u.Purpose, u.ID(), strings.Join(u.Roles, " ")) + } + usableRoles += 1 + } + + var warnRoles bool + for key := range creds.CredentialKeys() { + u, _ := creds.LoadCredential(key) + if len(u.Roles) > 0 { + continue + } + if !warnRoles { + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "WARNING - Please Read\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "The following credentials cannot be used because they contain no Roles:\n") + warnRoles = true + } + fmt.Fprintf(os.Stderr, " %q (%s)\n", u.Name, u.Purpose) + } + if warnRoles { + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "Permission must be explicitly granted in the Roles column as a space-separated\n") + fmt.Fprintf(os.Stderr, "list of URI matchers in the form of \"[METHOD:][HOST]/[PATH]\"\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "CLI Examples\n") + fmt.Fprintf(os.Stderr, " authcsv store --roles '/' john.doe\n") + fmt.Fprintf(os.Stderr, " authcsv store --roles 'GET:example.com/mail POST:example.com/feed' --token openclaw\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "CSV Examples\n") + fmt.Fprintf(os.Stderr, " GET:example.com/ # GET-only access to example.com, for all paths\n") + fmt.Fprintf(os.Stderr, " / # Full access to everything\n") + fmt.Fprintf(os.Stderr, " GET:/ POST:/logs # GET anything, POST only to /logs/... \n") + fmt.Fprintf(os.Stderr, " ex1.com/ GET:ex2.net/logs # full access to ex1.com, GET-only for ex2.net\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "\n") + } + if usableRoles == 0 { + fmt.Fprintf(os.Stderr, "Error: no usable credentials found\n") + os.Exit(1) + } + + // Build proxy handler + handler := cli.newAuthProxyHandler(cli.ProxyTarget) + + // Server setup + srv := &http.Server{ + Addr: cli.Addr(), + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + // Graceful shutdown + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-done + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + log.Printf("HTTP server shutdown error: %v", err) + } + }() + + log.Printf("Starting %s v%s on %s → %s", name, version, srv.Addr, cli.ProxyTarget) + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("HTTP server failed: %v", err) + } + + log.Println("Server stopped") +} + +func (cli *MainConfig) newAuthProxyHandler(targetURL string) http.Handler { + target, err := url.Parse(targetURL) + if err != nil { + log.Fatalf("invalid proxy target %q: %v", targetURL, err) + } + + proxy := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(target) + r.Out.Host = r.In.Host // preserve original Host header + // X-Forwarded-* headers are preserved from incoming request + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("proxy error: %v", err) + http.Error(w, "Bad Gateway", http.StatusBadGateway) + }, + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !cli.authorize(r) { + // TODO allow --realm for `WWW-Authenticate: Basic realm="My Application"` + w.Header().Set("WWW-Authenticate", `Basic`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + proxy.ServeHTTP(w, r) + }) +} + +func (cli *MainConfig) authorize(r *http.Request) bool { + cred, err := cli.authenticate(r) + if err != nil { + if !errors.Is(err, ErrNoAuth) { + return false + } + cred, err = creds.Authenticate("guest", "") + if err != nil { + return false + } + } + + grants := cred.Permissions() + if len(grants) == 0 { + // must have at least '/' + fmt.Fprintf(os.Stderr, "Warn: user %q correctly authenticated, but no --roles were specified (assign * or / for full access)\n", cred.ID()) + return false + } + + if grants[0] == "*" || grants[0] == "/" { + return true + } + + // GET,POST example.com/path/{$} + for _, grant := range grants { + if matchPattern(grant, r.Method, r.Host, r.URL.Path) { + return true + } + } + + return false +} + +// patternMatch returns true for a grant in the form of a ServeMux pattern matches the current request +// (though : is used instead of space since space is already used as a separator) +// +// grant is in the form `[METHOD:][HOST]/[PATH]`, but METHOD may be a comma-separated list GET,POST... +// rMthod must be ALL_CAPS +// rHost may be HOSTNAME or HOSTNAME:PORT +// rPath must start with / +func matchPattern(grant, rMethod, rHost, rPath string) bool { + // this should have been done already, but... + grant = strings.TrimSpace(grant) + + // must have at least / + // "" + if grant == "" { + fmt.Fprintf(os.Stderr, "DEBUG: missing grant\n") + return false + } + + // / => ["/"] + // /path => ["/path"] + // example.com/path => ["example.com/path"] + // GET:example.com/path => ["GET", "example.com/path"] + { + var methodSep = ":" + var methods []string + grantParts := strings.SplitN(grant, methodSep, 2) + switch len(grantParts) { + case 1: + // no method + case 2: + if len(grantParts) == 2 { + methods = strings.Split(strings.ToUpper(grantParts[0]), ",") + if !slices.Contains(methods, rMethod) { + // TODO maybe propagate method-not-allowed? + fmt.Fprintf(os.Stderr, "DEBUG: method %q != %q\n", rMethod, grantParts[0]) + return false + } + grant = grantParts[1] + } + default: + fmt.Fprintf(os.Stderr, "DEBUG: extraneous spaces in grant %q\n", grant) + return false + } + } + + // / => / + // /path => /path + // example.com/path => /path + idx := strings.Index(grant, "/") + if idx < 0 { + // host without path is invalid + fmt.Fprintf(os.Stderr, "DEBUG: missing leading / from grant %q\n", grant) + return false + } + hostname := grant[:idx] + if hostname != "" { + // example.com:443 => example.com + if h, _, _ := strings.Cut(rHost, ":"); hostname != h { + // hostname doesn't match + fmt.Fprintf(os.Stderr, "DEBUG: hostname %q != %q\n", rHost, hostname) + return false + } + } + grant = grant[idx:] + + // Prefix-only matching + // + // /path => /path + // /path/ => /path + // /path/{var}/bar => /path/{var}/bar + // /path/{var...} = /path/{var} + // /path/{$} => /path + // var exact bool + // grant, exact = strings.CutSuffix(grant, "/{$}") + // grant, _ = strings.CutSuffix(grant, "/") + // rPath, _ = strings.CutSuffix(rPath, "/") + // if len(grantPaths) > len(rPaths) { + // return false + // } else if len(grantPaths) < len(rPaths) { + // if exact { + // return false + // } + // } + + // // TODO replace with pattern matching as per https://pkg.go.dev/net/http#hdr-Patterns-ServeMux + // // /path/{var}/bar matches /path/foo/bar and /path/foo/bar/ + // // /path/{var...} matches /path/ and /path/foo/bar + // // /path/{var}/bar/{$} matches /path/foo/bar and /path/baz/bar/ but not /path/foo/ + // for i := 1; i < len(grantPaths); i++ { + // grantPath := grantPaths[i] + // rPath := rPaths[i] + // if strings.HasPrefix(grantPath, "{") { + // continue + // } + // if rPath != grantPath { + // return false + // } + // } + // return true + + // ServeMux pattern matching + nextGPath, gstop := iter.Pull(strings.SplitSeq(grant, "/")) + nextRPath, rstop := iter.Pull(strings.SplitSeq(rPath, "/")) + defer gstop() + defer rstop() + + for { + gp, gok := nextGPath() + rp, rok := nextRPath() + // everything has matched thus far, and the pattern has ended + if !gok { + return true + } + + // false unless the extra length of the pattern signifies the exact match, disregarding trailing / + if !rok { + // this matches trailing /, {var}, {var}/, {var...}, and {$} + if gp == "" || (strings.HasPrefix(gp, "{") && strings.HasSuffix(gp, "}")) { + gp2, more := nextGPath() + // this allows for one more final trailing /, but nothing else + if !more { + return true + } + if gp2 == "" { + // two trailing slashes are not allowed + _, more := nextGPath() + return !more + } + } + return false + } + + // path parts are only allowed to disagree for trailing slashes and variables + if gp != rp { + // this allows for one more final trailing / on the pattern, but nothing else + if gp == "" { + _, more := nextGPath() + return !more + } + // this allows for a placeholder in the pattern + if strings.HasPrefix(gp, "{") && strings.HasSuffix(gp, "}") { + // normal variables pass + if gp != "{$}" { + continue + } + // trailing slash on exact match passes + if rp == "" { + _, more := nextRPath() + return !more + } + fmt.Fprintf(os.Stderr, "DEBUG: path past {$} %q vs %q\n", rp, gp) + return false + } + fmt.Fprintf(os.Stderr, "DEBUG: path part %q != %q\n", rp, gp) + return false + } + } +} + +func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) { + // 1. Try Basic Auth first (cleanest path) + username, password, ok := r.BasicAuth() + if ok { + // Authorization: Basic exists + return creds.Authenticate(username, password) + } + + // 2. Any Authorization: + if len(cli.AuthorizationHeaderSchemes) > 0 { + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) == 2 { + if cli.AuthorizationHeaderSchemes[0] == "*" || + slices.Contains(cli.AuthorizationHeaderSchemes, parts[0]) { + token := strings.TrimSpace(parts[1]) + // Authorization: exists + return creds.Authenticate(basicAPIKeyName, token) + } + } + return nil, errors.New("'Authorization' header is not properly formatted") + } + } + + // 3. API-Key / X-API-Key headers + for _, h := range cli.TokenHeaderNames { + if key := r.Header.Get(h); key != "" { + // : exists + return creds.Authenticate(basicAPIKeyName, key) + } + } + + // 4. access_token query param + for _, h := range cli.QueryParamNames { + if token := r.URL.Query().Get(h); token != "" { + // ?= exists + return creds.Authenticate(basicAPIKeyName, token) + } + } + + return nil, ErrNoAuth +} + +// peekOption looks for a flag value without parsing the full set +func peekOption(args []string, names []string, def string) string { + for i := range len(args) { + for _, name := range names { + if args[i] == name { + if i+1 < len(args) { + return args[i+1] + } + } + } + } + return def +} + +// TODO expose this from csvauth +func getAESKey(envname, filename string) ([]byte, error) { + envKey := os.Getenv(envname) + if envKey != "" { + key, err := hex.DecodeString(strings.TrimSpace(envKey)) + if err != nil || len(key) != 16 { + return nil, fmt.Errorf("invalid %s: must be 32-char hex string", envname) + } + fmt.Fprintf(os.Stderr, "Found AES Key in %s\n", envname) + return key, nil + } + + if _, err := os.Stat(filename); err != nil { + return nil, err + } + + data, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %v", filename, err) + } + key, err := hex.DecodeString(strings.TrimSpace(string(data))) + if err != nil || len(key) != 16 { + return nil, fmt.Errorf("invalid key in %s: must be 32-char hex string", filename) + } + // relpath := strings.Replace(filename, homedir, "~", 1) + fmt.Fprintf(os.Stderr, "Found AES Key at %s\n", filename) + return key, nil +} diff --git a/cmd/auth-proxy/pattern_test.go b/cmd/auth-proxy/pattern_test.go new file mode 100644 index 0000000..a969d0e --- /dev/null +++ b/cmd/auth-proxy/pattern_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "fmt" + "strings" + "testing" +) + +func TestMatchPattern(t *testing.T) { + tests := []struct { + grant string + method string + host string + path string + want bool + }{ + // Basic path matching + {"/", "GET", "example.com", "/", true}, + {"GET:/", "POST", "example.com", "/", false}, + {"/api/users", "GET", "api.example.com", "/api/users", true}, + {"/api/users", "GET", "api.example.com", "/api/users/", true}, + {"/api/users", "GET", "api.example.com", "/api/users", true}, + {"/api/users/", "GET", "", "/api/users", true}, + + // Host matching + {"example.com/", "GET", "example.com", "/", true}, + {"GET:example.com/", "GET", "example.com", "/", true}, + {"whatever.net/", "GET", "example.com", "/", false}, + {"example.comz/", "GET", "example.com", "/", false}, + {"example.com/", "GET", "example.comz", "/", false}, + {"aexample.com/", "GET", "example.com", "/", false}, + {"example.com/", "GET", "aexample.com", "/", false}, + {"example.com/", "GET", "api.example.com", "/", false}, + {"api.example.com/", "GET", "example.com", "/", false}, + {".example.com/", "GET", "api.example.com", "/", false}, + {"api.example.com/", "GET", "", "/", false}, + {"GET:api.example.com/", "GET", "example.com", "/", false}, + {"example.com/", "GET", "example.com:443", "/", true}, + {"GET:example.com/", "GET", "example.com:443", "/", true}, + + // Method lists + {"GET,POST,PUT:/api", "POST", "", "/api", true}, + {"GET,DELETE:/api", "POST", "", "/api", false}, + + // Wildcard / placeholder segments + // bad + {"/users/{id}", "GET", "", "/user", false}, + // good + {"/users/{id}", "GET", "", "/users", true}, + {"/users/{id}", "GET", "", "/users/", true}, + {"/users/{id}", "GET", "", "/users/123", true}, + {"/users/{id}", "GET", "", "/users/123/", true}, + {"/users/{id}", "GET", "", "/users/123/friends", true}, + // bad + {"/users/{id}/", "GET", "", "/user", false}, + // good + {"/users/{id}/", "GET", "", "/users", true}, + {"/users/{id}/", "GET", "", "/users/", true}, + {"/users/{id}/", "GET", "", "/users/123", true}, + {"/users/{id}/", "GET", "", "/users/123/", true}, + {"/users/{id}/", "GET", "", "/users/123/friends", true}, + // good (these are exactly the same as /path/{var} above, but added for completeness) + {"/users/{id...}", "GET", "", "/users", true}, + {"/users/{id...}", "GET", "", "/users/", true}, + {"/users/{id...}", "GET", "", "/users/123", true}, + {"/users/{id...}", "GET", "", "/users/123/", true}, + {"/users/{id...}", "GET", "", "/users/123/friends", true}, + // good + {"/users/{id}", "GET", "", "/users/123/bar", true}, + {"/users/{id}", "GET", "", "/users/123/bar/", true}, + {"/users/{id}", "GET", "", "/users/123/bar/456", true}, + {"/users/{id}", "GET", "", "/users/123/bar", true}, + {"/users/{id}", "GET", "", "/users/123/bar/", true}, + {"/users/{id}", "GET", "", "/users/123/bar/456", true}, + // good + {"/users/{id}/{$}", "GET", "", "/users/123", true}, + {"/users/{id}/{$}", "GET", "", "/users/123/", true}, + // wrong + {"/users/{id}/bar", "GET", "", "/users/123", false}, + {"/users/{id}/bar", "GET", "", "/users/123/", false}, + {"/users/{id}/bar", "GET", "", "/users/123/b", false}, + {"/users/{id}/bar", "GET", "", "/users/123/b/", false}, + {"/users/{id}/bar", "GET", "", "/users/123/b/456", false}, + {"/users/{id}/{$}", "GET", "", "/users/123/b", false}, + // {$} + {"/api/{ver}/items/{$}", "GET", "", "/api/v1/items", true}, + {"/api/{ver}/items/{$}", "GET", "", "/api/v1/items/", true}, + {"/api/{ver}/items/{$}", "GET", "", "/api/v1/items/42", false}, + {"/foo/{$}/baz", "GET", "", "/foo/bar/baz", false}, + + // Edge cases & invalid patterns + {"", "GET", "example.com", "/", false}, + {"", "GET", "", "/", false}, + {"GET", "GET", "example.com", "/", false}, + {"GET:", "GET", "example.com", "/", false}, + {"example.com", "GET", "example.com", "/", false}, + {"GET:example.com", "GET", "example.com", "/", false}, + {"GET:/users ", "GET", "", "/users", true}, + } + + for _, tt := range tests { + // name := "Pattern " + tt.grant + " vs URI " + strings.TrimSpace(fmt.Sprintf("%s %s%s", tt.method, tt.host, tt.path)) + name := tt.grant + " vs " + strings.TrimSpace(fmt.Sprintf("%s %s%s", tt.method, tt.host, tt.path)) + t.Run(name, func(t *testing.T) { + got := matchPattern(tt.grant, tt.method, tt.host, tt.path) + if got != tt.want { + t.Errorf("matchPattern(%q, %q, %q, %q) = %v, want %v", + tt.grant, tt.method, tt.host, tt.path, got, tt.want) + } + }) + } +}