ajwt: implement redesigned API from REDESIGN.md

Rename ValidateParams → Validator, make Issuer immutable after construction.

Key changes:
- StandardClaims.GetStandardClaims() + StandardClaimsSource interface: any
  struct embedding StandardClaims satisfies the interface for free via
  Go's method promotion — zero boilerplate for callers
- Issuer is now immutable after construction; keys and validator are
  unexported; Params field removed
- New constructors: New, NewWithJWKs, NewWithOIDC, NewWithOAuth2
- UnsafeVerify(tokenStr string) (*JWS, error): Decode + sig verify + iss
  check; "unsafe" means exp/aud/etc. are NOT checked
- VerifyAndValidate(tokenStr, claims, now): full pipeline requiring non-nil
  Validator; fails loudly with nil Validator
- FetchJWKs(ctx, url), FetchJWKsFromOIDC(ctx, base),
  FetchJWKsFromOAuth2(ctx, base): standalone fetch functions with context
- PublicJWK.Thumbprint(): RFC 7638 SHA-256 thumbprint, canonical field
  ordering per spec (EC: crv/kty/x/y, RSA: e/kty/n, OKP: crv/kty/x)
- DecodePublicJWKsJSON: auto-populates KID from Thumbprint when absent
- Tests: 14 pass, covering VerifyAndValidate, UnsafeVerify, nil-validator
  error, all alg round trips, tampered alg, Thumbprint, auto-KID
This commit is contained in:
AJ ONeal 2026-03-13 10:28:47 -06:00
parent 1f0b36fc6d
commit 3f7985317f
No known key found for this signature in database
3 changed files with 554 additions and 223 deletions

View File

