From 3f7985317fb3ff8e8367b539971ee16e0561f655 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 13 Mar 2026 10:28:47 -0600 Subject: [PATCH] ajwt: implement redesigned API from REDESIGN.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename ValidateParams → Validator, make Issuer immutable after construction. Key changes: - StandardClaims.GetStandardClaims() + StandardClaimsSource interface: any struct embedding StandardClaims satisfies the interface for free via Go's method promotion — zero boilerplate for callers - Issuer is now immutable after construction; keys and validator are unexported; Params field removed - New constructors: New, NewWithJWKs, NewWithOIDC, NewWithOAuth2 - UnsafeVerify(tokenStr string) (*JWS, error): Decode + sig verify + iss check; "unsafe" means exp/aud/etc. are NOT checked - VerifyAndValidate(tokenStr, claims, now): full pipeline requiring non-nil Validator; fails loudly with nil Validator - FetchJWKs(ctx, url), FetchJWKsFromOIDC(ctx, base), FetchJWKsFromOAuth2(ctx, base): standalone fetch functions with context - PublicJWK.Thumbprint(): RFC 7638 SHA-256 thumbprint, canonical field ordering per spec (EC: crv/kty/x/y, RSA: e/kty/n, OKP: crv/kty/x) - DecodePublicJWKsJSON: auto-populates KID from Thumbprint when absent - Tests: 14 pass, covering VerifyAndValidate, UnsafeVerify, nil-validator error, all alg round trips, tampered alg, Thumbprint, auto-KID --- auth/ajwt/jwt.go | 353 +++++++++++++++++++++++++----------------- auth/ajwt/jwt_test.go | 253 +++++++++++++++++++++--------- auth/ajwt/pub.go | 171 +++++++++++++++++++- 3 files changed, 554 insertions(+), 223 deletions(-) diff --git a/auth/ajwt/jwt.go b/auth/ajwt/jwt.go index 1f77d47..7349ffa 100644 --- a/auth/ajwt/jwt.go +++ b/auth/ajwt/jwt.go @@ -12,25 +12,33 @@ // - [JWS] is a parsed structure only — no Claims interface, no Verified flag. // - [Issuer] owns key management and signature verification, centralizing // the key lookup → sig verify → iss check sequence. -// - [ValidateParams] is a stable config object; time is passed at the call -// site so the same params can be reused across requests. +// - [Validator] is a stable config object; time is passed at the call site +// so the same validator can be reused across requests. +// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims]. // - [JWS.UnmarshalClaims] accepts any type — no interface to implement. // - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384), // ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037). // -// Typical usage: +// Typical usage with VerifyAndValidate: // // // At startup: -// iss := ajwt.NewIssuer("https://accounts.example.com") -// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true} -// if err := iss.FetchKeys(ctx); err != nil { ... } +// iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com", +// &ajwt.Validator{Aud: "my-app", IgnoreIss: true}) // // // Per request: -// jws, err := ajwt.Decode(tokenStr) -// if err := iss.Verify(jws); err != nil { ... } // sig + iss check // var claims AppClaims -// if err := jws.UnmarshalClaims(&claims); err != nil { ... } -// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... } +// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now()) +// if err != nil { /* hard error: bad sig, expired, etc. */ } +// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ } +// +// Typical usage with UnsafeVerify (custom validation only): +// +// iss := ajwt.New("https://example.com", keys, nil) +// jws, err := iss.UnsafeVerify(tokenStr) +// var claims AppClaims +// jws.UnmarshalClaims(&claims) +// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, +// ajwt.Validator{Aud: "myapp"}, time.Now()) package ajwt import ( @@ -48,7 +56,6 @@ import ( "encoding/json" "fmt" "math/big" - "net/http" "slices" "strings" "time" @@ -58,10 +65,10 @@ import ( // // It holds only the parsed structure — header, raw base64url fields, and // decoded signature bytes. It carries no Claims interface and no Verified flag; -// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to -// decode the payload into a typed struct. +// use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the +// token and [JWS.UnmarshalClaims] to decode the payload into a typed struct. type JWS struct { - Protected string // base64url-encoded header + Protected string // base64url-encoded header Header StandardHeader Payload string // base64url-encoded claims Signature []byte @@ -77,12 +84,15 @@ type StandardHeader struct { // StandardClaims holds the registered JWT claim names defined in RFC 7519 // and extended by OpenID Connect Core. // -// Embed StandardClaims in your own claims struct: +// Embed StandardClaims in your own claims struct to satisfy [StandardClaimsSource] +// for free via Go's method promotion — zero boilerplate: // // type AppClaims struct { -// ajwt.StandardClaims +// ajwt.StandardClaims // promotes GetStandardClaims() // Email string `json:"email"` +// Roles []string `json:"roles"` // } +// // AppClaims now satisfies StandardClaimsSource automatically. type StandardClaims struct { Iss string `json:"iss"` Sub string `json:"sub"` @@ -96,10 +106,24 @@ type StandardClaims struct { Jti string `json:"jti"` } +// GetStandardClaims implements [StandardClaimsSource]. +// Any struct embedding StandardClaims gets this method for free via promotion. +func (sc StandardClaims) GetStandardClaims() StandardClaims { return sc } + +// StandardClaimsSource is implemented for free by any struct that embeds [StandardClaims]. +// +// type AppClaims struct { +// ajwt.StandardClaims // promotes GetStandardClaims() — zero boilerplate +// Email string `json:"email"` +// } +type StandardClaimsSource interface { + GetStandardClaims() StandardClaims +} + // Decode parses a compact JWT string (header.payload.signature) into a JWS. // // It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after -// [Issuer.Verify] to populate a typed claims struct. +// [Issuer.UnsafeVerify] to safely populate a typed claims struct. func Decode(tokenStr string) (*JWS, error) { parts := strings.Split(tokenStr, ".") if len(parts) != 3 { @@ -128,7 +152,7 @@ func Decode(tokenStr string) (*JWS, error) { // UnmarshalClaims decodes the JWT payload into v. // // v must be a pointer to a struct (e.g. *AppClaims). Always call -// [Issuer.Verify] before UnmarshalClaims to ensure the signature is +// [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is // authenticated before trusting the payload. func (jws *JWS) UnmarshalClaims(v any) error { payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) @@ -232,16 +256,16 @@ func (jws *JWS) Encode() string { return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature) } -// ValidateParams holds claim validation configuration. +// Validator holds claim validation configuration. // -// Configure once at startup; call [ValidateParams.Validate] per request, -// passing the current time. This keeps the config stable and makes the -// time dependency explicit at the call site. +// Configure once at startup; call [Validator.Validate] per request, passing +// the current time. This keeps the config stable and makes the time dependency +// explicit at the call site. // // https://openid.net/specs/openid-connect-core-1_0.html#IDToken -type ValidateParams struct { +type Validator struct { IgnoreIss bool - Iss string + Iss string // rarely needed — Issuer.UnsafeVerify already checks iss IgnoreSub bool Sub string IgnoreAud bool @@ -260,53 +284,53 @@ type ValidateParams struct { Azp string } -// Validate checks the standard JWT/OIDC claim fields against this config. +// Validate checks the standard JWT/OIDC claim fields in claims against this config. // // now is typically time.Now() — passing it explicitly keeps the config stable -// across requests and avoids hidden time dependencies in the params struct. -func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) { - return ValidateStandardClaims(claims, p, now) +// across requests and avoids hidden time dependencies in the validator struct. +func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) { + return ValidateStandardClaims(claims.GetStandardClaims(), v, now) } -// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params. +// ValidateStandardClaims checks the registered JWT/OIDC claim fields against v. // -// Exported so callers can use it directly without a [ValidateParams] receiver: +// Exported so callers can use it directly without a [Validator] receiver: // -// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now()) -func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) { +// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, v, time.Now()) +func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ([]string, error) { var errs []string // Required to exist and match - if len(params.Iss) > 0 || !params.IgnoreIss { + if len(v.Iss) > 0 || !v.IgnoreIss { if len(claims.Iss) == 0 { errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)") - } else if claims.Iss != params.Iss { - errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss)) + } else if claims.Iss != v.Iss { + errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss)) } } // Required to exist, optional match if len(claims.Sub) == 0 { - if !params.IgnoreSub { + if !v.IgnoreSub { errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)") } - } else if len(params.Sub) > 0 { - if params.Sub != claims.Sub { - errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub)) + } else if len(v.Sub) > 0 { + if v.Sub != claims.Sub { + errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, v.Sub)) } } // Required to exist and match - if len(params.Aud) > 0 || !params.IgnoreAud { + if len(v.Aud) > 0 || !v.IgnoreAud { if len(claims.Aud) == 0 { errs = append(errs, "missing or malformed 'aud' (audience receiving token)") - } else if claims.Aud != params.Aud { - errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud)) + } else if claims.Aud != v.Aud { + errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud)) } } // Required to exist and not be in the past - if !params.IgnoreExp { + if !v.IgnoreExp { if claims.Exp <= 0 { errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)") } else if claims.Exp < now.Unix() { @@ -317,7 +341,7 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti } // Required to exist and not be in the future - if !params.IgnoreIat { + if !v.IgnoreIat { if claims.Iat <= 0 { errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)") } else if claims.Iat > now.Unix() { @@ -328,53 +352,53 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti } // Should exist, in the past, with optional max age - if params.MaxAge > 0 || !params.IgnoreAuthTime { + if v.MaxAge > 0 || !v.IgnoreAuthTime { if claims.AuthTime == 0 { errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)") } else { authTime := time.Unix(claims.AuthTime, 0) authTimeStr := authTime.Format("2006-01-02 15:04:05 MST") age := now.Sub(authTime) - diff := age - params.MaxAge + diff := age - v.MaxAge if claims.AuthTime > now.Unix() { fromNow := time.Unix(claims.AuthTime, 0).Sub(now) errs = append(errs, fmt.Sprintf( "'auth_time' of %s is %s in the future (server time %s)", authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")), ) - } else if params.MaxAge > 0 && age > params.MaxAge { + } else if v.MaxAge > 0 && age > v.MaxAge { errs = append(errs, fmt.Sprintf( "'auth_time' of %s is %s old, exceeding max age %s by %s", - authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)), + authTimeStr, formatDuration(age), formatDuration(v.MaxAge), formatDuration(diff)), ) } } } // Optional exact match - if params.Jti != claims.Jti { - if len(params.Jti) > 0 { - errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti)) - } else if !params.IgnoreJti { + if v.Jti != claims.Jti { + if len(v.Jti) > 0 { + errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti)) + } else if !v.IgnoreJti { errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti)) } } // Optional exact match - if params.Nonce != claims.Nonce { - if len(params.Nonce) > 0 { - errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce)) - } else if !params.IgnoreNonce { + if v.Nonce != claims.Nonce { + if len(v.Nonce) > 0 { + errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce)) + } else if !v.IgnoreNonce { errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce)) } } // Should exist, optional required-set check - if !params.IgnoreAmr { + if !v.IgnoreAmr { if len(claims.Amr) == 0 { errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)") - } else if len(params.RequiredAmrs) > 0 { - for _, required := range params.RequiredAmrs { + } else if len(v.RequiredAmrs) > 0 { + for _, required := range v.RequiredAmrs { if !slices.Contains(claims.Amr, required) { errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required)) } @@ -383,10 +407,10 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti } // Optional, match if present - if params.Azp != claims.Azp { - if len(params.Azp) > 0 { - errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp)) - } else if !params.IgnoreAzp { + if v.Azp != claims.Azp { + if len(v.Azp) > 0 { + errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp)) + } else if !v.IgnoreAzp { errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp)) } } @@ -402,111 +426,154 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti return nil, nil } -// Issuer holds public keys and validation config for a trusted token issuer. +// Issuer holds public keys and optional validation config for a trusted token issuer. // -// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint. -// [Issuer.SetKeys] injects keys directly (useful in tests). -// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check. +// Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2]. +// After construction, Issuer is immutable. // -// Typical setup: -// -// iss := ajwt.NewIssuer("https://accounts.example.com") -// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true} -// if err := iss.FetchKeys(ctx); err != nil { ... } +// [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check. +// [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator. type Issuer struct { - URL string - JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json" - Params ValidateParams - keys map[string]crypto.PublicKey // kid → key + URL string // issuer URL for iss claim enforcement; empty skips the check + validator *Validator + keys map[string]crypto.PublicKey // kid → key } -// NewIssuer creates an Issuer for the given base URL. -func NewIssuer(url string) *Issuer { - return &Issuer{ - URL: url, - keys: make(map[string]crypto.PublicKey), - } -} - -// SetKeys stores public keys by their KID, replacing any previously stored keys. -// Useful for injecting keys in tests without an HTTP round-trip. -func (iss *Issuer) SetKeys(keys []PublicJWK) { +// New creates an Issuer with explicit keys. +// +// v is optional — pass nil to use [Issuer.UnsafeVerify] only. +// [Issuer.VerifyAndValidate] requires a non-nil Validator. +func New(issURL string, keys []PublicJWK, v *Validator) *Issuer { m := make(map[string]crypto.PublicKey, len(keys)) for _, k := range keys { m[k.KID] = k.Key } - iss.keys = m + return &Issuer{ + URL: issURL, + validator: v, + keys: m, + } } -// FetchKeys retrieves and stores the JWKS from the issuer's endpoint. -// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json". -func (iss *Issuer) FetchKeys(ctx context.Context) error { - url := iss.JWKsURL - if url == "" { - url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json" - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return fmt.Errorf("fetch JWKS: %w", err) - } - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("fetch JWKS: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode) - } - - keys, err := DecodePublicJWKs(resp.Body) - if err != nil { - return fmt.Errorf("parse JWKS: %w", err) - } - - iss.SetKeys(keys) - return nil -} - -// Verify authenticates jws against this issuer: -// 1. Looks up the signing key by jws.Header.Kid. -// 2. Verifies the signature before trusting any payload data. -// 3. Checks that the token's "iss" claim matches iss.URL. +// NewWithJWKs creates an Issuer by fetching keys from jwksURL. // -// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a -// typed struct, then [ValidateParams.Validate] to check claim values. -func (iss *Issuer) Verify(jws *JWS) error { +// The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is +// not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement. +// +// v is optional — pass nil to use [Issuer.UnsafeVerify] only. +func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) { + keys, err := FetchJWKs(ctx, jwksURL) + if err != nil { + return nil, err + } + return New("", keys, v), nil +} + +// NewWithOIDC creates an Issuer using OIDC discovery. +// +// It fetches {baseURL}/.well-known/openid-configuration and reads the +// jwks_uri and issuer fields. The Issuer URL is set from the discovery +// document's issuer field (not baseURL) because OIDC requires them to match. +// +// v is optional — pass nil to use [Issuer.UnsafeVerify] only. +func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) { + discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration" + keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL) + if err != nil { + return nil, err + } + return New(issURL, keys, v), nil +} + +// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414). +// +// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the +// jwks_uri and issuer fields. The Issuer URL is set from the discovery +// document's issuer field. +// +// v is optional — pass nil to use [Issuer.UnsafeVerify] only. +func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) { + discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server" + keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL) + if err != nil { + return nil, err + } + return New(issURL, keys, v), nil +} + +// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL] +// is set) checks the iss claim. +// +// "Unsafe" means exp, aud, and other claim values are NOT checked — the token +// is forgery-safe but not semantically validated. Callers are responsible for +// validating claim values, or use [Issuer.VerifyAndValidate]. +func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) { + jws, err := Decode(tokenStr) + if err != nil { + return nil, err + } + if jws.Header.Kid == "" { - return fmt.Errorf("missing 'kid' header") + return nil, fmt.Errorf("missing 'kid' header") } key, ok := iss.keys[jws.Header.Kid] if !ok { - return fmt.Errorf("unknown kid: %q", jws.Header.Kid) + return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid) } signingInput := jws.Protected + "." + jws.Payload if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil { - return fmt.Errorf("signature verification failed: %w", err) + return nil, fmt.Errorf("signature verification failed: %w", err) } - // Signature verified — now safe to inspect the payload. - payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + // Signature verified — now safe to inspect the payload for iss check. + if iss.URL != "" { + payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return nil, fmt.Errorf("invalid claims encoding: %w", err) + } + var partial struct { + Iss string `json:"iss"` + } + if err := json.Unmarshal(payload, &partial); err != nil { + return nil, fmt.Errorf("invalid claims JSON: %w", err) + } + if partial.Iss != iss.URL { + return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL) + } + } + + return jws, nil +} + +// VerifyAndValidate verifies the token signature and iss, unmarshals the claims +// into claims, and runs the [Validator]. +// +// Returns a hard error (err != nil) for signature failures, decoding errors, +// and nil Validator. Returns soft errors (errs != nil) for claim validation +// failures (wrong aud, expired token, etc.). +// +// claims must be a pointer whose underlying type embeds [StandardClaims] (or +// otherwise implements [StandardClaimsSource]): +// +// var claims AppClaims +// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now()) +func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) { + if iss.validator == nil { + return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification") + } + + jws, err := iss.UnsafeVerify(tokenStr) if err != nil { - return fmt.Errorf("invalid claims encoding: %w", err) - } - var partial struct { - Iss string `json:"iss"` - } - if err := json.Unmarshal(payload, &partial); err != nil { - return fmt.Errorf("invalid claims JSON: %w", err) - } - if partial.Iss != iss.URL { - return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL) + return nil, nil, err } - return nil + if err := jws.UnmarshalClaims(claims); err != nil { + return nil, nil, err + } + + errs, err := iss.validator.Validate(claims, now) + return jws, errs, err } // verifyWith checks a JWS signature using the given algorithm and public key. diff --git a/auth/ajwt/jwt_test.go b/auth/ajwt/jwt_test.go index e8319a2..f4c7cac 100644 --- a/auth/ajwt/jwt_test.go +++ b/auth/ajwt/jwt_test.go @@ -14,7 +14,9 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "encoding/base64" "fmt" + "strings" "testing" "time" @@ -23,9 +25,8 @@ import ( // AppClaims embeds StandardClaims and adds application-specific fields. // -// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate -// interface — there is none. Validation is explicit: call -// ValidateStandardClaims or ValidateParams.Validate at the call site. +// Because StandardClaims is embedded, AppClaims satisfies StandardClaimsSource +// for free via Go's method promotion — no interface to implement. type AppClaims struct { ajwt.StandardClaims Email string `json:"email"` @@ -33,9 +34,10 @@ type AppClaims struct { } // validateAppClaims is a plain function — not a method satisfying an interface. -// Custom validation logic lives here, calling ValidateStandardClaims directly. -func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) { - errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now) +// It demonstrates the UnsafeVerify pattern: custom validation logic lives here, +// calling ValidateStandardClaims directly and adding app-specific checks. +func validateAppClaims(c AppClaims, v ajwt.Validator, now time.Time) ([]string, error) { + errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, v, now) if c.Email == "" { errs = append(errs, "missing email claim") } @@ -65,11 +67,11 @@ func goodClaims() AppClaims { } } -// goodParams configures the validator. Iss is omitted because Issuer.Verify -// already enforces the iss claim — no need to check it twice. -func goodParams() ajwt.ValidateParams { - return ajwt.ValidateParams{ - IgnoreIss: true, // Issuer.Verify handles iss +// goodValidator configures the validator. IgnoreIss is true because +// Issuer.UnsafeVerify already enforces the iss claim — no need to check twice. +func goodValidator() *ajwt.Validator { + return &ajwt.Validator{ + IgnoreIss: true, // UnsafeVerify handles iss Sub: "user123", Aud: "myapp", Jti: "abc123", @@ -80,16 +82,13 @@ func goodParams() ajwt.ValidateParams { } func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer { - iss := ajwt.NewIssuer("https://example.com") - iss.Params = goodParams() - iss.SetKeys([]ajwt.PublicJWK{pub}) - return iss + return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator()) } // TestRoundTrip is the primary happy path using ES256. -// It demonstrates the full Issuer-based flow: +// It demonstrates the full VerifyAndValidate flow: // -// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate +// New → VerifyAndValidate → custom claim access // // No Claims interface, no Verified flag, no type assertions on jws. func TestRoundTrip(t *testing.T) { @@ -115,20 +114,16 @@ func TestRoundTrip(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) - jws2, err := ajwt.Decode(token) - if err != nil { - t.Fatal(err) - } - if err = iss.Verify(jws2); err != nil { - t.Fatalf("Verify failed: %v", err) - } - var decoded AppClaims - if err = jws2.UnmarshalClaims(&decoded); err != nil { - t.Fatal(err) + jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now()) + if err != nil { + t.Fatalf("VerifyAndValidate failed: %v", err) } - if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { - t.Fatalf("validation failed: %v", errs) + if len(errs) > 0 { + t.Fatalf("claim validation failed: %v", errs) + } + if jws2.Header.Alg != "ES256" { + t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg) } // Direct field access — no type assertion needed, no jws.Claims interface. if decoded.Email != claims.Email { @@ -160,20 +155,13 @@ func TestRoundTripRS256(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) - jws2, err := ajwt.Decode(token) - if err != nil { - t.Fatal(err) - } - if err = iss.Verify(jws2); err != nil { - t.Fatalf("Verify failed: %v", err) - } - var decoded AppClaims - if err = jws2.UnmarshalClaims(&decoded); err != nil { - t.Fatal(err) + _, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now()) + if err != nil { + t.Fatalf("VerifyAndValidate failed: %v", err) } - if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { - t.Fatalf("validation failed: %v", errs) + if len(errs) > 0 { + t.Fatalf("claim validation failed: %v", errs) } } @@ -201,25 +189,47 @@ func TestRoundTripEdDSA(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"}) - jws2, err := ajwt.Decode(token) - if err != nil { - t.Fatal(err) - } - if err = iss.Verify(jws2); err != nil { - t.Fatalf("Verify failed: %v", err) - } - var decoded AppClaims - if err = jws2.UnmarshalClaims(&decoded); err != nil { - t.Fatal(err) + _, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now()) + if err != nil { + t.Fatalf("VerifyAndValidate failed: %v", err) } - if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { - t.Fatalf("validation failed: %v", errs) + if len(errs) > 0 { + t.Fatalf("claim validation failed: %v", errs) } } -// TestCustomValidation demonstrates custom claim validation without any interface. -// The caller owns the validation logic and calls ValidateStandardClaims directly. +// TestUnsafeVerifyFlow demonstrates the UnsafeVerify + custom validation pattern. +// The caller owns the full validation pipeline. +func TestUnsafeVerifyFlow(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := ajwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + // Create issuer without validator — UnsafeVerify only. + iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil) + + jws2, err := iss.UnsafeVerify(token) + if err != nil { + t.Fatalf("UnsafeVerify failed: %v", err) + } + + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + + errs, err := ajwt.ValidateStandardClaims(decoded.StandardClaims, *goodValidator(), time.Now()) + if err != nil { + t.Fatalf("ValidateStandardClaims failed: %v — errs: %v", err, errs) + } +} + +// TestCustomValidation demonstrates that ValidateStandardClaims is called +// explicitly and custom fields are validated without any Claims interface. func TestCustomValidation(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -231,12 +241,15 @@ func TestCustomValidation(t *testing.T) { token := jws.Encode() iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) - jws2, _ := ajwt.Decode(token) - _ = iss.Verify(jws2) + jws2, err := iss.UnsafeVerify(token) + if err != nil { + t.Fatalf("UnsafeVerify failed unexpectedly: %v", err) + } + var decoded AppClaims _ = jws2.UnmarshalClaims(&decoded) - errs, err := validateAppClaims(decoded, goodParams(), time.Now()) + errs, err := validateAppClaims(decoded, *goodValidator(), time.Now()) if err == nil { t.Fatal("expected validation to fail: email is empty") } @@ -251,6 +264,23 @@ func TestCustomValidation(t *testing.T) { } } +// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly +// when no Validator was provided at construction time. +func TestVerifyAndValidateNilValidator(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + c := goodClaims() + jws, _ := ajwt.NewJWSFromClaims(&c, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil) + + var claims AppClaims + if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil { + t.Fatal("expected VerifyAndValidate to error with nil validator") + } +} + // TestIssuerWrongKey confirms that a different key's public key is rejected. func TestIssuerWrongKey(t *testing.T) { signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -262,10 +292,9 @@ func TestIssuerWrongKey(t *testing.T) { token := jws.Encode() iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"}) - jws2, _ := ajwt.Decode(token) - if err := iss.Verify(jws2); err == nil { - t.Fatal("expected Verify to fail with wrong key") + if _, err := iss.UnsafeVerify(token); err == nil { + t.Fatal("expected UnsafeVerify to fail with wrong key") } } @@ -279,10 +308,9 @@ func TestIssuerUnknownKid(t *testing.T) { token := jws.Encode() iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"}) - jws2, _ := ajwt.Decode(token) - if err := iss.Verify(jws2); err == nil { - t.Fatal("expected Verify to fail for unknown kid") + if _, err := iss.UnsafeVerify(token); err == nil { + t.Fatal("expected UnsafeVerify to fail for unknown kid") } } @@ -299,14 +327,15 @@ func TestIssuerIssMismatch(t *testing.T) { // Issuer expects "https://example.com" but token says "https://evil.example.com" iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) - jws2, _ := ajwt.Decode(token) - if err := iss.Verify(jws2); err == nil { - t.Fatal("expected Verify to fail: iss mismatch") + if _, err := iss.UnsafeVerify(token); err == nil { + t.Fatal("expected UnsafeVerify to fail: iss mismatch") } } -// TestVerifyTamperedAlg confirms that a tampered alg header is rejected. +// TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected. +// The token is reconstructed with a replaced protected header; the original +// ES256 signature is kept, making the signing input mismatch detectable. func TestVerifyTamperedAlg(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -316,11 +345,15 @@ func TestVerifyTamperedAlg(t *testing.T) { token := jws.Encode() iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) - jws2, _ := ajwt.Decode(token) - jws2.Header.Alg = "none" // tamper - if err := iss.Verify(jws2); err == nil { - t.Fatal("expected Verify to fail for tampered alg") + // Replace the protected header with one that has alg:"none". + // The original ES256 signature stays — the signing input will mismatch. + noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`)) + parts := strings.SplitN(token, ".", 3) + tamperedToken := noneHeader + "." + parts[1] + "." + parts[2] + + if _, err := iss.UnsafeVerify(tamperedToken); err == nil { + t.Fatal("expected UnsafeVerify to fail for tampered alg") } } @@ -408,3 +441,79 @@ func TestDecodePublicJWKJSON(t *testing.T) { t.Errorf("expected 1 RSA key, got %d", rsaCount) } } + +// TestThumbprint verifies that Thumbprint returns a non-empty base64url string +// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint. +func TestThumbprint(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + edPub, _, _ := ed25519.GenerateKey(rand.Reader) + + tests := []struct { + name string + jwk ajwt.PublicJWK + }{ + {"EC P-256", ajwt.PublicJWK{Key: &ecKey.PublicKey}}, + {"RSA 2048", ajwt.PublicJWK{Key: &rsaKey.PublicKey}}, + {"Ed25519", ajwt.PublicJWK{Key: edPub}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + thumb, err := tt.jwk.Thumbprint() + if err != nil { + t.Fatalf("Thumbprint() error: %v", err) + } + if thumb == "" { + t.Fatal("Thumbprint() returned empty string") + } + // Must be valid base64url (no padding, no +/) + if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") { + t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb) + } + // Same key → same thumbprint + thumb2, _ := tt.jwk.Thumbprint() + if thumb != thumb2 { + t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2) + } + }) + } +} + +// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets +// its KID auto-populated from the RFC 7638 thumbprint. +func TestNoKidAutoThumbprint(t *testing.T) { + // EC key with no "kid" field in the JWKS + jwksJSON := []byte(`{"keys":[ + {"kty":"EC","crv":"P-256", + "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "use":"sig"} + ]}`) + + keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON) + if err != nil { + t.Fatal(err) + } + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KID == "" { + t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS") + } + + // The auto-KID should be a valid base64url string. + kid := keys[0].KID + if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") { + t.Errorf("auto-KID contains non-base64url characters: %s", kid) + } + + // Round-trip: compute Thumbprint directly and compare. + thumb, err := keys[0].Thumbprint() + if err != nil { + t.Fatalf("Thumbprint() error: %v", err) + } + if kid != thumb { + t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb) + } +} diff --git a/auth/ajwt/pub.go b/auth/ajwt/pub.go index 0114556..ed3e473 100644 --- a/auth/ajwt/pub.go +++ b/auth/ajwt/pub.go @@ -9,11 +9,13 @@ package ajwt import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rsa" + "crypto/sha256" "encoding/base64" "encoding/json" "fmt" @@ -21,6 +23,7 @@ import ( "math/big" "net/http" "os" + "strings" "time" ) @@ -54,6 +57,90 @@ func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) { return key, ok } +// Thumbprint computes the RFC 7638 JWK Thumbprint (SHA-256 of the canonical +// key JSON with fields in lexicographic order). The result is base64url-encoded. +// +// Canonical forms per RFC 7638: +// - EC: {"crv":…, "kty":"EC", "x":…, "y":…} +// - RSA: {"e":…, "kty":"RSA", "n":…} +// - OKP: {"crv":"Ed25519", "kty":"OKP", "x":…} +// +// Use Thumbprint as KID when none is provided in the JWKS source. +func (k PublicJWK) Thumbprint() (string, error) { + var canonical []byte + var err error + + switch key := k.Key.(type) { + case *ecdsa.PublicKey: + byteLen := (key.Curve.Params().BitSize + 7) / 8 + xBytes := make([]byte, byteLen) + yBytes := make([]byte, byteLen) + key.X.FillBytes(xBytes) + key.Y.FillBytes(yBytes) + + var crv string + switch key.Curve { + case elliptic.P256(): + crv = "P-256" + case elliptic.P384(): + crv = "P-384" + case elliptic.P521(): + crv = "P-521" + default: + return "", fmt.Errorf("Thumbprint: unsupported EC curve %s", key.Curve.Params().Name) + } + + // Fields in lexicographic order: crv, kty, x, y + canonical, err = json.Marshal(struct { + Crv string `json:"crv"` + Kty string `json:"kty"` + X string `json:"x"` + Y string `json:"y"` + }{ + Crv: crv, + Kty: "EC", + X: base64.RawURLEncoding.EncodeToString(xBytes), + Y: base64.RawURLEncoding.EncodeToString(yBytes), + }) + + case *rsa.PublicKey: + eInt := big.NewInt(int64(key.E)) + + // Fields in lexicographic order: e, kty, n + canonical, err = json.Marshal(struct { + E string `json:"e"` + Kty string `json:"kty"` + N string `json:"n"` + }{ + E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()), + Kty: "RSA", + N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()), + }) + + case ed25519.PublicKey: + // Fields in lexicographic order: crv, kty, x + canonical, err = json.Marshal(struct { + Crv string `json:"crv"` + Kty string `json:"kty"` + X string `json:"x"` + }{ + Crv: "Ed25519", + Kty: "OKP", + X: base64.RawURLEncoding.EncodeToString([]byte(key)), + }) + + default: + return "", fmt.Errorf("Thumbprint: unsupported key type %T", k.Key) + } + + if err != nil { + return "", fmt.Errorf("Thumbprint: marshal canonical JSON: %w", err) + } + + sum := sha256.Sum256(canonical) + return base64.RawURLEncoding.EncodeToString(sum[:]), nil +} + // PublicJWKJSON is the JSON representation of a single key in a JWKS document. type PublicJWKJSON struct { Kty string `json:"kty"` @@ -71,24 +158,83 @@ type JWKsJSON struct { Keys []PublicJWKJSON `json:"keys"` } -// FetchPublicJWKs retrieves and parses a JWKS document from url. +// FetchJWKs retrieves and parses a JWKS document from jwksURL. // -// For issuer-scoped key management with context support, use -// [Issuer.FetchKeys] instead. -func FetchPublicJWKs(url string) ([]PublicJWK, error) { - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Get(url) +// ctx is used for the HTTP request timeout and cancellation. +func FetchJWKs(ctx context.Context, jwksURL string) ([]PublicJWK, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil) if err != nil { - return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + return nil, fmt.Errorf("fetch JWKS: %w", err) + } + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch JWKS: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + return nil, fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode) } return DecodePublicJWKs(resp.Body) } +// FetchJWKsFromOIDC fetches JWKS via OIDC discovery from baseURL. +// +// It fetches {baseURL}/.well-known/openid-configuration and reads the jwks_uri field. +func FetchJWKsFromOIDC(ctx context.Context, baseURL string) ([]PublicJWK, error) { + discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration" + keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL) + return keys, err +} + +// FetchJWKsFromOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata (RFC 8414) +// from baseURL. +// +// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the jwks_uri field. +func FetchJWKsFromOAuth2(ctx context.Context, baseURL string) ([]PublicJWK, error) { + discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server" + keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL) + return keys, err +} + +// fetchJWKsFromDiscovery fetches a discovery document from discoveryURL, then +// fetches the JWKS from the jwks_uri field. Returns the keys and the issuer +// URL from the discovery document's "issuer" field. +func fetchJWKsFromDiscovery(ctx context.Context, discoveryURL string) ([]PublicJWK, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil) + if err != nil { + return nil, "", fmt.Errorf("fetch discovery: %w", err) + } + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, "", fmt.Errorf("fetch discovery: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("fetch discovery: unexpected status %d", resp.StatusCode) + } + + var doc struct { + Issuer string `json:"issuer"` + JWKsURI string `json:"jwks_uri"` + } + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + return nil, "", fmt.Errorf("parse discovery doc: %w", err) + } + if doc.JWKsURI == "" { + return nil, "", fmt.Errorf("discovery doc missing jwks_uri field") + } + + keys, err := FetchJWKs(ctx, doc.JWKsURI) + if err != nil { + return nil, "", err + } + return keys, doc.Issuer, nil +} + // ReadPublicJWKs reads and parses a JWKS document from a file path. func ReadPublicJWKs(filePath string) ([]PublicJWK, error) { file, err := os.Open(filePath) @@ -118,6 +264,9 @@ func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) { } // DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys. +// +// If a key has no kid field in the source document, the KID is auto-populated +// from [PublicJWK.Thumbprint] per RFC 7638. func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) { var keys []PublicJWK for _, jwk := range jwks.Keys { @@ -125,6 +274,12 @@ func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) { if err != nil { return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err) } + if key.KID == "" { + key.KID, err = key.Thumbprint() + if err != nil { + return nil, fmt.Errorf("compute thumbprint for kid-less key: %w", err) + } + } keys = append(keys, *key) } if len(keys) == 0 {