feat: add WWWAuthenticate field + Challenge(w) method to RequestAuthenticator

Co-authored-by: coolaj86 <122831+coolaj86@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-03-02 07:58:27 +00:00
parent 33960a5f2b
commit 9e7a50b443
3 changed files with 32 additions and 12 deletions

View File

@ -36,17 +36,32 @@ type RequestAuthenticator struct {
// TokenQueryParams lists query parameter names checked for tokens, // TokenQueryParams lists query parameter names checked for tokens,
// e.g. []string{"access_token", "token"}. // e.g. []string{"access_token", "token"}.
TokenQueryParams []string TokenQueryParams []string
// WWWAuthenticate is the value sent in the WWW-Authenticate response header
// when Challenge is called. An empty string disables the header.
// NewRequestAuthenticator sets this to "Basic".
WWWAuthenticate string
} }
// NewRequestAuthenticator returns a RequestAuthenticator with sane defaults: // NewRequestAuthenticator returns a RequestAuthenticator with sane defaults:
// Basic Auth enabled, Bearer/Token Authorization schemes, common API-key // Basic Auth enabled, Bearer/Token Authorization schemes, common API-key
// headers, and access_token/token query params. // headers, access_token/token query params, and WWW-Authenticate: Basic.
func NewRequestAuthenticator() *RequestAuthenticator { func NewRequestAuthenticator() *RequestAuthenticator {
return &RequestAuthenticator{ return &RequestAuthenticator{
AuthenticateBasic: true, AuthenticateBasic: true,
AuthorizationSchemes: []string{"Bearer", "Token"}, AuthorizationSchemes: []string{"Bearer", "Token"},
TokenHeaders: []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"}, TokenHeaders: []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"},
TokenQueryParams: []string{"access_token", "token"}, TokenQueryParams: []string{"access_token", "token"},
WWWAuthenticate: "Basic",
}
}
// Challenge sets the WWW-Authenticate response header to ra.WWWAuthenticate
// when it is non-empty. Call this before writing a 401 Unauthorized response
// so that clients know which auth scheme to use.
func (ra *RequestAuthenticator) Challenge(w http.ResponseWriter) {
if ra.WWWAuthenticate != "" {
w.Header().Set("WWW-Authenticate", ra.WWWAuthenticate)
} }
} }

View File

@ -76,6 +76,7 @@ type MainConfig struct {
ProxyTarget string ProxyTarget string
AES128KeyPath string AES128KeyPath string
ShowVersion bool ShowVersion bool
WWWAuthenticate string
AuthorizationHeaderSchemes []string AuthorizationHeaderSchemes []string
TokenHeaderNames []string TokenHeaderNames []string
QueryParamNames []string QueryParamNames []string
@ -84,6 +85,7 @@ type MainConfig struct {
tokenSchemeList string tokenSchemeList string
tokenHeaderList string tokenHeaderList string
tokenParamList string tokenParamList string
ra *auth.RequestAuthenticator
} }
func (c *MainConfig) Addr() string { func (c *MainConfig) Addr() string {
@ -102,6 +104,7 @@ func main() {
tokenSchemeList: "", tokenSchemeList: "",
tokenHeaderList: "", tokenHeaderList: "",
tokenParamList: "", tokenParamList: "",
WWWAuthenticate: "Basic",
AuthorizationHeaderSchemes: nil, // []string{"Bearer", "Token"} AuthorizationHeaderSchemes: nil, // []string{"Bearer", "Token"}
TokenHeaderNames: nil, // []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"}, TokenHeaderNames: nil, // []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"},
QueryParamNames: nil, // []string{"access_token", "token"}, QueryParamNames: nil, // []string{"access_token", "token"},
@ -282,6 +285,14 @@ func run(cli *MainConfig) {
log.Fatalf("Failed to load CSV auth: %v", err) log.Fatalf("Failed to load CSV auth: %v", err)
} }
cli.ra = &auth.RequestAuthenticator{
Authenticator: creds,
AuthorizationSchemes: cli.AuthorizationHeaderSchemes,
TokenHeaders: cli.TokenHeaderNames,
TokenQueryParams: cli.QueryParamNames,
WWWAuthenticate: cli.WWWAuthenticate,
}
var usableRoles int var usableRoles int
for key := range creds.CredentialKeys() { for key := range creds.CredentialKeys() {
u, err := creds.LoadCredential(key) u, err := creds.LoadCredential(key)
@ -391,8 +402,7 @@ func (cli *MainConfig) newAuthProxyHandler(targetURL string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !cli.authorize(r) { if !cli.authorize(r) {
// TODO allow --realm for `WWW-Authenticate: Basic realm="My Application"` cli.ra.Challenge(w)
w.Header().Set("WWW-Authenticate", `Basic`)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
@ -593,13 +603,7 @@ func matchPattern(grant, rMethod, rHost, rPath string) bool {
} }
func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) { func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) {
ra := auth.RequestAuthenticator{ cred, err := cli.ra.Authenticate(r)
Authenticator: creds,
AuthorizationSchemes: cli.AuthorizationHeaderSchemes,
TokenHeaders: cli.TokenHeaderNames,
TokenQueryParams: cli.QueryParamNames,
}
cred, err := ra.Authenticate(r)
if errors.Is(err, auth.ErrNoCredentials) { if errors.Is(err, auth.ErrNoCredentials) {
return nil, ErrNoAuth return nil, ErrNoAuth
} }

View File

@ -39,6 +39,7 @@ var smsRequestAuth = &auth.RequestAuthenticator{
AuthenticateBasic: true, AuthenticateBasic: true,
AuthorizationSchemes: []string{"*"}, AuthorizationSchemes: []string{"*"},
TokenHeaders: []string{"API-Key", "X-API-Key"}, TokenHeaders: []string{"API-Key", "X-API-Key"},
WWWAuthenticate: "Basic",
} }
func main() { func main() {
@ -145,13 +146,13 @@ func requireSMSPermission(permission string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if smsAuth == nil { if smsAuth == nil {
w.Header().Set("WWW-Authenticate", `Basic`) smsRequestAuth.Challenge(w)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
cred, err := smsRequestAuth.Authenticate(r) cred, err := smsRequestAuth.Authenticate(r)
if err != nil || !hasSMSPermission(cred.Permissions(), permission) { if err != nil || !hasSMSPermission(cred.Permissions(), permission) {
w.Header().Set("WWW-Authenticate", `Basic`) smsRequestAuth.Challenge(w)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }