diff --git a/auth/ajwt/fetcher.go b/auth/ajwt/fetcher.go new file mode 100644 index 0000000..23f0987 --- /dev/null +++ b/auth/ajwt/fetcher.go @@ -0,0 +1,122 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package ajwt + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" +) + +// cachedIssuer bundles an [*Issuer] with its freshness window. +// Stored atomically in [JWKsFetcher]; immutable after creation. +type cachedIssuer struct { + iss *Issuer + fetchedAt time.Time + expiresAt time.Time // fetchedAt + MaxAge; fresh until this point +} + +// JWKsFetcher lazily fetches and caches JWKS keys from a remote URL, +// returning a fresh [*Issuer] on demand. +// +// Each call to [JWKsFetcher.Issuer] checks freshness and either returns the +// cached Issuer immediately or fetches a new one. There is no background +// goroutine — refresh only happens when a caller requests an Issuer. +// +// Fields must be set before the first call to [JWKsFetcher.Issuer]; do not +// modify them concurrently. +// +// Typical usage: +// +// fetcher := &ajwt.JWKsFetcher{ +// URL: "https://accounts.example.com/.well-known/jwks.json", +// MaxAge: time.Hour, +// StaleAge: 30 * time.Minute, +// KeepOnError: true, +// } +// iss, err := fetcher.Issuer(ctx) +type JWKsFetcher struct { + // URL is the JWKS endpoint to fetch keys from. + URL string + + // MaxAge is how long fetched keys are considered fresh. After MaxAge, + // the next call to Issuer triggers a refresh. Defaults to 1 hour. + MaxAge time.Duration + + // StaleAge is additional time beyond MaxAge during which the old Issuer + // may be returned when a refresh fails. For example, MaxAge=1h and + // StaleAge=30m means keys will be served up to 90 minutes after the last + // successful fetch, if KeepOnError is true and fetches keep failing. + // Defaults to 0 (no stale window). + StaleAge time.Duration + + // KeepOnError causes the previous Issuer to be returned (with an error) + // when a refresh fails, as long as the result is within the stale window + // (expiresAt + StaleAge). If false, any fetch error after MaxAge returns + // (nil, err). + KeepOnError bool + + // RespectHeaders is reserved for future use (honor Cache-Control max-age + // from the JWKS response, capped at MaxAge). + RespectHeaders bool + + mu sync.Mutex + cached atomic.Pointer[cachedIssuer] +} + +// Issuer returns a current [*Issuer] for verifying tokens. +// +// If the cached Issuer is still fresh (within MaxAge), it is returned without +// a network call. If it has expired, a new fetch is performed. On fetch +// failure with KeepOnError=true and within StaleAge, the old Issuer is +// returned alongside a non-nil error; callers may choose to accept it. +func (f *JWKsFetcher) Issuer(ctx context.Context) (*Issuer, error) { + now := time.Now() + + // Fast path: check cached value without locking. + if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) { + return ci.iss, nil + } + + // Slow path: refresh needed. Serialize to avoid stampeding. + f.mu.Lock() + defer f.mu.Unlock() + + // Re-check after acquiring lock — another goroutine may have refreshed. + if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) { + return ci.iss, nil + } + + keys, err := FetchJWKs(ctx, f.URL) + if err != nil { + // On error, serve stale keys within the stale window. + if ci := f.cached.Load(); ci != nil && f.KeepOnError { + staleDeadline := ci.expiresAt.Add(f.StaleAge) + if now.Before(staleDeadline) { + return ci.iss, fmt.Errorf("JWKS refresh failed (serving cached keys): %w", err) + } + } + return nil, fmt.Errorf("fetch JWKS from %s: %w", f.URL, err) + } + + maxAge := f.MaxAge + if maxAge <= 0 { + maxAge = time.Hour + } + + ci := &cachedIssuer{ + iss: New(keys), + fetchedAt: now, + expiresAt: now.Add(maxAge), + } + f.cached.Store(ci) + return ci.iss, nil +} diff --git a/auth/ajwt/jwt.go b/auth/ajwt/jwt.go index 7349ffa..6037821 100644 --- a/auth/ajwt/jwt.go +++ b/auth/ajwt/jwt.go @@ -9,40 +9,52 @@ // Package ajwt is a lightweight JWT/JWS/JWK library designed from first // principles: // -// - [JWS] is a parsed structure only — no Claims interface, no Verified flag. -// - [Issuer] owns key management and signature verification, centralizing -// the key lookup → sig verify → iss check sequence. -// - [Validator] is a stable config object; time is passed at the call site -// so the same validator can be reused across requests. -// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims]. -// - [JWS.UnmarshalClaims] accepts any type — no interface to implement. -// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384), -// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037). +// - [Issuer] is immutable — constructed with a fixed key set, safe for concurrent use. +// - [Signer] manages private keys and returns [*Issuer] for verification. +// - [JWKsFetcher] lazily fetches and caches JWKS keys, returning a fresh [*Issuer] on demand. +// - [Validator] and [MultiValidator] validate standard JWT/OIDC claims. +// - [JWS] is a parsed structure — use [Issuer.Verify] or [Issuer.UnsafeVerify] to authenticate. +// - [JWS.UnmarshalClaims] accepts any type — no Claims interface to implement. +// - [StandardClaimsSource] is satisfied for free by embedding [StandardClaims]. // // Typical usage with VerifyAndValidate: // // // At startup: -// iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com", -// &ajwt.Validator{Aud: "my-app", IgnoreIss: true}) +// signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{Signer: privKey}}) +// iss := signer.Issuer() +// v := &ajwt.Validator{Iss: "https://example.com", Aud: "my-app"} +// +// // Sign a token: +// tokenStr, err := signer.Sign(claims) // // // Per request: // var claims AppClaims -// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now()) -// if err != nil { /* hard error: bad sig, expired, etc. */ } -// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ } +// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, v, time.Now()) +// if err != nil { /* hard error: bad sig, malformed token */ } +// if len(errs) > 0 { /* soft errors: wrong aud, expired, etc. */ } // -// Typical usage with UnsafeVerify (custom validation only): +// Typical usage with UnsafeVerify (custom validation): // -// iss := ajwt.New("https://example.com", keys, nil) +// iss := ajwt.New(keys) // jws, err := iss.UnsafeVerify(tokenStr) // var claims AppClaims // jws.UnmarshalClaims(&claims) // errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, // ajwt.Validator{Aud: "myapp"}, time.Now()) +// +// Typical usage with JWKsFetcher (dynamic keys from remote): +// +// fetcher := &ajwt.JWKsFetcher{ +// URL: "https://accounts.example.com/.well-known/jwks.json", +// MaxAge: time.Hour, +// StaleAge: time.Hour, +// KeepOnError: true, +// } +// iss, err := fetcher.Issuer(ctx) +// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, v, time.Now()) package ajwt import ( - "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -65,8 +77,8 @@ import ( // // It holds only the parsed structure — header, raw base64url fields, and // decoded signature bytes. It carries no Claims interface and no Verified flag; -// use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the -// token and [JWS.UnmarshalClaims] to decode the payload into a typed struct. +// use [Issuer.Verify] or [Issuer.UnsafeVerify] to authenticate the token and +// [JWS.UnmarshalClaims] to decode the payload into a typed struct. type JWS struct { Protected string // base64url-encoded header Header StandardHeader @@ -120,10 +132,16 @@ type StandardClaimsSource interface { GetStandardClaims() StandardClaims } +// ClaimsValidator validates the standard JWT/OIDC claims in a token. +// Implemented by [*Validator] and [*MultiValidator]. +type ClaimsValidator interface { + Validate(claims StandardClaimsSource, now time.Time) ([]string, error) +} + // Decode parses a compact JWT string (header.payload.signature) into a JWS. // // It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after -// [Issuer.UnsafeVerify] to safely populate a typed claims struct. +// [Issuer.Verify] or [Issuer.UnsafeVerify] to populate a typed claims struct. func Decode(tokenStr string) (*JWS, error) { parts := strings.Split(tokenStr, ".") if len(parts) != 3 { @@ -152,8 +170,8 @@ func Decode(tokenStr string) (*JWS, error) { // UnmarshalClaims decodes the JWT payload into v. // // v must be a pointer to a struct (e.g. *AppClaims). Always call -// [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is -// authenticated before trusting the payload. +// [Issuer.Verify] or [Issuer.UnsafeVerify] before UnmarshalClaims to ensure +// the signature is authenticated before trusting the payload. func (jws *JWS) UnmarshalClaims(v any) error { payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) if err != nil { @@ -256,16 +274,15 @@ func (jws *JWS) Encode() string { return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature) } -// Validator holds claim validation configuration. +// Validator holds claim validation configuration for single-tenant use. // -// Configure once at startup; call [Validator.Validate] per request, passing -// the current time. This keeps the config stable and makes the time dependency -// explicit at the call site. +// Configure once at startup; pass to [Issuer.VerifyAndValidate] or call +// [Validator.Validate] directly per request. // // https://openid.net/specs/openid-connect-core-1_0.html#IDToken type Validator struct { IgnoreIss bool - Iss string // rarely needed — Issuer.UnsafeVerify already checks iss + Iss string IgnoreSub bool Sub string IgnoreAud bool @@ -284,12 +301,104 @@ type Validator struct { Azp string } -// Validate checks the standard JWT/OIDC claim fields in claims against this config. -// -// now is typically time.Now() — passing it explicitly keeps the config stable -// across requests and avoids hidden time dependencies in the validator struct. -func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) { - return ValidateStandardClaims(claims.GetStandardClaims(), v, now) +// Validate implements [ClaimsValidator]. +func (v *Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) { + return ValidateStandardClaims(claims.GetStandardClaims(), *v, now) +} + +// MultiValidator holds claim validation configuration for multi-tenant use. +// Iss, Aud, and Azp accept slices — the claim value must appear in the slice. +type MultiValidator struct { + Iss []string + IgnoreIss bool + IgnoreSub bool + Aud []string + IgnoreAud bool + IgnoreExp bool + IgnoreIat bool + IgnoreAuthTime bool + MaxAge time.Duration + IgnoreNonce bool + IgnoreAmr bool + RequiredAmrs []string + IgnoreAzp bool + Azp []string + IgnoreJti bool +} + +// Validate implements [ClaimsValidator]. +func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) { + sc := claims.GetStandardClaims() + var errs []string + + if !v.IgnoreIss { + if sc.Iss == "" { + errs = append(errs, "missing or malformed 'iss' (token issuer)") + } else if len(v.Iss) > 0 && !slices.Contains(v.Iss, sc.Iss) { + errs = append(errs, fmt.Sprintf("'iss' %q not in allowed list", sc.Iss)) + } + } + + if !v.IgnoreAud { + if sc.Aud == "" { + errs = append(errs, "missing or malformed 'aud' (audience)") + } else if len(v.Aud) > 0 && !slices.Contains(v.Aud, sc.Aud) { + errs = append(errs, fmt.Sprintf("'aud' %q not in allowed list", sc.Aud)) + } + } + + if !v.IgnoreExp { + if sc.Exp <= 0 { + errs = append(errs, "missing or malformed 'exp' (expiration)") + } else if sc.Exp < now.Unix() { + duration := now.Sub(time.Unix(sc.Exp, 0)) + errs = append(errs, fmt.Sprintf("token expired %s ago", formatDuration(duration))) + } + } + + if !v.IgnoreIat { + if sc.Iat <= 0 { + errs = append(errs, "missing or malformed 'iat' (issued at)") + } else if sc.Iat > now.Unix() { + errs = append(errs, "'iat' is in the future") + } + } + + if v.MaxAge > 0 || !v.IgnoreAuthTime { + if sc.AuthTime == 0 { + errs = append(errs, "missing or malformed 'auth_time'") + } else if sc.AuthTime > now.Unix() { + errs = append(errs, "'auth_time' is in the future") + } else if v.MaxAge > 0 { + age := now.Sub(time.Unix(sc.AuthTime, 0)) + if age > v.MaxAge { + errs = append(errs, fmt.Sprintf("'auth_time' exceeds max age %s by %s", formatDuration(v.MaxAge), formatDuration(age-v.MaxAge))) + } + } + } + + if !v.IgnoreAmr { + if len(sc.Amr) == 0 { + errs = append(errs, "missing or malformed 'amr'") + } else { + for _, req := range v.RequiredAmrs { + if !slices.Contains(sc.Amr, req) { + errs = append(errs, fmt.Sprintf("missing required %q from 'amr'", req)) + } + } + } + } + + if !v.IgnoreAzp { + if len(v.Azp) > 0 && !slices.Contains(v.Azp, sc.Azp) { + errs = append(errs, fmt.Sprintf("'azp' %q not in allowed list", sc.Azp)) + } + } + + if len(errs) > 0 { + return errs, fmt.Errorf("has errors") + } + return nil, nil } // ValidateStandardClaims checks the registered JWT/OIDC claim fields against v. @@ -304,7 +413,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( if len(v.Iss) > 0 || !v.IgnoreIss { if len(claims.Iss) == 0 { errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)") - } else if claims.Iss != v.Iss { + } else if len(v.Iss) > 0 && claims.Iss != v.Iss { errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss)) } } @@ -324,7 +433,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( if len(v.Aud) > 0 || !v.IgnoreAud { if len(claims.Aud) == 0 { errs = append(errs, "missing or malformed 'aud' (audience receiving token)") - } else if claims.Aud != v.Aud { + } else if len(v.Aud) > 0 && claims.Aud != v.Aud { errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud)) } } @@ -426,87 +535,63 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( return nil, nil } -// Issuer holds public keys and optional validation config for a trusted token issuer. +// Issuer holds public keys for a trusted token issuer. // -// Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2]. -// After construction, Issuer is immutable. -// -// [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check. -// [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator. +// Issuer is immutable after construction — safe for concurrent use with no locking. +// Use [New] to construct with a fixed key set, or use [Signer.Issuer] or +// [JWKsFetcher.Issuer] to obtain one from a signer or remote JWKS endpoint. type Issuer struct { - URL string // issuer URL for iss claim enforcement; empty skips the check - validator *Validator - keys map[string]crypto.PublicKey // kid → key + pubKeys []PublicJWK + keys map[string]crypto.PublicKey // kid → key } -// New creates an Issuer with explicit keys. +// New creates an Issuer with an explicit set of public keys. // -// v is optional — pass nil to use [Issuer.UnsafeVerify] only. -// [Issuer.VerifyAndValidate] requires a non-nil Validator. -func New(issURL string, keys []PublicJWK, v *Validator) *Issuer { +// The returned Issuer is immutable — keys cannot be added or removed after +// construction. For dynamic key rotation, see [JWKsFetcher]. +func New(keys []PublicJWK) *Issuer { m := make(map[string]crypto.PublicKey, len(keys)) for _, k := range keys { m[k.KID] = k.Key } return &Issuer{ - URL: issURL, - validator: v, - keys: m, + pubKeys: keys, + keys: m, } } -// NewWithJWKs creates an Issuer by fetching keys from jwksURL. +// PublicKeys returns the public keys held by this Issuer. +func (iss *Issuer) PublicKeys() []PublicJWK { + return iss.pubKeys +} + +// ToJWKs serializes the Issuer's public keys as a JWKS JSON document. +func (iss *Issuer) ToJWKs() ([]byte, error) { + return MarshalPublicJWKs(iss.pubKeys) +} + +// Verify decodes tokenStr and verifies its signature. // -// The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is -// not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement. -// -// v is optional — pass nil to use [Issuer.UnsafeVerify] only. -func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) { - keys, err := FetchJWKs(ctx, jwksURL) +// Returns (nil, err) on any failure — the caller never receives an +// unauthenticated JWS. For inspecting a JWS despite signature failure +// (e.g., for multi-issuer routing by kid/iss), use [Issuer.UnsafeVerify]. +func (iss *Issuer) Verify(tokenStr string) (*JWS, error) { + jws, err := iss.UnsafeVerify(tokenStr) if err != nil { return nil, err } - return New("", keys, v), nil + return jws, nil } -// NewWithOIDC creates an Issuer using OIDC discovery. +// UnsafeVerify decodes tokenStr and verifies the signature. // -// It fetches {baseURL}/.well-known/openid-configuration and reads the -// jwks_uri and issuer fields. The Issuer URL is set from the discovery -// document's issuer field (not baseURL) because OIDC requires them to match. +// Unlike [Issuer.Verify], UnsafeVerify returns the parsed [*JWS] even when +// signature verification fails — the error is non-nil but the JWS is +// available for inspection (e.g., to read the kid or iss for multi-issuer +// routing). Returns (nil, err) only when the token cannot be parsed at all. // -// v is optional — pass nil to use [Issuer.UnsafeVerify] only. -func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) { - discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration" - keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL) - if err != nil { - return nil, err - } - return New(issURL, keys, v), nil -} - -// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414). -// -// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the -// jwks_uri and issuer fields. The Issuer URL is set from the discovery -// document's issuer field. -// -// v is optional — pass nil to use [Issuer.UnsafeVerify] only. -func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) { - discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server" - keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL) - if err != nil { - return nil, err - } - return New(issURL, keys, v), nil -} - -// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL] -// is set) checks the iss claim. -// -// "Unsafe" means exp, aud, and other claim values are NOT checked — the token -// is forgery-safe but not semantically validated. Callers are responsible for -// validating claim values, or use [Issuer.VerifyAndValidate]. +// "Unsafe" means exp, aud, iss, and other claim values are NOT checked. +// Use [Issuer.VerifyAndValidate] for full validation. func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) { jws, err := Decode(tokenStr) if err != nil { @@ -514,56 +599,34 @@ func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) { } if jws.Header.Kid == "" { - return nil, fmt.Errorf("missing 'kid' header") + return jws, fmt.Errorf("missing 'kid' header") } key, ok := iss.keys[jws.Header.Kid] if !ok { - return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid) + return jws, fmt.Errorf("unknown kid: %q", jws.Header.Kid) } signingInput := jws.Protected + "." + jws.Payload if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil { - return nil, fmt.Errorf("signature verification failed: %w", err) - } - - // Signature verified — now safe to inspect the payload for iss check. - if iss.URL != "" { - payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) - if err != nil { - return nil, fmt.Errorf("invalid claims encoding: %w", err) - } - var partial struct { - Iss string `json:"iss"` - } - if err := json.Unmarshal(payload, &partial); err != nil { - return nil, fmt.Errorf("invalid claims JSON: %w", err) - } - if partial.Iss != iss.URL { - return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL) - } + return jws, fmt.Errorf("signature verification failed: %w", err) } return jws, nil } -// VerifyAndValidate verifies the token signature and iss, unmarshals the claims -// into claims, and runs the [Validator]. +// VerifyAndValidate verifies the token signature, unmarshals the claims +// into claims, and runs v. // -// Returns a hard error (err != nil) for signature failures, decoding errors, -// and nil Validator. Returns soft errors (errs != nil) for claim validation -// failures (wrong aud, expired token, etc.). +// Returns a hard error (err != nil) for signature failures and decoding errors. +// Returns soft errors (errs != nil) for claim validation failures (wrong aud, +// expired token, etc.). If v is nil, claims are unmarshalled but not validated. // -// claims must be a pointer whose underlying type embeds [StandardClaims] (or -// otherwise implements [StandardClaimsSource]): +// claims must be a pointer whose underlying type embeds [StandardClaims]: // // var claims AppClaims -// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now()) -func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) { - if iss.validator == nil { - return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification") - } - - jws, err := iss.UnsafeVerify(tokenStr) +// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, v, time.Now()) +func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, v ClaimsValidator, now time.Time) (*JWS, []string, error) { + jws, err := iss.Verify(tokenStr) if err != nil { return nil, nil, err } @@ -572,8 +635,12 @@ func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSourc return nil, nil, err } - errs, err := iss.validator.Validate(claims, now) - return jws, errs, err + if v == nil { + return jws, nil, nil + } + + errs, _ := v.Validate(claims, now) // discard sentinel; callers check len(errs) > 0 + return jws, errs, nil } // verifyWith checks a JWS signature using the given algorithm and public key. diff --git a/auth/ajwt/jwt_test.go b/auth/ajwt/jwt_test.go index f4c7cac..e7b6c0b 100644 --- a/auth/ajwt/jwt_test.go +++ b/auth/ajwt/jwt_test.go @@ -67,11 +67,11 @@ func goodClaims() AppClaims { } } -// goodValidator configures the validator. IgnoreIss is true because -// Issuer.UnsafeVerify already enforces the iss claim — no need to check twice. +// goodValidator configures the validator with iss set to "https://example.com". +// Iss checking is now the Validator's responsibility, not the Issuer's. func goodValidator() *ajwt.Validator { return &ajwt.Validator{ - IgnoreIss: true, // UnsafeVerify handles iss + Iss: "https://example.com", Sub: "user123", Aud: "myapp", Jti: "abc123", @@ -82,15 +82,13 @@ func goodValidator() *ajwt.Validator { } func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer { - return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator()) + return ajwt.New([]ajwt.PublicJWK{pub}) } // TestRoundTrip is the primary happy path using ES256. // It demonstrates the full VerifyAndValidate flow: // // New → VerifyAndValidate → custom claim access -// -// No Claims interface, no Verified flag, no type assertions on jws. func TestRoundTrip(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -115,7 +113,7 @@ func TestRoundTrip(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) var decoded AppClaims - jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now()) + jws2, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now()) if err != nil { t.Fatalf("VerifyAndValidate failed: %v", err) } @@ -125,7 +123,7 @@ func TestRoundTrip(t *testing.T) { if jws2.Header.Alg != "ES256" { t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg) } - // Direct field access — no type assertion needed, no jws.Claims interface. + // Direct field access — no type assertion needed. if decoded.Email != claims.Email { t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) } @@ -156,7 +154,7 @@ func TestRoundTripRS256(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) var decoded AppClaims - _, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now()) + _, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now()) if err != nil { t.Fatalf("VerifyAndValidate failed: %v", err) } @@ -190,7 +188,7 @@ func TestRoundTripEdDSA(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"}) var decoded AppClaims - _, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now()) + _, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now()) if err != nil { t.Fatalf("VerifyAndValidate failed: %v", err) } @@ -209,8 +207,7 @@ func TestUnsafeVerifyFlow(t *testing.T) { _, _ = jws.Sign(privKey) token := jws.Encode() - // Create issuer without validator — UnsafeVerify only. - iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil) + iss := ajwt.New([]ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}) jws2, err := iss.UnsafeVerify(token) if err != nil { @@ -228,6 +225,34 @@ func TestUnsafeVerifyFlow(t *testing.T) { } } +// TestUnsafeVerifyReturnsJWSOnSigFailure verifies that UnsafeVerify returns a +// non-nil *JWS even when signature verification fails, so callers can inspect +// the header (kid, iss) for routing. +func TestUnsafeVerifyReturnsJWSOnSigFailure(t *testing.T) { + signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := ajwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(signingKey) + token := jws.Encode() + + // Issuer has wrong public key — sig verification will fail. + iss := ajwt.New([]ajwt.PublicJWK{{Key: &wrongKey.PublicKey, KID: "k"}}) + + result, err := iss.UnsafeVerify(token) + if err == nil { + t.Fatal("expected error for wrong key") + } + // UnsafeVerify must return the JWS despite sig failure. + if result == nil { + t.Fatal("UnsafeVerify should return non-nil JWS on sig failure") + } + if result.Header.Kid != "k" { + t.Errorf("expected kid %q, got %q", "k", result.Header.Kid) + } +} + // TestCustomValidation demonstrates that ValidateStandardClaims is called // explicitly and custom fields are validated without any Claims interface. func TestCustomValidation(t *testing.T) { @@ -264,8 +289,8 @@ func TestCustomValidation(t *testing.T) { } } -// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly -// when no Validator was provided at construction time. +// TestVerifyAndValidateNilValidator confirms that passing a nil ClaimsValidator +// skips validation but still returns the verified JWS and unmarshalled claims. func TestVerifyAndValidateNilValidator(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) c := goodClaims() @@ -273,11 +298,21 @@ func TestVerifyAndValidateNilValidator(t *testing.T) { _, _ = jws.Sign(privKey) token := jws.Encode() - iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil) + iss := ajwt.New([]ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}) var claims AppClaims - if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil { - t.Fatal("expected VerifyAndValidate to error with nil validator") + jws2, errs, err := iss.VerifyAndValidate(token, &claims, nil, time.Now()) + if err != nil { + t.Fatalf("expected success with nil validator: %v", err) + } + if len(errs) > 0 { + t.Fatalf("expected no validation errors with nil validator: %v", errs) + } + if jws2 == nil { + t.Fatal("expected non-nil JWS") + } + if claims.Email != c.Email { + t.Errorf("claims not unmarshalled: email got %q, want %q", claims.Email, c.Email) } } @@ -293,8 +328,8 @@ func TestIssuerWrongKey(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"}) - if _, err := iss.UnsafeVerify(token); err == nil { - t.Fatal("expected UnsafeVerify to fail with wrong key") + if _, err := iss.Verify(token); err == nil { + t.Fatal("expected Verify to fail with wrong key") } } @@ -309,27 +344,47 @@ func TestIssuerUnknownKid(t *testing.T) { iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"}) - if _, err := iss.UnsafeVerify(token); err == nil { - t.Fatal("expected UnsafeVerify to fail for unknown kid") + if _, err := iss.Verify(token); err == nil { + t.Fatal("expected Verify to fail for unknown kid") } } -// TestIssuerIssMismatch confirms that a token with a mismatched iss is rejected -// even if the signature is valid. +// TestIssuerIssMismatch confirms that a token with a mismatched iss is caught +// by the Validator, not the Issuer. Signature verification succeeds; the iss +// mismatch appears as a soft validation error. func TestIssuerIssMismatch(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) claims := goodClaims() - claims.Iss = "https://evil.example.com" // not the issuer URL + claims.Iss = "https://evil.example.com" jws, _ := ajwt.NewJWSFromClaims(&claims, "k") _, _ = jws.Sign(privKey) token := jws.Encode() - // Issuer expects "https://example.com" but token says "https://evil.example.com" iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) - if _, err := iss.UnsafeVerify(token); err == nil { - t.Fatal("expected UnsafeVerify to fail: iss mismatch") + // UnsafeVerify succeeds — iss is not checked at the Issuer level. + if _, err := iss.UnsafeVerify(token); err != nil { + t.Fatalf("UnsafeVerify should succeed (no iss check): %v", err) + } + + // VerifyAndValidate with a Validator that enforces iss catches the mismatch. + var decoded AppClaims + _, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now()) + if err != nil { + t.Fatalf("unexpected hard error: %v", err) + } + if len(errs) == 0 { + t.Fatal("expected validation errors for iss mismatch") + } + found := false + for _, e := range errs { + if strings.Contains(e, "iss") { + found = true + } + } + if !found { + t.Fatalf("expected iss error in validation errors: %v", errs) } } @@ -352,8 +407,133 @@ func TestVerifyTamperedAlg(t *testing.T) { parts := strings.SplitN(token, ".", 3) tamperedToken := noneHeader + "." + parts[1] + "." + parts[2] - if _, err := iss.UnsafeVerify(tamperedToken); err == nil { - t.Fatal("expected UnsafeVerify to fail for tampered alg") + if _, err := iss.Verify(tamperedToken); err == nil { + t.Fatal("expected Verify to fail for tampered alg") + } +} + +// TestSignerRoundTrip verifies the Signer → Sign → Issuer → VerifyAndValidate flow. +func TestSignerRoundTrip(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{KID: "k1", Signer: privKey}}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + + iss := signer.Issuer() + var decoded AppClaims + _, errs, err := iss.VerifyAndValidate(tokenStr, &decoded, goodValidator(), time.Now()) + if err != nil { + t.Fatalf("VerifyAndValidate failed: %v", err) + } + if len(errs) > 0 { + t.Fatalf("claim validation failed: %v", errs) + } + if decoded.Email != claims.Email { + t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) + } +} + +// TestSignerAutoKID verifies that KID is auto-computed from the key thumbprint +// when NamedSigner.KID is empty. +func TestSignerAutoKID(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{Signer: privKey}}) + if err != nil { + t.Fatal(err) + } + + keys := signer.PublicKeys() + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KID == "" { + t.Fatal("KID should be auto-computed from thumbprint") + } + + // Token should verify with the auto-KID issuer. + iss := signer.Issuer() + claims := goodClaims() + tokenStr, _ := signer.Sign(&claims) + + if _, err := iss.Verify(tokenStr); err != nil { + t.Fatalf("Verify failed: %v", err) + } +} + +// TestSignerRoundRobin verifies that signing round-robins across keys and that +// all resulting tokens verify with the combined Issuer. +func TestSignerRoundRobin(t *testing.T) { + key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + key2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + signer, err := ajwt.NewSigner([]ajwt.NamedSigner{ + {KID: "k1", Signer: key1}, + {KID: "k2", Signer: key2}, + }) + if err != nil { + t.Fatal(err) + } + + iss := signer.Issuer() + v := goodValidator() + + for i := range 4 { + claims := goodClaims() + tokenStr, err := signer.Sign(&claims) + if err != nil { + t.Fatalf("Sign[%d] failed: %v", i, err) + } + var decoded AppClaims + if _, _, err := iss.VerifyAndValidate(tokenStr, &decoded, v, time.Now()); err != nil { + t.Fatalf("VerifyAndValidate[%d] failed: %v", i, err) + } + } +} + +// TestIssuerToJWKs verifies JWKS serialization and round-trip parsing. +func TestIssuerToJWKs(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{KID: "k1", Signer: privKey}}) + if err != nil { + t.Fatal(err) + } + + iss := signer.Issuer() + jwksBytes, err := iss.ToJWKs() + if err != nil { + t.Fatal(err) + } + + // Round-trip: parse the JWKS JSON and verify it produces a working Issuer. + keys, err := ajwt.UnmarshalPublicJWKs(jwksBytes) + if err != nil { + t.Fatal(err) + } + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KID != "k1" { + t.Errorf("expected kid 'k1', got %q", keys[0].KID) + } + + iss2 := ajwt.New(keys) + claims := goodClaims() + tokenStr, _ := signer.Sign(&claims) + if _, err := iss2.Verify(tokenStr); err != nil { + t.Fatalf("Verify on round-tripped JWKS failed: %v", err) } } diff --git a/auth/ajwt/pub.go b/auth/ajwt/pub.go index ed3e473..c0e6d75 100644 --- a/auth/ajwt/pub.go +++ b/auth/ajwt/pub.go @@ -158,6 +158,74 @@ type JWKsJSON struct { Keys []PublicJWKJSON `json:"keys"` } +// EncodePublicJWK converts a [PublicJWK] to its JSON representation. +// +// Supported key types: *ecdsa.PublicKey (EC), *rsa.PublicKey (RSA), ed25519.PublicKey (OKP). +func EncodePublicJWK(k PublicJWK) (PublicJWKJSON, error) { + switch key := k.Key.(type) { + case *ecdsa.PublicKey: + var crv string + switch key.Curve { + case elliptic.P256(): + crv = "P-256" + case elliptic.P384(): + crv = "P-384" + case elliptic.P521(): + crv = "P-521" + default: + return PublicJWKJSON{}, fmt.Errorf("EncodePublicJWK: unsupported EC curve %s", key.Curve.Params().Name) + } + byteLen := (key.Curve.Params().BitSize + 7) / 8 + xBytes := make([]byte, byteLen) + yBytes := make([]byte, byteLen) + key.X.FillBytes(xBytes) + key.Y.FillBytes(yBytes) + return PublicJWKJSON{ + Kty: "EC", + KID: k.KID, + Crv: crv, + X: base64.RawURLEncoding.EncodeToString(xBytes), + Y: base64.RawURLEncoding.EncodeToString(yBytes), + Use: k.Use, + }, nil + + case *rsa.PublicKey: + eInt := big.NewInt(int64(key.E)) + return PublicJWKJSON{ + Kty: "RSA", + KID: k.KID, + N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()), + Use: k.Use, + }, nil + + case ed25519.PublicKey: + return PublicJWKJSON{ + Kty: "OKP", + KID: k.KID, + Crv: "Ed25519", + X: base64.RawURLEncoding.EncodeToString([]byte(key)), + Use: k.Use, + }, nil + + default: + return PublicJWKJSON{}, fmt.Errorf("EncodePublicJWK: unsupported key type %T", k.Key) + } +} + +// MarshalPublicJWKs serializes a slice of [PublicJWK] as a JWKS JSON document. +func MarshalPublicJWKs(keys []PublicJWK) ([]byte, error) { + jsonKeys := make([]PublicJWKJSON, 0, len(keys)) + for _, k := range keys { + jk, err := EncodePublicJWK(k) + if err != nil { + return nil, err + } + jsonKeys = append(jsonKeys, jk) + } + return json.Marshal(JWKsJSON{Keys: jsonKeys}) +} + // FetchJWKs retrieves and parses a JWKS document from jwksURL. // // ctx is used for the HTTP request timeout and cancellation. diff --git a/auth/ajwt/sign.go b/auth/ajwt/sign.go new file mode 100644 index 0000000..47bafc4 --- /dev/null +++ b/auth/ajwt/sign.go @@ -0,0 +1,98 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package ajwt + +import ( + "crypto" + "fmt" + "sync/atomic" +) + +// NamedSigner pairs a [crypto.Signer] with a key ID (KID). +// +// If KID is empty, it is auto-computed from the RFC 7638 thumbprint of the +// public key when passed to [NewSigner]. +type NamedSigner struct { + KID string + Signer crypto.Signer +} + +// Signer manages one or more private signing keys and issues JWTs by +// round-robining across them. +// +// Do not copy a Signer after first use — it contains an atomic counter. +type Signer struct { + signers []NamedSigner + signerIdx atomic.Uint64 +} + +// NewSigner creates a Signer from the provided signing keys. +// +// If a NamedSigner's KID is empty, it is auto-computed from the RFC 7638 +// thumbprint of the public key. Returns an error if the slice is empty or +// a thumbprint cannot be computed. +func NewSigner(signers []NamedSigner) (*Signer, error) { + if len(signers) == 0 { + return nil, fmt.Errorf("NewSigner: at least one signer is required") + } + // Copy so the caller can't mutate after construction. + ss := make([]NamedSigner, len(signers)) + copy(ss, signers) + for i, ns := range ss { + if ns.KID == "" { + jwk := PublicJWK{Key: ns.Signer.Public()} + thumb, err := jwk.Thumbprint() + if err != nil { + return nil, fmt.Errorf("NewSigner: compute thumbprint for signer[%d]: %w", i, err) + } + ss[i].KID = thumb + } + } + return &Signer{signers: ss}, nil +} + +// Sign creates and signs a compact JWT from claims, using the next signing key +// in round-robin order. The caller is responsible for setting the "iss" field +// in claims if issuer identification is needed. +func (s *Signer) Sign(claims any) (string, error) { + idx := s.signerIdx.Add(1) - 1 + ns := s.signers[idx%uint64(len(s.signers))] + + jws, err := NewJWSFromClaims(claims, ns.KID) + if err != nil { + return "", err + } + if _, err := jws.Sign(ns.Signer); err != nil { + return "", err + } + return jws.Encode(), nil +} + +// Issuer returns a new [*Issuer] containing the public keys of all signing keys. +// +// Use this to construct an Issuer for verifying tokens signed by this Signer. +// For key rotation, combine with old public keys: +// +// iss := ajwt.New(append(signer.PublicKeys(), oldKeys...)) +func (s *Signer) Issuer() *Issuer { + return New(s.PublicKeys()) +} + +// PublicKeys returns the public-key side of each signing key, in the same order +// as the signers were provided to [NewSigner]. +func (s *Signer) PublicKeys() []PublicJWK { + keys := make([]PublicJWK, len(s.signers)) + for i, ns := range s.signers { + keys[i] = PublicJWK{ + Key: ns.Signer.Public(), + KID: ns.KID, + } + } + return keys +}