@ -12,25 +12,33 @@
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag. // - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
// - [Issuer] owns key management and signature verification, centralizing // - [Issuer] owns key management and signature verification, centralizing
// the key lookup → sig verify → iss check sequence. // the key lookup → sig verify → iss check sequence.
// - [ValidateParams] is a stable config object; time is passed at the call // - [Validator] is a stable config object; time is passed at the call site
// site so the same params can be reused across requests. // so the same validator can be reused across requests.
// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims].
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement. // - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384), // - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037). // ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
// //
// Typical usage: // Typical usage with VerifyAndValidate:
// //
// // At startup: // // At startup:
// iss := ajwt.NewIssuer("https://accounts.example.com") // iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com",
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true} // &ajwt.Validator{Aud: "my-app", IgnoreIss: true})
// if err := iss.FetchKeys(ctx); err != nil { ... }
// //
// // Per request: // // Per request:
// jws, err := ajwt.Decode(tokenStr)
// if err := iss.Verify(jws); err != nil { ... } // sig + iss check
// var claims AppClaims // var claims AppClaims
// if err := jws.UnmarshalClaims(&claims); err != nil { ... } // jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... } // if err != nil { /* hard error: bad sig, expired, etc. */ }
// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ }
//
// Typical usage with UnsafeVerify (custom validation only):
//
// iss := ajwt.New("https://example.com", keys, nil)
// jws, err := iss.UnsafeVerify(tokenStr)
// var claims AppClaims
// jws.UnmarshalClaims(&claims)
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims,
// ajwt.Validator{Aud: "myapp"}, time.Now())
package ajwt package ajwt
import ( import (
@ -48,7 +56,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/big" "math/big"
"net/http"
"slices" "slices"
"strings" "strings"
"time" "time"
@ -58,10 +65,10 @@ import (
// //
// It holds only the parsed structure — header, raw base64url fields, and // It holds only the parsed structure — header, raw base64url fields, and
// decoded signature bytes. It carries no Claims interface and no Verified flag; // decoded signature bytes. It carries no Claims interface and no Verified flag;
// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to // use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the
// decode the payload into a typed struct. // token and [JWS.UnmarshalClaims] to decode the payload into a typed struct.
type JWS struct { type JWS struct {
Protected string // base64url-encoded header Protected string // base64url-encoded header
Header StandardHeader Header StandardHeader
Payload string // base64url-encoded claims Payload string // base64url-encoded claims
Signature []byte Signature []byte
@ -77,12 +84,15 @@ type StandardHeader struct {
// StandardClaims holds the registered JWT claim names defined in RFC 7519 // StandardClaims holds the registered JWT claim names defined in RFC 7519
// and extended by OpenID Connect Core. // and extended by OpenID Connect Core.
// //
// Embed StandardClaims in your own claims struct: // Embed StandardClaims in your own claims struct to satisfy [StandardClaimsSource]
// for free via Go's method promotion — zero boilerplate:
// //
// type AppClaims struct { // type AppClaims struct {
// ajwt.StandardClaims // ajwt.StandardClaims // promotes GetStandardClaims()
// Email string `json:"email"` // Email string `json:"email"`
// Roles []string `json:"roles"`
// } // }
// // AppClaims now satisfies StandardClaimsSource automatically.
type StandardClaims struct { type StandardClaims struct {
Iss string `json:"iss"` Iss string `json:"iss"`
Sub string `json:"sub"` Sub string `json:"sub"`
@ -96,10 +106,24 @@ type StandardClaims struct {
Jti string `json:"jti"` Jti string `json:"jti"`
} }
// GetStandardClaims implements [StandardClaimsSource].
// Any struct embedding StandardClaims gets this method for free via promotion.
func (sc StandardClaims) GetStandardClaims() StandardClaims { return sc }
// StandardClaimsSource is implemented for free by any struct that embeds [StandardClaims].
//
// type AppClaims struct {
// ajwt.StandardClaims // promotes GetStandardClaims() — zero boilerplate
// Email string `json:"email"`
// }
type StandardClaimsSource interface {
GetStandardClaims() StandardClaims
}
// Decode parses a compact JWT string (header.payload.signature) into a JWS. // Decode parses a compact JWT string (header.payload.signature) into a JWS.
// //
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after // It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
// [Issuer.Verify] to populate a typed claims struct. // [Issuer.UnsafeVerify] to safely populate a typed claims struct.
func Decode(tokenStr string) (*JWS, error) { func Decode(tokenStr string) (*JWS, error) {
parts := strings.Split(tokenStr, ".") parts := strings.Split(tokenStr, ".")
if len(parts) != 3 { if len(parts) != 3 {
@ -128,7 +152,7 @@ func Decode(tokenStr string) (*JWS, error) {
// UnmarshalClaims decodes the JWT payload into v. // UnmarshalClaims decodes the JWT payload into v.
// //
// v must be a pointer to a struct (e.g. *AppClaims). Always call // v must be a pointer to a struct (e.g. *AppClaims). Always call
// [Issuer.Verify] before UnmarshalClaims to ensure the signature is // [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is
// authenticated before trusting the payload. // authenticated before trusting the payload.
func (jws *JWS) UnmarshalClaims(v any) error { func (jws *JWS) UnmarshalClaims(v any) error {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
@ -232,16 +256,16 @@ func (jws *JWS) Encode() string {
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature) return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
} }
// ValidateParams holds claim validation configuration. // Validator holds claim validation configuration.
// //
// Configure once at startup; call [ValidateParams.Validate] per request, // Configure once at startup; call [Validator.Validate] per request, passing
// passing the current time. This keeps the config stable and makes the // the current time. This keeps the config stable and makes the time dependency
// time dependency explicit at the call site. // explicit at the call site.
// //
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken // https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type ValidateParams struct { type Validator struct {
IgnoreIss bool IgnoreIss bool
Iss string Iss string // rarely needed — Issuer.UnsafeVerify already checks iss
IgnoreSub bool IgnoreSub bool
Sub string Sub string
IgnoreAud bool IgnoreAud bool
@ -260,53 +284,53 @@ type ValidateParams struct {
Azp string Azp string
} }
// Validate checks the standard JWT/OIDC claim fields against this config. // Validate checks the standard JWT/OIDC claim fields in claims against this config.
// //
// now is typically time.Now() — passing it explicitly keeps the config stable // now is typically time.Now() — passing it explicitly keeps the config stable
// across requests and avoids hidden time dependencies in the params struct. // across requests and avoids hidden time dependencies in the validator struct.
func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) { func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
return ValidateStandardClaims(claims, p, now) return ValidateStandardClaims(claims.GetStandardClaims(), v, now)
} }
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params. // ValidateStandardClaims checks the registered JWT/OIDC claim fields against v.
// //
// Exported so callers can use it directly without a [ValidateParams] receiver: // Exported so callers can use it directly without a [Validator] receiver:
// //
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now()) // errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, v, time.Now())
func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) { func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ([]string, error) {
var errs []string var errs []string
// Required to exist and match // Required to exist and match
if len(params.Iss) > 0 || !params.IgnoreIss { if len(v.Iss) > 0 || !v.IgnoreIss {
if len(claims.Iss) == 0 { if len(claims.Iss) == 0 {
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)") errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
} else if claims.Iss != params.Iss { } else if claims.Iss != v.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss)) errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss))
} }
} }
// Required to exist, optional match // Required to exist, optional match
if len(claims.Sub) == 0 { if len(claims.Sub) == 0 {
if !params.IgnoreSub { if !v.IgnoreSub {
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)") errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
} }
} else if len(params.Sub) > 0 { } else if len(v.Sub) > 0 {
if params.Sub != claims.Sub { if v.Sub != claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub)) errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, v.Sub))
} }
} }
// Required to exist and match // Required to exist and match
if len(params.Aud) > 0 || !params.IgnoreAud { if len(v.Aud) > 0 || !v.IgnoreAud {
if len(claims.Aud) == 0 { if len(claims.Aud) == 0 {
errs = append(errs, "missing or malformed 'aud' (audience receiving token)") errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
} else if claims.Aud != params.Aud { } else if claims.Aud != v.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud)) errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
} }
} }
// Required to exist and not be in the past // Required to exist and not be in the past
if !params.IgnoreExp { if !v.IgnoreExp {
if claims.Exp <= 0 { if claims.Exp <= 0 {
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)") errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
} else if claims.Exp < now.Unix() { } else if claims.Exp < now.Unix() {
@ -317,7 +341,7 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
} }
// Required to exist and not be in the future // Required to exist and not be in the future
if !params.IgnoreIat { if !v.IgnoreIat {
if claims.Iat <= 0 { if claims.Iat <= 0 {
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)") errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
} else if claims.Iat > now.Unix() { } else if claims.Iat > now.Unix() {
@ -328,53 +352,53 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
} }
// Should exist, in the past, with optional max age // Should exist, in the past, with optional max age
if params.MaxAge > 0 || !params.IgnoreAuthTime { if v.MaxAge > 0 || !v.IgnoreAuthTime {
if claims.AuthTime == 0 { if claims.AuthTime == 0 {
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)") errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
} else { } else {
authTime := time.Unix(claims.AuthTime, 0) authTime := time.Unix(claims.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST") authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
age := now.Sub(authTime) age := now.Sub(authTime)
diff := age - params.MaxAge diff := age - v.MaxAge
if claims.AuthTime > now.Unix() { if claims.AuthTime > now.Unix() {
fromNow := time.Unix(claims.AuthTime, 0).Sub(now) fromNow := time.Unix(claims.AuthTime, 0).Sub(now)
errs = append(errs, fmt.Sprintf( errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s in the future (server time %s)", "'auth_time' of %s is %s in the future (server time %s)",
authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")), authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")),
) )
} else if params.MaxAge > 0 && age > params.MaxAge { } else if v.MaxAge > 0 && age > v.MaxAge {
errs = append(errs, fmt.Sprintf( errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s old, exceeding max age %s by %s", "'auth_time' of %s is %s old, exceeding max age %s by %s",
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)), authTimeStr, formatDuration(age), formatDuration(v.MaxAge), formatDuration(diff)),
) )
} }
} }
} }
// Optional exact match // Optional exact match
if params.Jti != claims.Jti { if v.Jti != claims.Jti {
if len(params.Jti) > 0 { if len(v.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti)) errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
} else if !params.IgnoreJti { } else if !v.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti)) errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
} }
} }
// Optional exact match // Optional exact match
if params.Nonce != claims.Nonce { if v.Nonce != claims.Nonce {
if len(params.Nonce) > 0 { if len(v.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce)) errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
} else if !params.IgnoreNonce { } else if !v.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce)) errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
} }
} }
// Should exist, optional required-set check // Should exist, optional required-set check
if !params.IgnoreAmr { if !v.IgnoreAmr {
if len(claims.Amr) == 0 { if len(claims.Amr) == 0 {
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)") errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
} else if len(params.RequiredAmrs) > 0 { } else if len(v.RequiredAmrs) > 0 {
for _, required := range params.RequiredAmrs { for _, required := range v.RequiredAmrs {
if !slices.Contains(claims.Amr, required) { if !slices.Contains(claims.Amr, required) {
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required)) errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
} }
@ -383,10 +407,10 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
} }
// Optional, match if present // Optional, match if present
if params.Azp != claims.Azp { if v.Azp != claims.Azp {
if len(params.Azp) > 0 { if len(v.Azp) > 0 {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp)) errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
} else if !params.IgnoreAzp { } else if !v.IgnoreAzp {
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp)) errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
} }
} }
@ -402,111 +426,154 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
return nil, nil return nil, nil
} }
// Issuer holds public keys and validation config for a trusted token issuer. // Issuer holds public keys and optional validation config for a trusted token issuer.
// //
// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint. // Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2].
// [Issuer.SetKeys] injects keys directly (useful in tests). // After construction, Issuer is immutable.
// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check.
// //
// Typical setup: // [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check.
// // [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator.
// iss := ajwt.NewIssuer("https://accounts.example.com")
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
// if err := iss.FetchKeys(ctx); err != nil { ... }
type Issuer struct { type Issuer struct {
URL string URL string // issuer URL for iss claim enforcement; empty skips the check
JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json" validator *Validator
Params ValidateParams keys map[string]crypto.PublicKey // kid → key
keys map[string]crypto.PublicKey // kid → key
} }
// NewIssuer creates an Issuer for the given base URL. // New creates an Issuer with explicit keys.
func NewIssuer(url string) *Issuer { //
return &Issuer{ // v is optional — pass nil to use [Issuer.UnsafeVerify] only.
URL: url, // [Issuer.VerifyAndValidate] requires a non-nil Validator.
keys: make(map[string]crypto.PublicKey), func New(issURL string, keys []PublicJWK, v *Validator) *Issuer {
}
}
// SetKeys stores public keys by their KID, replacing any previously stored keys.
// Useful for injecting keys in tests without an HTTP round-trip.
func (iss *Issuer) SetKeys(keys []PublicJWK) {
m := make(map[string]crypto.PublicKey, len(keys)) m := make(map[string]crypto.PublicKey, len(keys))
for _, k := range keys { for _, k := range keys {
m[k.KID] = k.Key m[k.KID] = k.Key
} }
iss.keys = m return &Issuer{
URL: issURL,
validator: v,
keys: m,
}
} }
// FetchKeys retrieves and stores the JWKS from the issuer's endpoint. // NewWithJWKs creates an Issuer by fetching keys from jwksURL.
// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json".
func (iss *Issuer) FetchKeys(ctx context.Context) error {
url := iss.JWKsURL
if url == "" {
url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json"
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
}
keys, err := DecodePublicJWKs(resp.Body)
if err != nil {
return fmt.Errorf("parse JWKS: %w", err)
}
iss.SetKeys(keys)
return nil
}
// Verify authenticates jws against this issuer:
// 1. Looks up the signing key by jws.Header.Kid.
// 2. Verifies the signature before trusting any payload data.
// 3. Checks that the token's "iss" claim matches iss.URL.
// //
// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a // The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is
// typed struct, then [ValidateParams.Validate] to check claim values. // not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement.
func (iss *Issuer) Verify(jws *JWS) error { //
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) {
keys, err := FetchJWKs(ctx, jwksURL)
if err != nil {
return nil, err
}
return New("", keys, v), nil
}
// NewWithOIDC creates an Issuer using OIDC discovery.
//
// It fetches {baseURL}/.well-known/openid-configuration and reads the
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
// document's issuer field (not baseURL) because OIDC requires them to match.
//
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
if err != nil {
return nil, err
}
return New(issURL, keys, v), nil
}
// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414).
//
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
// document's issuer field.
//
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
if err != nil {
return nil, err
}
return New(issURL, keys, v), nil
}
// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL]
// is set) checks the iss claim.
//
// "Unsafe" means exp, aud, and other claim values are NOT checked — the token
// is forgery-safe but not semantically validated. Callers are responsible for
// validating claim values, or use [Issuer.VerifyAndValidate].
func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
jws, err := Decode(tokenStr)
if err != nil {
return nil, err
}
if jws.Header.Kid == "" { if jws.Header.Kid == "" {
return fmt.Errorf("missing 'kid' header") return nil, fmt.Errorf("missing 'kid' header")
} }
key, ok := iss.keys[jws.Header.Kid] key, ok := iss.keys[jws.Header.Kid]
if !ok { if !ok {
return fmt.Errorf("unknown kid: %q", jws.Header.Kid) return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid)
} }
signingInput := jws.Protected + "." + jws.Payload signingInput := jws.Protected + "." + jws.Payload
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil { if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
return fmt.Errorf("signature verification failed: %w", err) return nil, fmt.Errorf("signature verification failed: %w", err)
} }
// Signature verified — now safe to inspect the payload. // Signature verified — now safe to inspect the payload for iss check.
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) if iss.URL != "" {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, fmt.Errorf("invalid claims encoding: %w", err)
}
var partial struct {
Iss string `json:"iss"`
}
if err := json.Unmarshal(payload, &partial); err != nil {
return nil, fmt.Errorf("invalid claims JSON: %w", err)
}
if partial.Iss != iss.URL {
return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
}
}
return jws, nil
}
// VerifyAndValidate verifies the token signature and iss, unmarshals the claims
// into claims, and runs the [Validator].
//
// Returns a hard error (err != nil) for signature failures, decoding errors,
// and nil Validator. Returns soft errors (errs != nil) for claim validation
// failures (wrong aud, expired token, etc.).
//
// claims must be a pointer whose underlying type embeds [StandardClaims] (or
// otherwise implements [StandardClaimsSource]):
//
// var claims AppClaims
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) {
if iss.validator == nil {
return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification")
}
jws, err := iss.UnsafeVerify(tokenStr)
if err != nil { if err != nil {
return fmt.Errorf("invalid claims encoding: %w", err) return nil, nil, err
}
var partial struct {
Iss string `json:"iss"`
}
if err := json.Unmarshal(payload, &partial); err != nil {
return fmt.Errorf("invalid claims JSON: %w", err)
}
if partial.Iss != iss.URL {
return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
} }
return nil if err := jws.UnmarshalClaims(claims); err != nil {
return nil, nil, err
}
errs, err := iss.validator.Validate(claims, now)
return jws, errs, err
} }
// verifyWith checks a JWS signature using the given algorithm and public key. // verifyWith checks a JWS signature using the given algorithm and public key.

