ref(cmd/auth-proxy): consolidate generic token logic in auth package

This commit is contained in:
AJ ONeal 2026-03-03 03:08:27 -07:00
parent c32acd5a74
commit 4bda5b4580
No known key found for this signature in database
2 changed files with 23 additions and 48 deletions

View File

@ -4,10 +4,13 @@ go 1.25.0
require ( require (
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/therootcompany/golib/auth/csvauth v1.2.2 github.com/therootcompany/golib/auth v1.1.1
github.com/therootcompany/golib/auth/csvauth v1.2.4
) )
require ( require golang.org/x/crypto v0.42.0 // indirect
github.com/therootcompany/golib/auth v1.0.0 // indirect
golang.org/x/crypto v0.42.0 // indirect replace (
github.com/therootcompany/golib/auth => ../../auth
github.com/therootcompany/golib/auth/csvauth => ../../auth/csvauth
) )

View File

@ -67,8 +67,6 @@ var (
var creds *csvauth.Auth var creds *csvauth.Auth
const basicAPIKeyName = ""
type MainConfig struct { type MainConfig struct {
Address string Address string
Port int Port int
@ -76,6 +74,7 @@ type MainConfig struct {
ProxyTarget string ProxyTarget string
AES128KeyPath string AES128KeyPath string
ShowVersion bool ShowVersion bool
BasicRealm string
AuthorizationHeaderSchemes []string AuthorizationHeaderSchemes []string
TokenHeaderNames []string TokenHeaderNames []string
QueryParamNames []string QueryParamNames []string
@ -84,6 +83,7 @@ type MainConfig struct {
tokenSchemeList string tokenSchemeList string
tokenHeaderList string tokenHeaderList string
tokenParamList string tokenParamList string
ra *auth.BasicRequestAuthenticator
} }
func (c *MainConfig) Addr() string { func (c *MainConfig) Addr() string {
@ -102,6 +102,7 @@ func main() {
tokenSchemeList: "", tokenSchemeList: "",
tokenHeaderList: "", tokenHeaderList: "",
tokenParamList: "", tokenParamList: "",
BasicRealm: "Basic",
AuthorizationHeaderSchemes: nil, // []string{"Bearer", "Token"} AuthorizationHeaderSchemes: nil, // []string{"Bearer", "Token"}
TokenHeaderNames: nil, // []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"}, TokenHeaderNames: nil, // []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"},
QueryParamNames: nil, // []string{"access_token", "token"}, QueryParamNames: nil, // []string{"access_token", "token"},
@ -282,6 +283,14 @@ func run(cli *MainConfig) {
log.Fatalf("Failed to load CSV auth: %v", err) log.Fatalf("Failed to load CSV auth: %v", err)
} }
cli.ra = &auth.BasicRequestAuthenticator{
Authenticator: creds,
AuthorizationSchemes: cli.AuthorizationHeaderSchemes,
TokenHeaders: cli.TokenHeaderNames,
TokenQueryParams: cli.QueryParamNames,
BasicRealm: cli.BasicRealm,
}
var usableRoles int var usableRoles int
for key := range creds.CredentialKeys() { for key := range creds.CredentialKeys() {
u, err := creds.LoadCredential(key) u, err := creds.LoadCredential(key)
@ -391,8 +400,7 @@ func (cli *MainConfig) newAuthProxyHandler(targetURL string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !cli.authorize(r) { if !cli.authorize(r) {
// TODO allow --realm for `WWW-Authenticate: Basic realm="My Application"` w.Header().Set("WWW-Authenticate", cli.ra.BasicRealm)
w.Header().Set("WWW-Authenticate", `Basic`)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
@ -593,47 +601,11 @@ func matchPattern(grant, rMethod, rHost, rPath string) bool {
} }
func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) { func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) {
// 1. Try Basic Auth first (cleanest path) cred, err := cli.ra.Authenticate(r)
username, password, ok := r.BasicAuth() if errors.Is(err, auth.ErrNoCredentials) {
if ok {
// Authorization: Basic <Auth> exists
return creds.Authenticate(username, password)
}
// 2. Any Authorization: <scheme> <token>
if len(cli.AuthorizationHeaderSchemes) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 {
if cli.AuthorizationHeaderSchemes[0] == "*" ||
slices.Contains(cli.AuthorizationHeaderSchemes, parts[0]) {
token := strings.TrimSpace(parts[1])
// Authorization: <Scheme> <Token> exists
return creds.Authenticate(basicAPIKeyName, token)
}
}
return nil, errors.New("'Authorization' header is not properly formatted")
}
}
// 3. API-Key / X-API-Key headers
for _, h := range cli.TokenHeaderNames {
if key := r.Header.Get(h); key != "" {
// <TokenHeader>: <Token> exists
return creds.Authenticate(basicAPIKeyName, key)
}
}
// 4. access_token query param
for _, h := range cli.QueryParamNames {
if token := r.URL.Query().Get(h); token != "" {
// <query_param>?=<Token> exists
return creds.Authenticate(basicAPIKeyName, token)
}
}
return nil, ErrNoAuth return nil, ErrNoAuth
}
return cred, err
} }
// peekOption looks for a flag value without parsing the full set // peekOption looks for a flag value without parsing the full set