ajwt: implement redesigned API from REDESIGN.md

Rename ValidateParams → Validator, make Issuer immutable after construction.

Key changes:
- StandardClaims.GetStandardClaims() + StandardClaimsSource interface: any
  struct embedding StandardClaims satisfies the interface for free via
  Go's method promotion — zero boilerplate for callers
- Issuer is now immutable after construction; keys and validator are
  unexported; Params field removed
- New constructors: New, NewWithJWKs, NewWithOIDC, NewWithOAuth2
- UnsafeVerify(tokenStr string) (*JWS, error): Decode + sig verify + iss
  check; "unsafe" means exp/aud/etc. are NOT checked
- VerifyAndValidate(tokenStr, claims, now): full pipeline requiring non-nil
  Validator; fails loudly with nil Validator
- FetchJWKs(ctx, url), FetchJWKsFromOIDC(ctx, base),
  FetchJWKsFromOAuth2(ctx, base): standalone fetch functions with context
- PublicJWK.Thumbprint(): RFC 7638 SHA-256 thumbprint, canonical field
  ordering per spec (EC: crv/kty/x/y, RSA: e/kty/n, OKP: crv/kty/x)
- DecodePublicJWKsJSON: auto-populates KID from Thumbprint when absent
- Tests: 14 pass, covering VerifyAndValidate, UnsafeVerify, nil-validator
  error, all alg round trips, tampered alg, Thumbprint, auto-KID
This commit is contained in:
AJ ONeal 2026-03-13 10:28:47 -06:00
parent 1f0b36fc6d
commit 3f7985317f
No known key found for this signature in database
3 changed files with 554 additions and 223 deletions

View File

