mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
ajwt: implement redesigned API from REDESIGN.md
Rename ValidateParams → Validator, make Issuer immutable after construction. Key changes: - StandardClaims.GetStandardClaims() + StandardClaimsSource interface: any struct embedding StandardClaims satisfies the interface for free via Go's method promotion — zero boilerplate for callers - Issuer is now immutable after construction; keys and validator are unexported; Params field removed - New constructors: New, NewWithJWKs, NewWithOIDC, NewWithOAuth2 - UnsafeVerify(tokenStr string) (*JWS, error): Decode + sig verify + iss check; "unsafe" means exp/aud/etc. are NOT checked - VerifyAndValidate(tokenStr, claims, now): full pipeline requiring non-nil Validator; fails loudly with nil Validator - FetchJWKs(ctx, url), FetchJWKsFromOIDC(ctx, base), FetchJWKsFromOAuth2(ctx, base): standalone fetch functions with context - PublicJWK.Thumbprint(): RFC 7638 SHA-256 thumbprint, canonical field ordering per spec (EC: crv/kty/x/y, RSA: e/kty/n, OKP: crv/kty/x) - DecodePublicJWKsJSON: auto-populates KID from Thumbprint when absent - Tests: 14 pass, covering VerifyAndValidate, UnsafeVerify, nil-validator error, all alg round trips, tampered alg, Thumbprint, auto-KID
This commit is contained in:
parent
1f0b36fc6d
commit
3f7985317f
333
auth/ajwt/jwt.go
333
auth/ajwt/jwt.go
@ -12,25 +12,33 @@
|
||||
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
|
||||
// - [Issuer] owns key management and signature verification, centralizing
|
||||
// the key lookup → sig verify → iss check sequence.
|
||||
// - [ValidateParams] is a stable config object; time is passed at the call
|
||||
// site so the same params can be reused across requests.
|
||||
// - [Validator] is a stable config object; time is passed at the call site
|
||||
// so the same validator can be reused across requests.
|
||||
// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims].
|
||||
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
|
||||
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
|
||||
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
|
||||
//
|
||||
// Typical usage:
|
||||
// Typical usage with VerifyAndValidate:
|
||||
//
|
||||
// // At startup:
|
||||
// iss := ajwt.NewIssuer("https://accounts.example.com")
|
||||
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
|
||||
// if err := iss.FetchKeys(ctx); err != nil { ... }
|
||||
// iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com",
|
||||
// &ajwt.Validator{Aud: "my-app", IgnoreIss: true})
|
||||
//
|
||||
// // Per request:
|
||||
// jws, err := ajwt.Decode(tokenStr)
|
||||
// if err := iss.Verify(jws); err != nil { ... } // sig + iss check
|
||||
// var claims AppClaims
|
||||
// if err := jws.UnmarshalClaims(&claims); err != nil { ... }
|
||||
// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... }
|
||||
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
|
||||
// if err != nil { /* hard error: bad sig, expired, etc. */ }
|
||||
// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ }
|
||||
//
|
||||
// Typical usage with UnsafeVerify (custom validation only):
|
||||
//
|
||||
// iss := ajwt.New("https://example.com", keys, nil)
|
||||
// jws, err := iss.UnsafeVerify(tokenStr)
|
||||
// var claims AppClaims
|
||||
// jws.UnmarshalClaims(&claims)
|
||||
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims,
|
||||
// ajwt.Validator{Aud: "myapp"}, time.Now())
|
||||
package ajwt
|
||||
|
||||
import (
|
||||
@ -48,7 +56,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
@ -58,8 +65,8 @@ import (
|
||||
//
|
||||
// It holds only the parsed structure — header, raw base64url fields, and
|
||||
// decoded signature bytes. It carries no Claims interface and no Verified flag;
|
||||
// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to
|
||||
// decode the payload into a typed struct.
|
||||
// use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the
|
||||
// token and [JWS.UnmarshalClaims] to decode the payload into a typed struct.
|
||||
type JWS struct {
|
||||
Protected string // base64url-encoded header
|
||||
Header StandardHeader
|
||||
@ -77,12 +84,15 @@ type StandardHeader struct {
|
||||
// StandardClaims holds the registered JWT claim names defined in RFC 7519
|
||||
// and extended by OpenID Connect Core.
|
||||
//
|
||||
// Embed StandardClaims in your own claims struct:
|
||||
// Embed StandardClaims in your own claims struct to satisfy [StandardClaimsSource]
|
||||
// for free via Go's method promotion — zero boilerplate:
|
||||
//
|
||||
// type AppClaims struct {
|
||||
// ajwt.StandardClaims
|
||||
// ajwt.StandardClaims // promotes GetStandardClaims()
|
||||
// Email string `json:"email"`
|
||||
// Roles []string `json:"roles"`
|
||||
// }
|
||||
// // AppClaims now satisfies StandardClaimsSource automatically.
|
||||
type StandardClaims struct {
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
@ -96,10 +106,24 @@ type StandardClaims struct {
|
||||
Jti string `json:"jti"`
|
||||
}
|
||||
|
||||
// GetStandardClaims implements [StandardClaimsSource].
|
||||
// Any struct embedding StandardClaims gets this method for free via promotion.
|
||||
func (sc StandardClaims) GetStandardClaims() StandardClaims { return sc }
|
||||
|
||||
// StandardClaimsSource is implemented for free by any struct that embeds [StandardClaims].
|
||||
//
|
||||
// type AppClaims struct {
|
||||
// ajwt.StandardClaims // promotes GetStandardClaims() — zero boilerplate
|
||||
// Email string `json:"email"`
|
||||
// }
|
||||
type StandardClaimsSource interface {
|
||||
GetStandardClaims() StandardClaims
|
||||
}
|
||||
|
||||
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
||||
//
|
||||
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
||||
// [Issuer.Verify] to populate a typed claims struct.
|
||||
// [Issuer.UnsafeVerify] to safely populate a typed claims struct.
|
||||
func Decode(tokenStr string) (*JWS, error) {
|
||||
parts := strings.Split(tokenStr, ".")
|
||||
if len(parts) != 3 {
|
||||
@ -128,7 +152,7 @@ func Decode(tokenStr string) (*JWS, error) {
|
||||
// UnmarshalClaims decodes the JWT payload into v.
|
||||
//
|
||||
// v must be a pointer to a struct (e.g. *AppClaims). Always call
|
||||
// [Issuer.Verify] before UnmarshalClaims to ensure the signature is
|
||||
// [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is
|
||||
// authenticated before trusting the payload.
|
||||
func (jws *JWS) UnmarshalClaims(v any) error {
|
||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||
@ -232,16 +256,16 @@ func (jws *JWS) Encode() string {
|
||||
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
|
||||
}
|
||||
|
||||
// ValidateParams holds claim validation configuration.
|
||||
// Validator holds claim validation configuration.
|
||||
//
|
||||
// Configure once at startup; call [ValidateParams.Validate] per request,
|
||||
// passing the current time. This keeps the config stable and makes the
|
||||
// time dependency explicit at the call site.
|
||||
// Configure once at startup; call [Validator.Validate] per request, passing
|
||||
// the current time. This keeps the config stable and makes the time dependency
|
||||
// explicit at the call site.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||
type ValidateParams struct {
|
||||
type Validator struct {
|
||||
IgnoreIss bool
|
||||
Iss string
|
||||
Iss string // rarely needed — Issuer.UnsafeVerify already checks iss
|
||||
IgnoreSub bool
|
||||
Sub string
|
||||
IgnoreAud bool
|
||||
@ -260,53 +284,53 @@ type ValidateParams struct {
|
||||
Azp string
|
||||
}
|
||||
|
||||
// Validate checks the standard JWT/OIDC claim fields against this config.
|
||||
// Validate checks the standard JWT/OIDC claim fields in claims against this config.
|
||||
//
|
||||
// now is typically time.Now() — passing it explicitly keeps the config stable
|
||||
// across requests and avoids hidden time dependencies in the params struct.
|
||||
func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) {
|
||||
return ValidateStandardClaims(claims, p, now)
|
||||
// across requests and avoids hidden time dependencies in the validator struct.
|
||||
func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
|
||||
return ValidateStandardClaims(claims.GetStandardClaims(), v, now)
|
||||
}
|
||||
|
||||
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params.
|
||||
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against v.
|
||||
//
|
||||
// Exported so callers can use it directly without a [ValidateParams] receiver:
|
||||
// Exported so callers can use it directly without a [Validator] receiver:
|
||||
//
|
||||
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now())
|
||||
func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) {
|
||||
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, v, time.Now())
|
||||
func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ([]string, error) {
|
||||
var errs []string
|
||||
|
||||
// Required to exist and match
|
||||
if len(params.Iss) > 0 || !params.IgnoreIss {
|
||||
if len(v.Iss) > 0 || !v.IgnoreIss {
|
||||
if len(claims.Iss) == 0 {
|
||||
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
|
||||
} else if claims.Iss != params.Iss {
|
||||
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss))
|
||||
} else if claims.Iss != v.Iss {
|
||||
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss))
|
||||
}
|
||||
}
|
||||
|
||||
// Required to exist, optional match
|
||||
if len(claims.Sub) == 0 {
|
||||
if !params.IgnoreSub {
|
||||
if !v.IgnoreSub {
|
||||
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
|
||||
}
|
||||
} else if len(params.Sub) > 0 {
|
||||
if params.Sub != claims.Sub {
|
||||
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub))
|
||||
} else if len(v.Sub) > 0 {
|
||||
if v.Sub != claims.Sub {
|
||||
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, v.Sub))
|
||||
}
|
||||
}
|
||||
|
||||
// Required to exist and match
|
||||
if len(params.Aud) > 0 || !params.IgnoreAud {
|
||||
if len(v.Aud) > 0 || !v.IgnoreAud {
|
||||
if len(claims.Aud) == 0 {
|
||||
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
||||
} else if claims.Aud != params.Aud {
|
||||
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud))
|
||||
} else if claims.Aud != v.Aud {
|
||||
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
|
||||
}
|
||||
}
|
||||
|
||||
// Required to exist and not be in the past
|
||||
if !params.IgnoreExp {
|
||||
if !v.IgnoreExp {
|
||||
if claims.Exp <= 0 {
|
||||
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
|
||||
} else if claims.Exp < now.Unix() {
|
||||
@ -317,7 +341,7 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
||||
}
|
||||
|
||||
// Required to exist and not be in the future
|
||||
if !params.IgnoreIat {
|
||||
if !v.IgnoreIat {
|
||||
if claims.Iat <= 0 {
|
||||
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
|
||||
} else if claims.Iat > now.Unix() {
|
||||
@ -328,53 +352,53 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
||||
}
|
||||
|
||||
// Should exist, in the past, with optional max age
|
||||
if params.MaxAge > 0 || !params.IgnoreAuthTime {
|
||||
if v.MaxAge > 0 || !v.IgnoreAuthTime {
|
||||
if claims.AuthTime == 0 {
|
||||
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
|
||||
} else {
|
||||
authTime := time.Unix(claims.AuthTime, 0)
|
||||
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
|
||||
age := now.Sub(authTime)
|
||||
diff := age - params.MaxAge
|
||||
diff := age - v.MaxAge
|
||||
if claims.AuthTime > now.Unix() {
|
||||
fromNow := time.Unix(claims.AuthTime, 0).Sub(now)
|
||||
errs = append(errs, fmt.Sprintf(
|
||||
"'auth_time' of %s is %s in the future (server time %s)",
|
||||
authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")),
|
||||
)
|
||||
} else if params.MaxAge > 0 && age > params.MaxAge {
|
||||
} else if v.MaxAge > 0 && age > v.MaxAge {
|
||||
errs = append(errs, fmt.Sprintf(
|
||||
"'auth_time' of %s is %s old, exceeding max age %s by %s",
|
||||
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)),
|
||||
authTimeStr, formatDuration(age), formatDuration(v.MaxAge), formatDuration(diff)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optional exact match
|
||||
if params.Jti != claims.Jti {
|
||||
if len(params.Jti) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti))
|
||||
} else if !params.IgnoreJti {
|
||||
if v.Jti != claims.Jti {
|
||||
if len(v.Jti) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
|
||||
} else if !v.IgnoreJti {
|
||||
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
|
||||
}
|
||||
}
|
||||
|
||||
// Optional exact match
|
||||
if params.Nonce != claims.Nonce {
|
||||
if len(params.Nonce) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce))
|
||||
} else if !params.IgnoreNonce {
|
||||
if v.Nonce != claims.Nonce {
|
||||
if len(v.Nonce) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
|
||||
} else if !v.IgnoreNonce {
|
||||
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
|
||||
}
|
||||
}
|
||||
|
||||
// Should exist, optional required-set check
|
||||
if !params.IgnoreAmr {
|
||||
if !v.IgnoreAmr {
|
||||
if len(claims.Amr) == 0 {
|
||||
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
|
||||
} else if len(params.RequiredAmrs) > 0 {
|
||||
for _, required := range params.RequiredAmrs {
|
||||
} else if len(v.RequiredAmrs) > 0 {
|
||||
for _, required := range v.RequiredAmrs {
|
||||
if !slices.Contains(claims.Amr, required) {
|
||||
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
|
||||
}
|
||||
@ -383,10 +407,10 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
||||
}
|
||||
|
||||
// Optional, match if present
|
||||
if params.Azp != claims.Azp {
|
||||
if len(params.Azp) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp))
|
||||
} else if !params.IgnoreAzp {
|
||||
if v.Azp != claims.Azp {
|
||||
if len(v.Azp) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
|
||||
} else if !v.IgnoreAzp {
|
||||
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
|
||||
}
|
||||
}
|
||||
@ -402,111 +426,154 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Issuer holds public keys and validation config for a trusted token issuer.
|
||||
// Issuer holds public keys and optional validation config for a trusted token issuer.
|
||||
//
|
||||
// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint.
|
||||
// [Issuer.SetKeys] injects keys directly (useful in tests).
|
||||
// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check.
|
||||
// Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2].
|
||||
// After construction, Issuer is immutable.
|
||||
//
|
||||
// Typical setup:
|
||||
//
|
||||
// iss := ajwt.NewIssuer("https://accounts.example.com")
|
||||
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
|
||||
// if err := iss.FetchKeys(ctx); err != nil { ... }
|
||||
// [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check.
|
||||
// [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator.
|
||||
type Issuer struct {
|
||||
URL string
|
||||
JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json"
|
||||
Params ValidateParams
|
||||
URL string // issuer URL for iss claim enforcement; empty skips the check
|
||||
validator *Validator
|
||||
keys map[string]crypto.PublicKey // kid → key
|
||||
}
|
||||
|
||||
// NewIssuer creates an Issuer for the given base URL.
|
||||
func NewIssuer(url string) *Issuer {
|
||||
return &Issuer{
|
||||
URL: url,
|
||||
keys: make(map[string]crypto.PublicKey),
|
||||
}
|
||||
}
|
||||
|
||||
// SetKeys stores public keys by their KID, replacing any previously stored keys.
|
||||
// Useful for injecting keys in tests without an HTTP round-trip.
|
||||
func (iss *Issuer) SetKeys(keys []PublicJWK) {
|
||||
// New creates an Issuer with explicit keys.
|
||||
//
|
||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||
// [Issuer.VerifyAndValidate] requires a non-nil Validator.
|
||||
func New(issURL string, keys []PublicJWK, v *Validator) *Issuer {
|
||||
m := make(map[string]crypto.PublicKey, len(keys))
|
||||
for _, k := range keys {
|
||||
m[k.KID] = k.Key
|
||||
}
|
||||
iss.keys = m
|
||||
return &Issuer{
|
||||
URL: issURL,
|
||||
validator: v,
|
||||
keys: m,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchKeys retrieves and stores the JWKS from the issuer's endpoint.
|
||||
// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json".
|
||||
func (iss *Issuer) FetchKeys(ctx context.Context) error {
|
||||
url := iss.JWKsURL
|
||||
if url == "" {
|
||||
url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch JWKS: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch JWKS: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
keys, err := DecodePublicJWKs(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse JWKS: %w", err)
|
||||
}
|
||||
|
||||
iss.SetKeys(keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify authenticates jws against this issuer:
|
||||
// 1. Looks up the signing key by jws.Header.Kid.
|
||||
// 2. Verifies the signature before trusting any payload data.
|
||||
// 3. Checks that the token's "iss" claim matches iss.URL.
|
||||
// NewWithJWKs creates an Issuer by fetching keys from jwksURL.
|
||||
//
|
||||
// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a
|
||||
// typed struct, then [ValidateParams.Validate] to check claim values.
|
||||
func (iss *Issuer) Verify(jws *JWS) error {
|
||||
// The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is
|
||||
// not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement.
|
||||
//
|
||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||
func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) {
|
||||
keys, err := FetchJWKs(ctx, jwksURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return New("", keys, v), nil
|
||||
}
|
||||
|
||||
// NewWithOIDC creates an Issuer using OIDC discovery.
|
||||
//
|
||||
// It fetches {baseURL}/.well-known/openid-configuration and reads the
|
||||
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
|
||||
// document's issuer field (not baseURL) because OIDC requires them to match.
|
||||
//
|
||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||
func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
|
||||
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
|
||||
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return New(issURL, keys, v), nil
|
||||
}
|
||||
|
||||
// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414).
|
||||
//
|
||||
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the
|
||||
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
|
||||
// document's issuer field.
|
||||
//
|
||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||
func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
|
||||
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
|
||||
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return New(issURL, keys, v), nil
|
||||
}
|
||||
|
||||
// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL]
|
||||
// is set) checks the iss claim.
|
||||
//
|
||||
// "Unsafe" means exp, aud, and other claim values are NOT checked — the token
|
||||
// is forgery-safe but not semantically validated. Callers are responsible for
|
||||
// validating claim values, or use [Issuer.VerifyAndValidate].
|
||||
func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
|
||||
jws, err := Decode(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if jws.Header.Kid == "" {
|
||||
return fmt.Errorf("missing 'kid' header")
|
||||
return nil, fmt.Errorf("missing 'kid' header")
|
||||
}
|
||||
key, ok := iss.keys[jws.Header.Kid]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown kid: %q", jws.Header.Kid)
|
||||
return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid)
|
||||
}
|
||||
|
||||
signingInput := jws.Protected + "." + jws.Payload
|
||||
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
return nil, fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Signature verified — now safe to inspect the payload.
|
||||
// Signature verified — now safe to inspect the payload for iss check.
|
||||
if iss.URL != "" {
|
||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid claims encoding: %w", err)
|
||||
return nil, fmt.Errorf("invalid claims encoding: %w", err)
|
||||
}
|
||||
var partial struct {
|
||||
Iss string `json:"iss"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &partial); err != nil {
|
||||
return fmt.Errorf("invalid claims JSON: %w", err)
|
||||
return nil, fmt.Errorf("invalid claims JSON: %w", err)
|
||||
}
|
||||
if partial.Iss != iss.URL {
|
||||
return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
|
||||
return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return jws, nil
|
||||
}
|
||||
|
||||
// VerifyAndValidate verifies the token signature and iss, unmarshals the claims
|
||||
// into claims, and runs the [Validator].
|
||||
//
|
||||
// Returns a hard error (err != nil) for signature failures, decoding errors,
|
||||
// and nil Validator. Returns soft errors (errs != nil) for claim validation
|
||||
// failures (wrong aud, expired token, etc.).
|
||||
//
|
||||
// claims must be a pointer whose underlying type embeds [StandardClaims] (or
|
||||
// otherwise implements [StandardClaimsSource]):
|
||||
//
|
||||
// var claims AppClaims
|
||||
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
|
||||
func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) {
|
||||
if iss.validator == nil {
|
||||
return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification")
|
||||
}
|
||||
|
||||
jws, err := iss.UnsafeVerify(tokenStr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := jws.UnmarshalClaims(claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
errs, err := iss.validator.Validate(claims, now)
|
||||
return jws, errs, err
|
||||
}
|
||||
|
||||
// verifyWith checks a JWS signature using the given algorithm and public key.
|
||||
|
||||
@ -14,7 +14,9 @@ import (
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -23,9 +25,8 @@ import (
|
||||
|
||||
// AppClaims embeds StandardClaims and adds application-specific fields.
|
||||
//
|
||||
// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate
|
||||
// interface — there is none. Validation is explicit: call
|
||||
// ValidateStandardClaims or ValidateParams.Validate at the call site.
|
||||
// Because StandardClaims is embedded, AppClaims satisfies StandardClaimsSource
|
||||
// for free via Go's method promotion — no interface to implement.
|
||||
type AppClaims struct {
|
||||
ajwt.StandardClaims
|
||||
Email string `json:"email"`
|
||||
@ -33,9 +34,10 @@ type AppClaims struct {
|
||||
}
|
||||
|
||||
// validateAppClaims is a plain function — not a method satisfying an interface.
|
||||
// Custom validation logic lives here, calling ValidateStandardClaims directly.
|
||||
func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) {
|
||||
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now)
|
||||
// It demonstrates the UnsafeVerify pattern: custom validation logic lives here,
|
||||
// calling ValidateStandardClaims directly and adding app-specific checks.
|
||||
func validateAppClaims(c AppClaims, v ajwt.Validator, now time.Time) ([]string, error) {
|
||||
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, v, now)
|
||||
if c.Email == "" {
|
||||
errs = append(errs, "missing email claim")
|
||||
}
|
||||
@ -65,11 +67,11 @@ func goodClaims() AppClaims {
|
||||
}
|
||||
}
|
||||
|
||||
// goodParams configures the validator. Iss is omitted because Issuer.Verify
|
||||
// already enforces the iss claim — no need to check it twice.
|
||||
func goodParams() ajwt.ValidateParams {
|
||||
return ajwt.ValidateParams{
|
||||
IgnoreIss: true, // Issuer.Verify handles iss
|
||||
// goodValidator configures the validator. IgnoreIss is true because
|
||||
// Issuer.UnsafeVerify already enforces the iss claim — no need to check twice.
|
||||
func goodValidator() *ajwt.Validator {
|
||||
return &ajwt.Validator{
|
||||
IgnoreIss: true, // UnsafeVerify handles iss
|
||||
Sub: "user123",
|
||||
Aud: "myapp",
|
||||
Jti: "abc123",
|
||||
@ -80,16 +82,13 @@ func goodParams() ajwt.ValidateParams {
|
||||
}
|
||||
|
||||
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
|
||||
iss := ajwt.NewIssuer("https://example.com")
|
||||
iss.Params = goodParams()
|
||||
iss.SetKeys([]ajwt.PublicJWK{pub})
|
||||
return iss
|
||||
return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator())
|
||||
}
|
||||
|
||||
// TestRoundTrip is the primary happy path using ES256.
|
||||
// It demonstrates the full Issuer-based flow:
|
||||
// It demonstrates the full VerifyAndValidate flow:
|
||||
//
|
||||
// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate
|
||||
// New → VerifyAndValidate → custom claim access
|
||||
//
|
||||
// No Claims interface, no Verified flag, no type assertions on jws.
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
@ -115,20 +114,16 @@ func TestRoundTrip(t *testing.T) {
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
||||
|
||||
jws2, err := ajwt.Decode(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = iss.Verify(jws2); err != nil {
|
||||
t.Fatalf("Verify failed: %v", err)
|
||||
}
|
||||
|
||||
var decoded AppClaims
|
||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||
}
|
||||
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
|
||||
t.Fatalf("validation failed: %v", errs)
|
||||
if len(errs) > 0 {
|
||||
t.Fatalf("claim validation failed: %v", errs)
|
||||
}
|
||||
if jws2.Header.Alg != "ES256" {
|
||||
t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg)
|
||||
}
|
||||
// Direct field access — no type assertion needed, no jws.Claims interface.
|
||||
if decoded.Email != claims.Email {
|
||||
@ -160,20 +155,13 @@ func TestRoundTripRS256(t *testing.T) {
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
||||
|
||||
jws2, err := ajwt.Decode(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = iss.Verify(jws2); err != nil {
|
||||
t.Fatalf("Verify failed: %v", err)
|
||||
}
|
||||
|
||||
var decoded AppClaims
|
||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||
}
|
||||
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
|
||||
t.Fatalf("validation failed: %v", errs)
|
||||
if len(errs) > 0 {
|
||||
t.Fatalf("claim validation failed: %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,25 +189,47 @@ func TestRoundTripEdDSA(t *testing.T) {
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
|
||||
|
||||
jws2, err := ajwt.Decode(token)
|
||||
var decoded AppClaims
|
||||
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||
}
|
||||
if err = iss.Verify(jws2); err != nil {
|
||||
t.Fatalf("Verify failed: %v", err)
|
||||
if len(errs) > 0 {
|
||||
t.Fatalf("claim validation failed: %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnsafeVerifyFlow demonstrates the UnsafeVerify + custom validation pattern.
|
||||
// The caller owns the full validation pipeline.
|
||||
func TestUnsafeVerifyFlow(t *testing.T) {
|
||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
|
||||
claims := goodClaims()
|
||||
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
|
||||
_, _ = jws.Sign(privKey)
|
||||
token := jws.Encode()
|
||||
|
||||
// Create issuer without validator — UnsafeVerify only.
|
||||
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
|
||||
|
||||
jws2, err := iss.UnsafeVerify(token)
|
||||
if err != nil {
|
||||
t.Fatalf("UnsafeVerify failed: %v", err)
|
||||
}
|
||||
|
||||
var decoded AppClaims
|
||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
||||
t.Fatalf("UnmarshalClaims failed: %v", err)
|
||||
}
|
||||
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
|
||||
t.Fatalf("validation failed: %v", errs)
|
||||
|
||||
errs, err := ajwt.ValidateStandardClaims(decoded.StandardClaims, *goodValidator(), time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateStandardClaims failed: %v — errs: %v", err, errs)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomValidation demonstrates custom claim validation without any interface.
|
||||
// The caller owns the validation logic and calls ValidateStandardClaims directly.
|
||||
// TestCustomValidation demonstrates that ValidateStandardClaims is called
|
||||
// explicitly and custom fields are validated without any Claims interface.
|
||||
func TestCustomValidation(t *testing.T) {
|
||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
|
||||
@ -231,12 +241,15 @@ func TestCustomValidation(t *testing.T) {
|
||||
token := jws.Encode()
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||
jws2, _ := ajwt.Decode(token)
|
||||
_ = iss.Verify(jws2)
|
||||
jws2, err := iss.UnsafeVerify(token)
|
||||
if err != nil {
|
||||
t.Fatalf("UnsafeVerify failed unexpectedly: %v", err)
|
||||
}
|
||||
|
||||
var decoded AppClaims
|
||||
_ = jws2.UnmarshalClaims(&decoded)
|
||||
|
||||
errs, err := validateAppClaims(decoded, goodParams(), time.Now())
|
||||
errs, err := validateAppClaims(decoded, *goodValidator(), time.Now())
|
||||
if err == nil {
|
||||
t.Fatal("expected validation to fail: email is empty")
|
||||
}
|
||||
@ -251,6 +264,23 @@ func TestCustomValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly
|
||||
// when no Validator was provided at construction time.
|
||||
func TestVerifyAndValidateNilValidator(t *testing.T) {
|
||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
c := goodClaims()
|
||||
jws, _ := ajwt.NewJWSFromClaims(&c, "k")
|
||||
_, _ = jws.Sign(privKey)
|
||||
token := jws.Encode()
|
||||
|
||||
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
|
||||
|
||||
var claims AppClaims
|
||||
if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil {
|
||||
t.Fatal("expected VerifyAndValidate to error with nil validator")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssuerWrongKey confirms that a different key's public key is rejected.
|
||||
func TestIssuerWrongKey(t *testing.T) {
|
||||
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
@ -262,10 +292,9 @@ func TestIssuerWrongKey(t *testing.T) {
|
||||
token := jws.Encode()
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
|
||||
jws2, _ := ajwt.Decode(token)
|
||||
|
||||
if err := iss.Verify(jws2); err == nil {
|
||||
t.Fatal("expected Verify to fail with wrong key")
|
||||
if _, err := iss.UnsafeVerify(token); err == nil {
|
||||
t.Fatal("expected UnsafeVerify to fail with wrong key")
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,10 +308,9 @@ func TestIssuerUnknownKid(t *testing.T) {
|
||||
token := jws.Encode()
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
|
||||
jws2, _ := ajwt.Decode(token)
|
||||
|
||||
if err := iss.Verify(jws2); err == nil {
|
||||
t.Fatal("expected Verify to fail for unknown kid")
|
||||
if _, err := iss.UnsafeVerify(token); err == nil {
|
||||
t.Fatal("expected UnsafeVerify to fail for unknown kid")
|
||||
}
|
||||
}
|
||||
|
||||
@ -299,14 +327,15 @@ func TestIssuerIssMismatch(t *testing.T) {
|
||||
|
||||
// Issuer expects "https://example.com" but token says "https://evil.example.com"
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||
jws2, _ := ajwt.Decode(token)
|
||||
|
||||
if err := iss.Verify(jws2); err == nil {
|
||||
t.Fatal("expected Verify to fail: iss mismatch")
|
||||
if _, err := iss.UnsafeVerify(token); err == nil {
|
||||
t.Fatal("expected UnsafeVerify to fail: iss mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyTamperedAlg confirms that a tampered alg header is rejected.
|
||||
// TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected.
|
||||
// The token is reconstructed with a replaced protected header; the original
|
||||
// ES256 signature is kept, making the signing input mismatch detectable.
|
||||
func TestVerifyTamperedAlg(t *testing.T) {
|
||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
|
||||
@ -316,11 +345,15 @@ func TestVerifyTamperedAlg(t *testing.T) {
|
||||
token := jws.Encode()
|
||||
|
||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||
jws2, _ := ajwt.Decode(token)
|
||||
jws2.Header.Alg = "none" // tamper
|
||||
|
||||
if err := iss.Verify(jws2); err == nil {
|
||||
t.Fatal("expected Verify to fail for tampered alg")
|
||||
// Replace the protected header with one that has alg:"none".
|
||||
// The original ES256 signature stays — the signing input will mismatch.
|
||||
noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`))
|
||||
parts := strings.SplitN(token, ".", 3)
|
||||
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
|
||||
|
||||
if _, err := iss.UnsafeVerify(tamperedToken); err == nil {
|
||||
t.Fatal("expected UnsafeVerify to fail for tampered alg")
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,3 +441,79 @@ func TestDecodePublicJWKJSON(t *testing.T) {
|
||||
t.Errorf("expected 1 RSA key, got %d", rsaCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestThumbprint verifies that Thumbprint returns a non-empty base64url string
|
||||
// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint.
|
||||
func TestThumbprint(t *testing.T) {
|
||||
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
jwk ajwt.PublicJWK
|
||||
}{
|
||||
{"EC P-256", ajwt.PublicJWK{Key: &ecKey.PublicKey}},
|
||||
{"RSA 2048", ajwt.PublicJWK{Key: &rsaKey.PublicKey}},
|
||||
{"Ed25519", ajwt.PublicJWK{Key: edPub}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
thumb, err := tt.jwk.Thumbprint()
|
||||
if err != nil {
|
||||
t.Fatalf("Thumbprint() error: %v", err)
|
||||
}
|
||||
if thumb == "" {
|
||||
t.Fatal("Thumbprint() returned empty string")
|
||||
}
|
||||
// Must be valid base64url (no padding, no +/)
|
||||
if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") {
|
||||
t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb)
|
||||
}
|
||||
// Same key → same thumbprint
|
||||
thumb2, _ := tt.jwk.Thumbprint()
|
||||
if thumb != thumb2 {
|
||||
t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets
|
||||
// its KID auto-populated from the RFC 7638 thumbprint.
|
||||
func TestNoKidAutoThumbprint(t *testing.T) {
|
||||
// EC key with no "kid" field in the JWKS
|
||||
jwksJSON := []byte(`{"keys":[
|
||||
{"kty":"EC","crv":"P-256",
|
||||
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
|
||||
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
|
||||
"use":"sig"}
|
||||
]}`)
|
||||
|
||||
keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(keys) != 1 {
|
||||
t.Fatalf("expected 1 key, got %d", len(keys))
|
||||
}
|
||||
if keys[0].KID == "" {
|
||||
t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS")
|
||||
}
|
||||
|
||||
// The auto-KID should be a valid base64url string.
|
||||
kid := keys[0].KID
|
||||
if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") {
|
||||
t.Errorf("auto-KID contains non-base64url characters: %s", kid)
|
||||
}
|
||||
|
||||
// Round-trip: compute Thumbprint directly and compare.
|
||||
thumb, err := keys[0].Thumbprint()
|
||||
if err != nil {
|
||||
t.Fatalf("Thumbprint() error: %v", err)
|
||||
}
|
||||
if kid != thumb {
|
||||
t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb)
|
||||
}
|
||||
}
|
||||
|
||||
171
auth/ajwt/pub.go
171
auth/ajwt/pub.go
@ -9,11 +9,13 @@
|
||||
package ajwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -21,6 +23,7 @@ import (
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -54,6 +57,90 @@ func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) {
|
||||
return key, ok
|
||||
}
|
||||
|
||||
// Thumbprint computes the RFC 7638 JWK Thumbprint (SHA-256 of the canonical
|
||||
// key JSON with fields in lexicographic order). The result is base64url-encoded.
|
||||
//
|
||||
// Canonical forms per RFC 7638:
|
||||
// - EC: {"crv":…, "kty":"EC", "x":…, "y":…}
|
||||
// - RSA: {"e":…, "kty":"RSA", "n":…}
|
||||
// - OKP: {"crv":"Ed25519", "kty":"OKP", "x":…}
|
||||
//
|
||||
// Use Thumbprint as KID when none is provided in the JWKS source.
|
||||
func (k PublicJWK) Thumbprint() (string, error) {
|
||||
var canonical []byte
|
||||
var err error
|
||||
|
||||
switch key := k.Key.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||
xBytes := make([]byte, byteLen)
|
||||
yBytes := make([]byte, byteLen)
|
||||
key.X.FillBytes(xBytes)
|
||||
key.Y.FillBytes(yBytes)
|
||||
|
||||
var crv string
|
||||
switch key.Curve {
|
||||
case elliptic.P256():
|
||||
crv = "P-256"
|
||||
case elliptic.P384():
|
||||
crv = "P-384"
|
||||
case elliptic.P521():
|
||||
crv = "P-521"
|
||||
default:
|
||||
return "", fmt.Errorf("Thumbprint: unsupported EC curve %s", key.Curve.Params().Name)
|
||||
}
|
||||
|
||||
// Fields in lexicographic order: crv, kty, x, y
|
||||
canonical, err = json.Marshal(struct {
|
||||
Crv string `json:"crv"`
|
||||
Kty string `json:"kty"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}{
|
||||
Crv: crv,
|
||||
Kty: "EC",
|
||||
X: base64.RawURLEncoding.EncodeToString(xBytes),
|
||||
Y: base64.RawURLEncoding.EncodeToString(yBytes),
|
||||
})
|
||||
|
||||
case *rsa.PublicKey:
|
||||
eInt := big.NewInt(int64(key.E))
|
||||
|
||||
// Fields in lexicographic order: e, kty, n
|
||||
canonical, err = json.Marshal(struct {
|
||||
E string `json:"e"`
|
||||
Kty string `json:"kty"`
|
||||
N string `json:"n"`
|
||||
}{
|
||||
E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()),
|
||||
Kty: "RSA",
|
||||
N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()),
|
||||
})
|
||||
|
||||
case ed25519.PublicKey:
|
||||
// Fields in lexicographic order: crv, kty, x
|
||||
canonical, err = json.Marshal(struct {
|
||||
Crv string `json:"crv"`
|
||||
Kty string `json:"kty"`
|
||||
X string `json:"x"`
|
||||
}{
|
||||
Crv: "Ed25519",
|
||||
Kty: "OKP",
|
||||
X: base64.RawURLEncoding.EncodeToString([]byte(key)),
|
||||
})
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("Thumbprint: unsupported key type %T", k.Key)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Thumbprint: marshal canonical JSON: %w", err)
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(canonical)
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
|
||||
type PublicJWKJSON struct {
|
||||
Kty string `json:"kty"`
|
||||
@ -71,24 +158,83 @@ type JWKsJSON struct {
|
||||
Keys []PublicJWKJSON `json:"keys"`
|
||||
}
|
||||
|
||||
// FetchPublicJWKs retrieves and parses a JWKS document from url.
|
||||
// FetchJWKs retrieves and parses a JWKS document from jwksURL.
|
||||
//
|
||||
// For issuer-scoped key management with context support, use
|
||||
// [Issuer.FetchKeys] instead.
|
||||
func FetchPublicJWKs(url string) ([]PublicJWK, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
// ctx is used for the HTTP request timeout and cancellation.
|
||||
func FetchJWKs(ctx context.Context, jwksURL string) ([]PublicJWK, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||
return nil, fmt.Errorf("fetch JWKS: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch JWKS: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
return DecodePublicJWKs(resp.Body)
|
||||
}
|
||||
|
||||
// FetchJWKsFromOIDC fetches JWKS via OIDC discovery from baseURL.
|
||||
//
|
||||
// It fetches {baseURL}/.well-known/openid-configuration and reads the jwks_uri field.
|
||||
func FetchJWKsFromOIDC(ctx context.Context, baseURL string) ([]PublicJWK, error) {
|
||||
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
|
||||
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||
return keys, err
|
||||
}
|
||||
|
||||
// FetchJWKsFromOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata (RFC 8414)
|
||||
// from baseURL.
|
||||
//
|
||||
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the jwks_uri field.
|
||||
func FetchJWKsFromOAuth2(ctx context.Context, baseURL string) ([]PublicJWK, error) {
|
||||
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
|
||||
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||
return keys, err
|
||||
}
|
||||
|
||||
// fetchJWKsFromDiscovery fetches a discovery document from discoveryURL, then
|
||||
// fetches the JWKS from the jwks_uri field. Returns the keys and the issuer
|
||||
// URL from the discovery document's "issuer" field.
|
||||
func fetchJWKsFromDiscovery(ctx context.Context, discoveryURL string) ([]PublicJWK, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("fetch discovery: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("fetch discovery: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("fetch discovery: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var doc struct {
|
||||
Issuer string `json:"issuer"`
|
||||
JWKsURI string `json:"jwks_uri"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
||||
return nil, "", fmt.Errorf("parse discovery doc: %w", err)
|
||||
}
|
||||
if doc.JWKsURI == "" {
|
||||
return nil, "", fmt.Errorf("discovery doc missing jwks_uri field")
|
||||
}
|
||||
|
||||
keys, err := FetchJWKs(ctx, doc.JWKsURI)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return keys, doc.Issuer, nil
|
||||
}
|
||||
|
||||
// ReadPublicJWKs reads and parses a JWKS document from a file path.
|
||||
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
|
||||
file, err := os.Open(filePath)
|
||||
@ -118,6 +264,9 @@ func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
|
||||
}
|
||||
|
||||
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
|
||||
//
|
||||
// If a key has no kid field in the source document, the KID is auto-populated
|
||||
// from [PublicJWK.Thumbprint] per RFC 7638.
|
||||
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
|
||||
var keys []PublicJWK
|
||||
for _, jwk := range jwks.Keys {
|
||||
@ -125,6 +274,12 @@ func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err)
|
||||
}
|
||||
if key.KID == "" {
|
||||
key.KID, err = key.Thumbprint()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compute thumbprint for kid-less key: %w", err)
|
||||
}
|
||||
}
|
||||
keys = append(keys, *key)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user