diff --git a/auth/request.go b/auth/request.go index ed58af2..883af62 100644 --- a/auth/request.go +++ b/auth/request.go @@ -17,6 +17,10 @@ var ErrNoCredentials = errors.New("no credentials provided") // // Use NewRequestAuthenticator for sane defaults. type RequestAuthenticator struct { + // Authenticator is the credential verifier called with the extracted + // username/password or token. Must be set before calling Authenticate. + Authenticator BasicAuthenticator + // AuthenticateBasic enables HTTP Basic Auth (Authorization: Basic …). AuthenticateBasic bool @@ -53,7 +57,9 @@ func NewRequestAuthenticator() *RequestAuthenticator { // 4. Query parameters (TokenQueryParams) // // Returns ErrNoCredentials if no credential form is present in the request. -func (ra *RequestAuthenticator) Authenticate(r *http.Request, a BasicAuthenticator) (BasicPrinciple, error) { +func (ra *RequestAuthenticator) Authenticate(r *http.Request) (BasicPrinciple, error) { + a := ra.Authenticator + // 1. Basic Auth if ra.AuthenticateBasic { if username, password, ok := r.BasicAuth(); ok { diff --git a/cmd/auth-proxy/main.go b/cmd/auth-proxy/main.go index ac5d7df..4d1567a 100644 --- a/cmd/auth-proxy/main.go +++ b/cmd/auth-proxy/main.go @@ -594,11 +594,12 @@ func matchPattern(grant, rMethod, rHost, rPath string) bool { func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) { ra := auth.RequestAuthenticator{ + Authenticator: creds, AuthorizationSchemes: cli.AuthorizationHeaderSchemes, TokenHeaders: cli.TokenHeaderNames, TokenQueryParams: cli.QueryParamNames, } - cred, err := ra.Authenticate(r, creds) + cred, err := ra.Authenticate(r) if errors.Is(err, auth.ErrNoCredentials) { return nil, ErrNoAuth } diff --git a/cmd/smsapid/main.go b/cmd/smsapid/main.go index 0885b11..9c36505 100644 --- a/cmd/smsapid/main.go +++ b/cmd/smsapid/main.go @@ -102,6 +102,7 @@ func main() { if err := smsAuth.LoadCSV(f, '\t'); err != nil { log.Fatalf("failed to load credentials from %q: %v", credPath, err) } + smsRequestAuth.Authenticator = smsAuth } else { log.Printf("Warning: credentials file %q not found; /api/smsgw routes will return 401: %v", credPath, err) } @@ -148,7 +149,7 @@ func requireSMSPermission(permission string) func(http.Handler) http.Handler { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - cred, err := smsRequestAuth.Authenticate(r, smsAuth) + cred, err := smsRequestAuth.Authenticate(r) if err != nil || !hasSMSPermission(cred.Permissions(), permission) { w.Header().Set("WWW-Authenticate", `Basic`) http.Error(w, "Unauthorized", http.StatusUnauthorized)