@ -12,25 +12,33 @@
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
// - [Issuer] owns key management and signature verification, centralizing
// the key lookup → sig verify → iss check sequence.
// - [ValidateParams] is a stable config object; time is passed at the call
// site so the same params can be reused across requests.
// - [Validator] is a stable config object; time is passed at the call site
// so the same validator can be reused across requests.
// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims].
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
//
// Typical usage:
// Typical usage with VerifyAndValidate:
//
// // At startup:
// iss := ajwt.NewIssuer("https://accounts.example.com")
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
// if err := iss.FetchKeys(ctx); err != nil { ... }
// iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com",
// &ajwt.Validator{Aud: "my-app", IgnoreIss: true})
//
// // Per request:
// jws, err := ajwt.Decode(tokenStr)
// if err := iss.Verify(jws); err != nil { ... } // sig + iss check
// var claims AppClaims
// if err := jws.UnmarshalClaims(&claims); err != nil { ... }
// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... }
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
// if err != nil { /* hard error: bad sig, expired, etc. */ }
// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ }
//
// Typical usage with UnsafeVerify (custom validation only):
//
// iss := ajwt.New("https://example.com", keys, nil)
// jws, err := iss.UnsafeVerify(tokenStr)
// var claims AppClaims
// jws.UnmarshalClaims(&claims)
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims,
// ajwt.Validator{Aud: "myapp"}, time.Now())
package ajwt
import (
@ -48,7 +56,6 @@ import (
"encoding/json"
"fmt"
"math/big"
"net/http"
"slices"
"strings"
"time"
@ -58,10 +65,10 @@ import (
//
// It holds only the parsed structure — header, raw base64url fields, and
// decoded signature bytes. It carries no Claims interface and no Verified flag;
// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to
// decode the payload into a typed struct.
// use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the
// token and [JWS.UnmarshalClaims] to decode the payload into a typed struct.
type JWS struct {
Protected string // base64url-encoded header
Protected string // base64url-encoded header
Header StandardHeader
Payload string // base64url-encoded claims
Signature []byte
@ -77,12 +84,15 @@ type StandardHeader struct {
// StandardClaims holds the registered JWT claim names defined in RFC 7519
// and extended by OpenID Connect Core.
//
// Embed StandardClaims in your own claims struct:
// Embed StandardClaims in your own claims struct to satisfy [StandardClaimsSource]
// for free via Go's method promotion — zero boilerplate:
//
// type AppClaims struct {
// ajwt.StandardClaims
// ajwt.StandardClaims // promotes GetStandardClaims()
// Email string `json:"email"`
// Roles []string `json:"roles"`
// }
// // AppClaims now satisfies StandardClaimsSource automatically.
type StandardClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
@ -96,10 +106,24 @@ type StandardClaims struct {
Jti string `json:"jti"`
}
// GetStandardClaims implements [StandardClaimsSource].
// Any struct embedding StandardClaims gets this method for free via promotion.
func (sc StandardClaims) GetStandardClaims() StandardClaims { return sc }
// StandardClaimsSource is implemented for free by any struct that embeds [StandardClaims].
//
// type AppClaims struct {
// ajwt.StandardClaims // promotes GetStandardClaims() — zero boilerplate
// Email string `json:"email"`
// }
type StandardClaimsSource interface {
GetStandardClaims() StandardClaims
}
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
//
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
// [Issuer.Verify] to populate a typed claims struct.
// [Issuer.UnsafeVerify] to safely populate a typed claims struct.
func Decode(tokenStr string) (*JWS, error) {
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
@ -128,7 +152,7 @@ func Decode(tokenStr string) (*JWS, error) {
// UnmarshalClaims decodes the JWT payload into v.
//
// v must be a pointer to a struct (e.g. *AppClaims). Always call
// [Issuer.Verify] before UnmarshalClaims to ensure the signature is
// [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is
// authenticated before trusting the payload.
func (jws *JWS) UnmarshalClaims(v any) error {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
@ -232,16 +256,16 @@ func (jws *JWS) Encode() string {
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
}
// ValidateParams holds claim validation configuration.
// Validator holds claim validation configuration.
//
// Configure once at startup; call [ValidateParams.Validate] per request,
// passing the current time. This keeps the config stable and makes the
// time dependency explicit at the call site.
// Configure once at startup; call [Validator.Validate] per request, passing
// the current time. This keeps the config stable and makes the time dependency
// explicit at the call site.
//
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type ValidateParams struct {
type Validator struct {
IgnoreIss bool
Iss string
Iss string // rarely needed — Issuer.UnsafeVerify already checks iss
IgnoreSub bool
Sub string
IgnoreAud bool
@ -260,53 +284,53 @@ type ValidateParams struct {
Azp string
}
// Validate checks the standard JWT/OIDC claim fields against this config.
// Validate checks the standard JWT/OIDC claim fields in claims against this config.
//
// now is typically time.Now() — passing it explicitly keeps the config stable
// across requests and avoids hidden time dependencies in the params struct.
func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) {
return ValidateStandardClaims(claims, p, now)
// across requests and avoids hidden time dependencies in the validator struct.
func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
return ValidateStandardClaims(claims.GetStandardClaims(), v, now)
}
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params.
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against v.
//
// Exported so callers can use it directly without a [ValidateParams] receiver:
// Exported so callers can use it directly without a [Validator] receiver:
//
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now())
func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) {
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, v, time.Now())
func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ([]string, error) {
var errs []string
// Required to exist and match
if len(params.Iss) > 0 || !params.IgnoreIss {
if len(v.Iss) > 0 || !v.IgnoreIss {
if len(claims.Iss) == 0 {
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
} else if claims.Iss != params.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss))
} else if claims.Iss != v.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss))
}
}
// Required to exist, optional match
if len(claims.Sub) == 0 {
if !params.IgnoreSub {
if !v.IgnoreSub {
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
}
} else if len(params.Sub) > 0 {
if params.Sub != claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub))
} else if len(v.Sub) > 0 {
if v.Sub != claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, v.Sub))
}
}
// Required to exist and match
if len(params.Aud) > 0 || !params.IgnoreAud {
if len(v.Aud) > 0 || !v.IgnoreAud {
if len(claims.Aud) == 0 {
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
} else if claims.Aud != params.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud))
} else if claims.Aud != v.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
}
}
// Required to exist and not be in the past
if !params.IgnoreExp {
if !v.IgnoreExp {
if claims.Exp <= 0 {
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
} else if claims.Exp < now.Unix() {
@ -317,7 +341,7 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
}
// Required to exist and not be in the future
if !params.IgnoreIat {
if !v.IgnoreIat {
if claims.Iat <= 0 {
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
} else if claims.Iat > now.Unix() {
@ -328,53 +352,53 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
}
// Should exist, in the past, with optional max age
if params.MaxAge > 0 || !params.IgnoreAuthTime {
if v.MaxAge > 0 || !v.IgnoreAuthTime {
if claims.AuthTime == 0 {
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
} else {
authTime := time.Unix(claims.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
age := now.Sub(authTime)
diff := age - params.MaxAge
diff := age - v.MaxAge
if claims.AuthTime > now.Unix() {
fromNow := time.Unix(claims.AuthTime, 0).Sub(now)
errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s in the future (server time %s)",
authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")),
)
} else if params.MaxAge > 0 && age > params.MaxAge {
} else if v.MaxAge > 0 && age > v.MaxAge {
errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s old, exceeding max age %s by %s",
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)),
authTimeStr, formatDuration(age), formatDuration(v.MaxAge), formatDuration(diff)),
)
}
}
}
// Optional exact match
if params.Jti != claims.Jti {
if len(params.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti))
} else if !params.IgnoreJti {
if v.Jti != claims.Jti {
if len(v.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
} else if !v.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
}
}
// Optional exact match
if params.Nonce != claims.Nonce {
if len(params.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce))
} else if !params.IgnoreNonce {
if v.Nonce != claims.Nonce {
if len(v.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
} else if !v.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
}
}
// Should exist, optional required-set check
if !params.IgnoreAmr {
if !v.IgnoreAmr {
if len(claims.Amr) == 0 {
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
} else if len(params.RequiredAmrs) > 0 {
for _, required := range params.RequiredAmrs {
} else if len(v.RequiredAmrs) > 0 {
for _, required := range v.RequiredAmrs {
if !slices.Contains(claims.Amr, required) {
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
}
@ -383,10 +407,10 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
}
// Optional, match if present
if params.Azp != claims.Azp {
if len(params.Azp) > 0 {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp))
} else if !params.IgnoreAzp {
if v.Azp != claims.Azp {
if len(v.Azp) > 0 {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
} else if !v.IgnoreAzp {
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
}
}
@ -402,111 +426,154 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
return nil, nil
}
// Issuer holds public keys and validation config for a trusted token issuer.
// Issuer holds public keys and optional validation config for a trusted token issuer.
//
// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint.
// [Issuer.SetKeys] injects keys directly (useful in tests).
// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check.
// Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2].
// After construction, Issuer is immutable.
//
// Typical setup:
//
// iss := ajwt.NewIssuer("https://accounts.example.com")
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
// if err := iss.FetchKeys(ctx); err != nil { ... }
// [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check.
// [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator.
type Issuer struct {
URL string
JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json"
Params ValidateParams
keys map[string]crypto.PublicKey // kid → key
URL string // issuer URL for iss claim enforcement; empty skips the check
validator *Validator
keys map[string]crypto.PublicKey // kid → key
}
// NewIssuer creates an Issuer for the given base URL.
func NewIssuer(url string) *Issuer {
return &Issuer{
URL: url,
keys: make(map[string]crypto.PublicKey),
}
}
// SetKeys stores public keys by their KID, replacing any previously stored keys.
// Useful for injecting keys in tests without an HTTP round-trip.
func (iss *Issuer) SetKeys(keys []PublicJWK) {
// New creates an Issuer with explicit keys.
//
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
// [Issuer.VerifyAndValidate] requires a non-nil Validator.
func New(issURL string, keys []PublicJWK, v *Validator) *Issuer {
m := make(map[string]crypto.PublicKey, len(keys))
for _, k := range keys {
m[k.KID] = k.Key
}
iss.keys = m
return &Issuer{
URL: issURL,
validator: v,
keys: m,
}
}
// FetchKeys retrieves and stores the JWKS from the issuer's endpoint.
// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json".
func (iss *Issuer) FetchKeys(ctx context.Context) error {
url := iss.JWKsURL
if url == "" {
url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json"
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
}
keys, err := DecodePublicJWKs(resp.Body)
if err != nil {
return fmt.Errorf("parse JWKS: %w", err)
}
iss.SetKeys(keys)
return nil
}
// Verify authenticates jws against this issuer:
// 1. Looks up the signing key by jws.Header.Kid.
// 2. Verifies the signature before trusting any payload data.
// 3. Checks that the token's "iss" claim matches iss.URL.
// NewWithJWKs creates an Issuer by fetching keys from jwksURL.
//
// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a
// typed struct, then [ValidateParams.Validate] to check claim values.
func (iss *Issuer) Verify(jws *JWS) error {
// The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is
// not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement.
//
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) {
keys, err := FetchJWKs(ctx, jwksURL)
if err != nil {
return nil, err
}
return New("", keys, v), nil
}
// NewWithOIDC creates an Issuer using OIDC discovery.
//
// It fetches {baseURL}/.well-known/openid-configuration and reads the
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
// document's issuer field (not baseURL) because OIDC requires them to match.
//
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
if err != nil {
return nil, err
}
return New(issURL, keys, v), nil
}
// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414).
//
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
// document's issuer field.
//
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
if err != nil {
return nil, err
}
return New(issURL, keys, v), nil
}
// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL]
// is set) checks the iss claim.
//
// "Unsafe" means exp, aud, and other claim values are NOT checked — the token
// is forgery-safe but not semantically validated. Callers are responsible for
// validating claim values, or use [Issuer.VerifyAndValidate].
func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
jws, err := Decode(tokenStr)
if err != nil {
return nil, err
}
if jws.Header.Kid == "" {
return fmt.Errorf("missing 'kid' header")
return nil, fmt.Errorf("missing 'kid' header")
}
key, ok := iss.keys[jws.Header.Kid]
if !ok {
return fmt.Errorf("unknown kid: %q", jws.Header.Kid)
return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid)
}
signingInput := jws.Protected + "." + jws.Payload
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
return nil, fmt.Errorf("signature verification failed: %w", err)
}
// Signature verified — now safe to inspect the payload.
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
// Signature verified — now safe to inspect the payload for iss check.
if iss.URL != "" {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, fmt.Errorf("invalid claims encoding: %w", err)
}
var partial struct {
Iss string `json:"iss"`
}
if err := json.Unmarshal(payload, &partial); err != nil {
return nil, fmt.Errorf("invalid claims JSON: %w", err)
}
if partial.Iss != iss.URL {
return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
}
}
return jws, nil
}
// VerifyAndValidate verifies the token signature and iss, unmarshals the claims
// into claims, and runs the [Validator].
//
// Returns a hard error (err != nil) for signature failures, decoding errors,
// and nil Validator. Returns soft errors (errs != nil) for claim validation
// failures (wrong aud, expired token, etc.).
//
// claims must be a pointer whose underlying type embeds [StandardClaims] (or
// otherwise implements [StandardClaimsSource]):
//
// var claims AppClaims
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) {
if iss.validator == nil {
return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification")
}
jws, err := iss.UnsafeVerify(tokenStr)
if err != nil {
return fmt.Errorf("invalid claims encoding: %w", err)
}
var partial struct {
Iss string `json:"iss"`
}
if err := json.Unmarshal(payload, &partial); err != nil {
return fmt.Errorf("invalid claims JSON: %w", err)
}
if partial.Iss != iss.URL {
return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
return nil, nil, err
}
return nil
if err := jws.UnmarshalClaims(claims); err != nil {
return nil, nil, err
}
errs, err := iss.validator.Validate(claims, now)
return jws, errs, err
}
// verifyWith checks a JWS signature using the given algorithm and public key.