View File

@ -14,7 +14,9 @@ import (
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64"
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
@ -23,9 +25,8 @@ import (
// AppClaims embeds StandardClaims and adds application-specific fields. // AppClaims embeds StandardClaims and adds application-specific fields.
// //
// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate // Because StandardClaims is embedded, AppClaims satisfies StandardClaimsSource
// interface — there is none. Validation is explicit: call // for free via Go's method promotion — no interface to implement.
// ValidateStandardClaims or ValidateParams.Validate at the call site.
type AppClaims struct { type AppClaims struct {
ajwt.StandardClaims ajwt.StandardClaims
Email string `json:"email"` Email string `json:"email"`
@ -33,9 +34,10 @@ type AppClaims struct {
} }
// validateAppClaims is a plain function — not a method satisfying an interface. // validateAppClaims is a plain function — not a method satisfying an interface.
// Custom validation logic lives here, calling ValidateStandardClaims directly. // It demonstrates the UnsafeVerify pattern: custom validation logic lives here,
func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) { // calling ValidateStandardClaims directly and adding app-specific checks.
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now) func validateAppClaims(c AppClaims, v ajwt.Validator, now time.Time) ([]string, error) {
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, v, now)
if c.Email == "" { if c.Email == "" {
errs = append(errs, "missing email claim") errs = append(errs, "missing email claim")
} }
@ -65,11 +67,11 @@ func goodClaims() AppClaims {
} }
} }
// goodParams configures the validator. Iss is omitted because Issuer.Verify // goodValidator configures the validator. IgnoreIss is true because
// already enforces the iss claim — no need to check it twice. // Issuer.UnsafeVerify already enforces the iss claim — no need to check twice.
func goodParams() ajwt.ValidateParams { func goodValidator() *ajwt.Validator {
return ajwt.ValidateParams{ return &ajwt.Validator{
IgnoreIss: true, // Issuer.Verify handles iss IgnoreIss: true, // UnsafeVerify handles iss
Sub: "user123", Sub: "user123",
Aud: "myapp", Aud: "myapp",
Jti: "abc123", Jti: "abc123",
@ -80,16 +82,13 @@ func goodParams() ajwt.ValidateParams {
} }
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer { func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
iss := ajwt.NewIssuer("https://example.com") return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator())
iss.Params = goodParams()
iss.SetKeys([]ajwt.PublicJWK{pub})
return iss
} }
// TestRoundTrip is the primary happy path using ES256. // TestRoundTrip is the primary happy path using ES256.
// It demonstrates the full Issuer-based flow: // It demonstrates the full VerifyAndValidate flow:
// //
// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate // New → VerifyAndValidate → custom claim access
// //
// No Claims interface, no Verified flag, no type assertions on jws. // No Claims interface, no Verified flag, no type assertions on jws.
func TestRoundTrip(t *testing.T) { func TestRoundTrip(t *testing.T) {
@ -115,20 +114,16 @@ func TestRoundTrip(t *testing.T) {
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil { jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
t.Fatal(err) if err != nil {
t.Fatalf("VerifyAndValidate failed: %v", err)
} }
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { if len(errs) > 0 {
t.Fatalf("validation failed: %v", errs) t.Fatalf("claim validation failed: %v", errs)
}
if jws2.Header.Alg != "ES256" {
t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg)
} }
// Direct field access — no type assertion needed, no jws.Claims interface. // Direct field access — no type assertion needed, no jws.Claims interface.
if decoded.Email != claims.Email { if decoded.Email != claims.Email {
@ -160,20 +155,13 @@ func TestRoundTripRS256(t *testing.T) {
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil { _, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
t.Fatal(err) if err != nil {
t.Fatalf("VerifyAndValidate failed: %v", err)
} }
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { if len(errs) > 0 {
t.Fatalf("validation failed: %v", errs) t.Fatalf("claim validation failed: %v", errs)
} }
} }
@ -201,25 +189,47 @@ func TestRoundTripEdDSA(t *testing.T) {
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"}) iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil { _, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
t.Fatal(err) if err != nil {
t.Fatalf("VerifyAndValidate failed: %v", err)
} }
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { if len(errs) > 0 {
t.Fatalf("validation failed: %v", errs) t.Fatalf("claim validation failed: %v", errs)
} }
} }
// TestCustomValidation demonstrates custom claim validation without any interface. // TestUnsafeVerifyFlow demonstrates the UnsafeVerify + custom validation pattern.
// The caller owns the validation logic and calls ValidateStandardClaims directly. // The caller owns the full validation pipeline.
func TestUnsafeVerifyFlow(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
// Create issuer without validator — UnsafeVerify only.
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
jws2, err := iss.UnsafeVerify(token)
if err != nil {
t.Fatalf("UnsafeVerify failed: %v", err)
}
var decoded AppClaims
if err := jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatalf("UnmarshalClaims failed: %v", err)
}
errs, err := ajwt.ValidateStandardClaims(decoded.StandardClaims, *goodValidator(), time.Now())
if err != nil {
t.Fatalf("ValidateStandardClaims failed: %v — errs: %v", err, errs)
}
}
// TestCustomValidation demonstrates that ValidateStandardClaims is called
// explicitly and custom fields are validated without any Claims interface.
func TestCustomValidation(t *testing.T) { func TestCustomValidation(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -231,12 +241,15 @@ func TestCustomValidation(t *testing.T) {
token := jws.Encode() token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token) jws2, err := iss.UnsafeVerify(token)
_ = iss.Verify(jws2) if err != nil {
t.Fatalf("UnsafeVerify failed unexpectedly: %v", err)
}
var decoded AppClaims var decoded AppClaims
_ = jws2.UnmarshalClaims(&decoded) _ = jws2.UnmarshalClaims(&decoded)
errs, err := validateAppClaims(decoded, goodParams(), time.Now()) errs, err := validateAppClaims(decoded, *goodValidator(), time.Now())
if err == nil { if err == nil {
t.Fatal("expected validation to fail: email is empty") t.Fatal("expected validation to fail: email is empty")
} }
@ -251,6 +264,23 @@ func TestCustomValidation(t *testing.T) {
} }
} }
// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly
// when no Validator was provided at construction time.
func TestVerifyAndValidateNilValidator(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
c := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&c, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
var claims AppClaims
if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil {
t.Fatal("expected VerifyAndValidate to error with nil validator")
}
}
// TestIssuerWrongKey confirms that a different key's public key is rejected. // TestIssuerWrongKey confirms that a different key's public key is rejected.
func TestIssuerWrongKey(t *testing.T) { func TestIssuerWrongKey(t *testing.T) {
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -262,10 +292,9 @@ func TestIssuerWrongKey(t *testing.T) {
token := jws.Encode() token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"}) iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil { if _, err := iss.UnsafeVerify(token); err == nil {
t.Fatal("expected Verify to fail with wrong key") t.Fatal("expected UnsafeVerify to fail with wrong key")
} }
} }
@ -279,10 +308,9 @@ func TestIssuerUnknownKid(t *testing.T) {
token := jws.Encode() token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"}) iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil { if _, err := iss.UnsafeVerify(token); err == nil {
t.Fatal("expected Verify to fail for unknown kid") t.Fatal("expected UnsafeVerify to fail for unknown kid")
} }
} }
@ -299,14 +327,15 @@ func TestIssuerIssMismatch(t *testing.T) {
// Issuer expects "https://example.com" but token says "https://evil.example.com" // Issuer expects "https://example.com" but token says "https://evil.example.com"
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil { if _, err := iss.UnsafeVerify(token); err == nil {
t.Fatal("expected Verify to fail: iss mismatch") t.Fatal("expected UnsafeVerify to fail: iss mismatch")
} }
} }
// TestVerifyTamperedAlg confirms that a tampered alg header is rejected. // TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected.
// The token is reconstructed with a replaced protected header; the original
// ES256 signature is kept, making the signing input mismatch detectable.
func TestVerifyTamperedAlg(t *testing.T) { func TestVerifyTamperedAlg(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -316,11 +345,15 @@ func TestVerifyTamperedAlg(t *testing.T) {
token := jws.Encode() token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
jws2.Header.Alg = "none" // tamper
if err := iss.Verify(jws2); err == nil { // Replace the protected header with one that has alg:"none".
t.Fatal("expected Verify to fail for tampered alg") // The original ES256 signature stays — the signing input will mismatch.
noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`))
parts := strings.SplitN(token, ".", 3)
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
if _, err := iss.UnsafeVerify(tamperedToken); err == nil {
t.Fatal("expected UnsafeVerify to fail for tampered alg")
} }
} }
@ -408,3 +441,79 @@ func TestDecodePublicJWKJSON(t *testing.T) {
t.Errorf("expected 1 RSA key, got %d", rsaCount) t.Errorf("expected 1 RSA key, got %d", rsaCount)
} }
} }
// TestThumbprint verifies that Thumbprint returns a non-empty base64url string
// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint.
func TestThumbprint(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
tests := []struct {
name string
jwk ajwt.PublicJWK
}{
{"EC P-256", ajwt.PublicJWK{Key: &ecKey.PublicKey}},
{"RSA 2048", ajwt.PublicJWK{Key: &rsaKey.PublicKey}},
{"Ed25519", ajwt.PublicJWK{Key: edPub}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
thumb, err := tt.jwk.Thumbprint()
if err != nil {
t.Fatalf("Thumbprint() error: %v", err)
}
if thumb == "" {
t.Fatal("Thumbprint() returned empty string")
}
// Must be valid base64url (no padding, no +/)
if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") {
t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb)
}
// Same key → same thumbprint
thumb2, _ := tt.jwk.Thumbprint()
if thumb != thumb2 {
t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2)
}
})
}
}
// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets
// its KID auto-populated from the RFC 7638 thumbprint.
func TestNoKidAutoThumbprint(t *testing.T) {
// EC key with no "kid" field in the JWKS
jwksJSON := []byte(`{"keys":[
{"kty":"EC","crv":"P-256",
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"use":"sig"}
]}`)
keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON)
if err != nil {
t.Fatal(err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
if keys[0].KID == "" {
t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS")
}
// The auto-KID should be a valid base64url string.
kid := keys[0].KID
if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") {
t.Errorf("auto-KID contains non-base64url characters: %s", kid)
}
// Round-trip: compute Thumbprint directly and compare.
thumb, err := keys[0].Thumbprint()
if err != nil {
t.Fatalf("Thumbprint() error: %v", err)
}
if kid != thumb {
t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb)
}
}

View File

@ -9,11 +9,13 @@
package ajwt package ajwt
import ( import (
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -21,6 +23,7 @@ import (
"math/big" "math/big"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
) )
@ -54,6 +57,90 @@ func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) {
return key, ok return key, ok
} }
// Thumbprint computes the RFC 7638 JWK Thumbprint (SHA-256 of the canonical
// key JSON with fields in lexicographic order). The result is base64url-encoded.
//
// Canonical forms per RFC 7638:
// - EC: {"crv":…, "kty":"EC", "x":…, "y":…}
// - RSA: {"e":…, "kty":"RSA", "n":…}
// - OKP: {"crv":"Ed25519", "kty":"OKP", "x":…}
//
// Use Thumbprint as KID when none is provided in the JWKS source.
func (k PublicJWK) Thumbprint() (string, error) {
var canonical []byte
var err error
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
byteLen := (key.Curve.Params().BitSize + 7) / 8
xBytes := make([]byte, byteLen)
yBytes := make([]byte, byteLen)
key.X.FillBytes(xBytes)
key.Y.FillBytes(yBytes)
var crv string
switch key.Curve {
case elliptic.P256():
crv = "P-256"
case elliptic.P384():
crv = "P-384"
case elliptic.P521():
crv = "P-521"
default:
return "", fmt.Errorf("Thumbprint: unsupported EC curve %s", key.Curve.Params().Name)
}
// Fields in lexicographic order: crv, kty, x, y
canonical, err = json.Marshal(struct {
Crv string `json:"crv"`
Kty string `json:"kty"`
X string `json:"x"`
Y string `json:"y"`
}{
Crv: crv,
Kty: "EC",
X: base64.RawURLEncoding.EncodeToString(xBytes),
Y: base64.RawURLEncoding.EncodeToString(yBytes),
})
case *rsa.PublicKey:
eInt := big.NewInt(int64(key.E))
// Fields in lexicographic order: e, kty, n
canonical, err = json.Marshal(struct {
E string `json:"e"`
Kty string `json:"kty"`
N string `json:"n"`
}{
E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()),
Kty: "RSA",
N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()),
})
case ed25519.PublicKey:
// Fields in lexicographic order: crv, kty, x
canonical, err = json.Marshal(struct {
Crv string `json:"crv"`
Kty string `json:"kty"`
X string `json:"x"`
}{
Crv: "Ed25519",
Kty: "OKP",
X: base64.RawURLEncoding.EncodeToString([]byte(key)),
})
default:
return "", fmt.Errorf("Thumbprint: unsupported key type %T", k.Key)
}
if err != nil {
return "", fmt.Errorf("Thumbprint: marshal canonical JSON: %w", err)
}
sum := sha256.Sum256(canonical)
return base64.RawURLEncoding.EncodeToString(sum[:]), nil
}
// PublicJWKJSON is the JSON representation of a single key in a JWKS document. // PublicJWKJSON is the JSON representation of a single key in a JWKS document.
type PublicJWKJSON struct { type PublicJWKJSON struct {
Kty string `json:"kty"` Kty string `json:"kty"`
@ -71,24 +158,83 @@ type JWKsJSON struct {
Keys []PublicJWKJSON `json:"keys"` Keys []PublicJWKJSON `json:"keys"`
} }
// FetchPublicJWKs retrieves and parses a JWKS document from url. // FetchJWKs retrieves and parses a JWKS document from jwksURL.
// //
// For issuer-scoped key management with context support, use // ctx is used for the HTTP request timeout and cancellation.
// [Issuer.FetchKeys] instead. func FetchJWKs(ctx context.Context, jwksURL string) ([]PublicJWK, error) {
func FetchPublicJWKs(url string) ([]PublicJWK, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err) return nil, fmt.Errorf("fetch JWKS: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch JWKS: %w", err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) return nil, fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
} }
return DecodePublicJWKs(resp.Body) return DecodePublicJWKs(resp.Body)
} }
// FetchJWKsFromOIDC fetches JWKS via OIDC discovery from baseURL.
//
// It fetches {baseURL}/.well-known/openid-configuration and reads the jwks_uri field.
func FetchJWKsFromOIDC(ctx context.Context, baseURL string) ([]PublicJWK, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
return keys, err
}
// FetchJWKsFromOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata (RFC 8414)
// from baseURL.
//
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the jwks_uri field.
func FetchJWKsFromOAuth2(ctx context.Context, baseURL string) ([]PublicJWK, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
return keys, err
}
// fetchJWKsFromDiscovery fetches a discovery document from discoveryURL, then
// fetches the JWKS from the jwks_uri field. Returns the keys and the issuer
// URL from the discovery document's "issuer" field.
func fetchJWKsFromDiscovery(ctx context.Context, discoveryURL string) ([]PublicJWK, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
if err != nil {
return nil, "", fmt.Errorf("fetch discovery: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, "", fmt.Errorf("fetch discovery: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("fetch discovery: unexpected status %d", resp.StatusCode)
}
var doc struct {
Issuer string `json:"issuer"`
JWKsURI string `json:"jwks_uri"`
}
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return nil, "", fmt.Errorf("parse discovery doc: %w", err)
}
if doc.JWKsURI == "" {
return nil, "", fmt.Errorf("discovery doc missing jwks_uri field")
}
keys, err := FetchJWKs(ctx, doc.JWKsURI)
if err != nil {
return nil, "", err
}
return keys, doc.Issuer, nil
}
// ReadPublicJWKs reads and parses a JWKS document from a file path. // ReadPublicJWKs reads and parses a JWKS document from a file path.
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) { func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
file, err := os.Open(filePath) file, err := os.Open(filePath)
@ -118,6 +264,9 @@ func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
} }
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys. // DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
//
// If a key has no kid field in the source document, the KID is auto-populated
// from [PublicJWK.Thumbprint] per RFC 7638.
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) { func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
var keys []PublicJWK var keys []PublicJWK
for _, jwk := range jwks.Keys { for _, jwk := range jwks.Keys {
@ -125,6 +274,12 @@ func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err) return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err)
} }
if key.KID == "" {
key.KID, err = key.Thumbprint()
if err != nil {
return nil, fmt.Errorf("compute thumbprint for kid-less key: %w", err)
}
}
keys = append(keys, *key) keys = append(keys, *key)
} }
if len(keys) == 0 { if len(keys) == 0 {