View File

@ -14,7 +14,9 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"strings"
"testing"
"time"
@ -23,9 +25,8 @@ import (
// AppClaims embeds StandardClaims and adds application-specific fields.
//
// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate
// interface — there is none. Validation is explicit: call
// ValidateStandardClaims or ValidateParams.Validate at the call site.
// Because StandardClaims is embedded, AppClaims satisfies StandardClaimsSource
// for free via Go's method promotion — no interface to implement.
type AppClaims struct {
ajwt.StandardClaims
Email string `json:"email"`
@ -33,9 +34,10 @@ type AppClaims struct {
}
// validateAppClaims is a plain function — not a method satisfying an interface.
// Custom validation logic lives here, calling ValidateStandardClaims directly.
func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) {
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now)
// It demonstrates the UnsafeVerify pattern: custom validation logic lives here,
// calling ValidateStandardClaims directly and adding app-specific checks.
func validateAppClaims(c AppClaims, v ajwt.Validator, now time.Time) ([]string, error) {
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, v, now)
if c.Email == "" {
errs = append(errs, "missing email claim")
}
@ -65,11 +67,11 @@ func goodClaims() AppClaims {
}
}
// goodParams configures the validator. Iss is omitted because Issuer.Verify
// already enforces the iss claim — no need to check it twice.
func goodParams() ajwt.ValidateParams {
return ajwt.ValidateParams{
IgnoreIss: true, // Issuer.Verify handles iss
// goodValidator configures the validator. IgnoreIss is true because
// Issuer.UnsafeVerify already enforces the iss claim — no need to check twice.
func goodValidator() *ajwt.Validator {
return &ajwt.Validator{
IgnoreIss: true, // UnsafeVerify handles iss
Sub: "user123",
Aud: "myapp",
Jti: "abc123",
@ -80,16 +82,13 @@ func goodParams() ajwt.ValidateParams {
}
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
iss := ajwt.NewIssuer("https://example.com")
iss.Params = goodParams()
iss.SetKeys([]ajwt.PublicJWK{pub})
return iss
return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator())
}
// TestRoundTrip is the primary happy path using ES256.
// It demonstrates the full Issuer-based flow:
// It demonstrates the full VerifyAndValidate flow:
//
// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate
// New → VerifyAndValidate → custom claim access
//
// No Claims interface, no Verified flag, no type assertions on jws.
func TestRoundTrip(t *testing.T) {
@ -115,20 +114,16 @@ func TestRoundTrip(t *testing.T) {
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
if err != nil {
t.Fatalf("VerifyAndValidate failed: %v", err)
}
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
t.Fatalf("validation failed: %v", errs)
if len(errs) > 0 {
t.Fatalf("claim validation failed: %v", errs)
}
if jws2.Header.Alg != "ES256" {
t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg)
}
// Direct field access — no type assertion needed, no jws.Claims interface.
if decoded.Email != claims.Email {
@ -160,20 +155,13 @@ func TestRoundTripRS256(t *testing.T) {
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
if err != nil {
t.Fatalf("VerifyAndValidate failed: %v", err)
}
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
t.Fatalf("validation failed: %v", errs)
if len(errs) > 0 {
t.Fatalf("claim validation failed: %v", errs)
}
}
@ -201,25 +189,47 @@ func TestRoundTripEdDSA(t *testing.T) {
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
if err != nil {
t.Fatalf("VerifyAndValidate failed: %v", err)
}
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
t.Fatalf("validation failed: %v", errs)
if len(errs) > 0 {
t.Fatalf("claim validation failed: %v", errs)
}
}
// TestCustomValidation demonstrates custom claim validation without any interface.
// The caller owns the validation logic and calls ValidateStandardClaims directly.
// TestUnsafeVerifyFlow demonstrates the UnsafeVerify + custom validation pattern.
// The caller owns the full validation pipeline.
func TestUnsafeVerifyFlow(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
// Create issuer without validator — UnsafeVerify only.
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
jws2, err := iss.UnsafeVerify(token)
if err != nil {
t.Fatalf("UnsafeVerify failed: %v", err)
}
var decoded AppClaims
if err := jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatalf("UnmarshalClaims failed: %v", err)
}
errs, err := ajwt.ValidateStandardClaims(decoded.StandardClaims, *goodValidator(), time.Now())
if err != nil {
t.Fatalf("ValidateStandardClaims failed: %v — errs: %v", err, errs)
}
}
// TestCustomValidation demonstrates that ValidateStandardClaims is called
// explicitly and custom fields are validated without any Claims interface.
func TestCustomValidation(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -231,12 +241,15 @@ func TestCustomValidation(t *testing.T) {
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
_ = iss.Verify(jws2)
jws2, err := iss.UnsafeVerify(token)
if err != nil {
t.Fatalf("UnsafeVerify failed unexpectedly: %v", err)
}
var decoded AppClaims
_ = jws2.UnmarshalClaims(&decoded)
errs, err := validateAppClaims(decoded, goodParams(), time.Now())
errs, err := validateAppClaims(decoded, *goodValidator(), time.Now())
if err == nil {
t.Fatal("expected validation to fail: email is empty")
}
@ -251,6 +264,23 @@ func TestCustomValidation(t *testing.T) {
}
}
// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly
// when no Validator was provided at construction time.
func TestVerifyAndValidateNilValidator(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
c := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&c, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
var claims AppClaims
if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil {
t.Fatal("expected VerifyAndValidate to error with nil validator")
}
}
// TestIssuerWrongKey confirms that a different key's public key is rejected.
func TestIssuerWrongKey(t *testing.T) {
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -262,10 +292,9 @@ func TestIssuerWrongKey(t *testing.T) {
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail with wrong key")
if _, err := iss.UnsafeVerify(token); err == nil {
t.Fatal("expected UnsafeVerify to fail with wrong key")
}
}
@ -279,10 +308,9 @@ func TestIssuerUnknownKid(t *testing.T) {
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail for unknown kid")
if _, err := iss.UnsafeVerify(token); err == nil {
t.Fatal("expected UnsafeVerify to fail for unknown kid")
}
}
@ -299,14 +327,15 @@ func TestIssuerIssMismatch(t *testing.T) {
// Issuer expects "https://example.com" but token says "https://evil.example.com"
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail: iss mismatch")
if _, err := iss.UnsafeVerify(token); err == nil {
t.Fatal("expected UnsafeVerify to fail: iss mismatch")
}
}
// TestVerifyTamperedAlg confirms that a tampered alg header is rejected.
// TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected.
// The token is reconstructed with a replaced protected header; the original
// ES256 signature is kept, making the signing input mismatch detectable.
func TestVerifyTamperedAlg(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -316,11 +345,15 @@ func TestVerifyTamperedAlg(t *testing.T) {
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
jws2.Header.Alg = "none" // tamper
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail for tampered alg")
// Replace the protected header with one that has alg:"none".
// The original ES256 signature stays — the signing input will mismatch.
noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`))
parts := strings.SplitN(token, ".", 3)
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
if _, err := iss.UnsafeVerify(tamperedToken); err == nil {
t.Fatal("expected UnsafeVerify to fail for tampered alg")
}
}
@ -408,3 +441,79 @@ func TestDecodePublicJWKJSON(t *testing.T) {
t.Errorf("expected 1 RSA key, got %d", rsaCount)
}
}
// TestThumbprint verifies that Thumbprint returns a non-empty base64url string
// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint.
func TestThumbprint(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
tests := []struct {
name string
jwk ajwt.PublicJWK
}{
{"EC P-256", ajwt.PublicJWK{Key: &ecKey.PublicKey}},
{"RSA 2048", ajwt.PublicJWK{Key: &rsaKey.PublicKey}},
{"Ed25519", ajwt.PublicJWK{Key: edPub}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
thumb, err := tt.jwk.Thumbprint()
if err != nil {
t.Fatalf("Thumbprint() error: %v", err)
}
if thumb == "" {
t.Fatal("Thumbprint() returned empty string")
}
// Must be valid base64url (no padding, no +/)
if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") {
t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb)
}
// Same key → same thumbprint
thumb2, _ := tt.jwk.Thumbprint()
if thumb != thumb2 {
t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2)
}
})
}
}
// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets
// its KID auto-populated from the RFC 7638 thumbprint.
func TestNoKidAutoThumbprint(t *testing.T) {
// EC key with no "kid" field in the JWKS
jwksJSON := []byte(`{"keys":[
{"kty":"EC","crv":"P-256",
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"use":"sig"}
]}`)
keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON)
if err != nil {
t.Fatal(err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
if keys[0].KID == "" {
t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS")
}
// The auto-KID should be a valid base64url string.
kid := keys[0].KID
if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") {
t.Errorf("auto-KID contains non-base64url characters: %s", kid)
}
// Round-trip: compute Thumbprint directly and compare.
thumb, err := keys[0].Thumbprint()
if err != nil {
t.Fatalf("Thumbprint() error: %v", err)
}
if kid != thumb {
t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb)
}
}

View File

@ -9,11 +9,13 @@
package ajwt
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
@ -21,6 +23,7 @@ import (
"math/big"
"net/http"
"os"
"strings"
"time"
)
@ -54,6 +57,90 @@ func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) {
return key, ok
}
// Thumbprint computes the RFC 7638 JWK Thumbprint (SHA-256 of the canonical
// key JSON with fields in lexicographic order). The result is base64url-encoded.
//
// Canonical forms per RFC 7638:
// - EC: {"crv":…, "kty":"EC", "x":…, "y":…}
// - RSA: {"e":…, "kty":"RSA", "n":…}
// - OKP: {"crv":"Ed25519", "kty":"OKP", "x":…}
//
// Use Thumbprint as KID when none is provided in the JWKS source.
func (k PublicJWK) Thumbprint() (string, error) {
var canonical []byte
var err error
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
byteLen := (key.Curve.Params().BitSize + 7) / 8
xBytes := make([]byte, byteLen)
yBytes := make([]byte, byteLen)
key.X.FillBytes(xBytes)
key.Y.FillBytes(yBytes)
var crv string
switch key.Curve {
case elliptic.P256():
crv = "P-256"
case elliptic.P384():
crv = "P-384"
case elliptic.P521():
crv = "P-521"
default:
return "", fmt.Errorf("Thumbprint: unsupported EC curve %s", key.Curve.Params().Name)
}
// Fields in lexicographic order: crv, kty, x, y
canonical, err = json.Marshal(struct {
Crv string `json:"crv"`
Kty string `json:"kty"`
X string `json:"x"`
Y string `json:"y"`
}{
Crv: crv,
Kty: "EC",
X: base64.RawURLEncoding.EncodeToString(xBytes),
Y: base64.RawURLEncoding.EncodeToString(yBytes),
})
case *rsa.PublicKey:
eInt := big.NewInt(int64(key.E))
// Fields in lexicographic order: e, kty, n
canonical, err = json.Marshal(struct {
E string `json:"e"`
Kty string `json:"kty"`
N string `json:"n"`
}{
E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()),
Kty: "RSA",
N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()),
})
case ed25519.PublicKey:
// Fields in lexicographic order: crv, kty, x
canonical, err = json.Marshal(struct {
Crv string `json:"crv"`
Kty string `json:"kty"`
X string `json:"x"`
}{
Crv: "Ed25519",
Kty: "OKP",
X: base64.RawURLEncoding.EncodeToString([]byte(key)),
})
default:
return "", fmt.Errorf("Thumbprint: unsupported key type %T", k.Key)
}
if err != nil {
return "", fmt.Errorf("Thumbprint: marshal canonical JSON: %w", err)
}
sum := sha256.Sum256(canonical)
return base64.RawURLEncoding.EncodeToString(sum[:]), nil
}
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
type PublicJWKJSON struct {
Kty string `json:"kty"`
@ -71,24 +158,83 @@ type JWKsJSON struct {
Keys []PublicJWKJSON `json:"keys"`
}
// FetchPublicJWKs retrieves and parses a JWKS document from url.
// FetchJWKs retrieves and parses a JWKS document from jwksURL.
//
// For issuer-scoped key management with context support, use
// [Issuer.FetchKeys] instead.
func FetchPublicJWKs(url string) ([]PublicJWK, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
// ctx is used for the HTTP request timeout and cancellation.
func FetchJWKs(ctx context.Context, jwksURL string) ([]PublicJWK, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
return nil, fmt.Errorf("fetch JWKS: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return nil, fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
}
return DecodePublicJWKs(resp.Body)
}
// FetchJWKsFromOIDC fetches JWKS via OIDC discovery from baseURL.
//
// It fetches {baseURL}/.well-known/openid-configuration and reads the jwks_uri field.
func FetchJWKsFromOIDC(ctx context.Context, baseURL string) ([]PublicJWK, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
return keys, err
}
// FetchJWKsFromOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata (RFC 8414)
// from baseURL.
//
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the jwks_uri field.
func FetchJWKsFromOAuth2(ctx context.Context, baseURL string) ([]PublicJWK, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
return keys, err
}
// fetchJWKsFromDiscovery fetches a discovery document from discoveryURL, then
// fetches the JWKS from the jwks_uri field. Returns the keys and the issuer
// URL from the discovery document's "issuer" field.
func fetchJWKsFromDiscovery(ctx context.Context, discoveryURL string) ([]PublicJWK, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
if err != nil {
return nil, "", fmt.Errorf("fetch discovery: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, "", fmt.Errorf("fetch discovery: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("fetch discovery: unexpected status %d", resp.StatusCode)
}
var doc struct {
Issuer string `json:"issuer"`
JWKsURI string `json:"jwks_uri"`
}
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return nil, "", fmt.Errorf("parse discovery doc: %w", err)
}
if doc.JWKsURI == "" {
return nil, "", fmt.Errorf("discovery doc missing jwks_uri field")
}
keys, err := FetchJWKs(ctx, doc.JWKsURI)
if err != nil {
return nil, "", err
}
return keys, doc.Issuer, nil
}
// ReadPublicJWKs reads and parses a JWKS document from a file path.
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
file, err := os.Open(filePath)
@ -118,6 +264,9 @@ func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
}
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
//
// If a key has no kid field in the source document, the KID is auto-populated
// from [PublicJWK.Thumbprint] per RFC 7638.
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
var keys []PublicJWK
for _, jwk := range jwks.Keys {
@ -125,6 +274,12 @@ func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err)
}
if key.KID == "" {
key.KID, err = key.Thumbprint()
if err != nil {
return nil, fmt.Errorf("compute thumbprint for kid-less key: %w", err)
}
}
keys = append(keys, *key)
}
if len(keys) == 0 {