mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
ajwt: redesign API — immutable Issuer, Signer, JWKsFetcher
Key changes from previous design: - Issuer is now immutable after construction (no mutex, no SetKeys) - New(keys []PublicJWK) — no issURL or Validator baked in - Verify returns (nil, err) on any failure; UnsafeVerify returns (*JWS, err) even on sig failure so callers can inspect kid/iss for multi-issuer routing - VerifyAndValidate takes ClaimsValidator per-call instead of baking it into the Issuer; soft errors in errs, hard errors in err, nil sentinel discarded - ClaimsValidator interface implemented by *Validator and *MultiValidator - MultiValidator: []string for iss, aud, azp (multi-tenant) - Signer: round-robin across NamedSigner keys via atomic.Uint64; auto-KID from RFC 7638 thumbprint; Issuer() returns *Issuer with signer's public keys - JWKsFetcher: lazy, no background goroutine; Issuer(ctx) checks freshness per call and creates new *Issuer on cache miss; KeepOnError + StaleAge for serving stale keys on fetch failure - pub.go: add EncodePublicJWK and MarshalPublicJWKs (encode counterparts) - Remove NewWithJWKs, NewWithOIDC, NewWithOAuth2 constructors from Issuer
This commit is contained in:
parent
3f7985317f
commit
2f946d28b5
122
auth/ajwt/fetcher.go
Normal file
122
auth/ajwt/fetcher.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package ajwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cachedIssuer bundles an [*Issuer] with its freshness window.
|
||||||
|
// Stored atomically in [JWKsFetcher]; immutable after creation.
|
||||||
|
type cachedIssuer struct {
|
||||||
|
iss *Issuer
|
||||||
|
fetchedAt time.Time
|
||||||
|
expiresAt time.Time // fetchedAt + MaxAge; fresh until this point
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWKsFetcher lazily fetches and caches JWKS keys from a remote URL,
|
||||||
|
// returning a fresh [*Issuer] on demand.
|
||||||
|
//
|
||||||
|
// Each call to [JWKsFetcher.Issuer] checks freshness and either returns the
|
||||||
|
// cached Issuer immediately or fetches a new one. There is no background
|
||||||
|
// goroutine — refresh only happens when a caller requests an Issuer.
|
||||||
|
//
|
||||||
|
// Fields must be set before the first call to [JWKsFetcher.Issuer]; do not
|
||||||
|
// modify them concurrently.
|
||||||
|
//
|
||||||
|
// Typical usage:
|
||||||
|
//
|
||||||
|
// fetcher := &ajwt.JWKsFetcher{
|
||||||
|
// URL: "https://accounts.example.com/.well-known/jwks.json",
|
||||||
|
// MaxAge: time.Hour,
|
||||||
|
// StaleAge: 30 * time.Minute,
|
||||||
|
// KeepOnError: true,
|
||||||
|
// }
|
||||||
|
// iss, err := fetcher.Issuer(ctx)
|
||||||
|
type JWKsFetcher struct {
|
||||||
|
// URL is the JWKS endpoint to fetch keys from.
|
||||||
|
URL string
|
||||||
|
|
||||||
|
// MaxAge is how long fetched keys are considered fresh. After MaxAge,
|
||||||
|
// the next call to Issuer triggers a refresh. Defaults to 1 hour.
|
||||||
|
MaxAge time.Duration
|
||||||
|
|
||||||
|
// StaleAge is additional time beyond MaxAge during which the old Issuer
|
||||||
|
// may be returned when a refresh fails. For example, MaxAge=1h and
|
||||||
|
// StaleAge=30m means keys will be served up to 90 minutes after the last
|
||||||
|
// successful fetch, if KeepOnError is true and fetches keep failing.
|
||||||
|
// Defaults to 0 (no stale window).
|
||||||
|
StaleAge time.Duration
|
||||||
|
|
||||||
|
// KeepOnError causes the previous Issuer to be returned (with an error)
|
||||||
|
// when a refresh fails, as long as the result is within the stale window
|
||||||
|
// (expiresAt + StaleAge). If false, any fetch error after MaxAge returns
|
||||||
|
// (nil, err).
|
||||||
|
KeepOnError bool
|
||||||
|
|
||||||
|
// RespectHeaders is reserved for future use (honor Cache-Control max-age
|
||||||
|
// from the JWKS response, capped at MaxAge).
|
||||||
|
RespectHeaders bool
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
cached atomic.Pointer[cachedIssuer]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issuer returns a current [*Issuer] for verifying tokens.
|
||||||
|
//
|
||||||
|
// If the cached Issuer is still fresh (within MaxAge), it is returned without
|
||||||
|
// a network call. If it has expired, a new fetch is performed. On fetch
|
||||||
|
// failure with KeepOnError=true and within StaleAge, the old Issuer is
|
||||||
|
// returned alongside a non-nil error; callers may choose to accept it.
|
||||||
|
func (f *JWKsFetcher) Issuer(ctx context.Context) (*Issuer, error) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Fast path: check cached value without locking.
|
||||||
|
if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) {
|
||||||
|
return ci.iss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slow path: refresh needed. Serialize to avoid stampeding.
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
// Re-check after acquiring lock — another goroutine may have refreshed.
|
||||||
|
if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) {
|
||||||
|
return ci.iss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := FetchJWKs(ctx, f.URL)
|
||||||
|
if err != nil {
|
||||||
|
// On error, serve stale keys within the stale window.
|
||||||
|
if ci := f.cached.Load(); ci != nil && f.KeepOnError {
|
||||||
|
staleDeadline := ci.expiresAt.Add(f.StaleAge)
|
||||||
|
if now.Before(staleDeadline) {
|
||||||
|
return ci.iss, fmt.Errorf("JWKS refresh failed (serving cached keys): %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("fetch JWKS from %s: %w", f.URL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAge := f.MaxAge
|
||||||
|
if maxAge <= 0 {
|
||||||
|
maxAge = time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
ci := &cachedIssuer{
|
||||||
|
iss: New(keys),
|
||||||
|
fetchedAt: now,
|
||||||
|
expiresAt: now.Add(maxAge),
|
||||||
|
}
|
||||||
|
f.cached.Store(ci)
|
||||||
|
return ci.iss, nil
|
||||||
|
}
|
||||||
325
auth/ajwt/jwt.go
325
auth/ajwt/jwt.go
@ -9,40 +9,52 @@
|
|||||||
// Package ajwt is a lightweight JWT/JWS/JWK library designed from first
|
// Package ajwt is a lightweight JWT/JWS/JWK library designed from first
|
||||||
// principles:
|
// principles:
|
||||||
//
|
//
|
||||||
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
|
// - [Issuer] is immutable — constructed with a fixed key set, safe for concurrent use.
|
||||||
// - [Issuer] owns key management and signature verification, centralizing
|
// - [Signer] manages private keys and returns [*Issuer] for verification.
|
||||||
// the key lookup → sig verify → iss check sequence.
|
// - [JWKsFetcher] lazily fetches and caches JWKS keys, returning a fresh [*Issuer] on demand.
|
||||||
// - [Validator] is a stable config object; time is passed at the call site
|
// - [Validator] and [MultiValidator] validate standard JWT/OIDC claims.
|
||||||
// so the same validator can be reused across requests.
|
// - [JWS] is a parsed structure — use [Issuer.Verify] or [Issuer.UnsafeVerify] to authenticate.
|
||||||
// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims].
|
// - [JWS.UnmarshalClaims] accepts any type — no Claims interface to implement.
|
||||||
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
|
// - [StandardClaimsSource] is satisfied for free by embedding [StandardClaims].
|
||||||
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
|
|
||||||
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
|
|
||||||
//
|
//
|
||||||
// Typical usage with VerifyAndValidate:
|
// Typical usage with VerifyAndValidate:
|
||||||
//
|
//
|
||||||
// // At startup:
|
// // At startup:
|
||||||
// iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com",
|
// signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{Signer: privKey}})
|
||||||
// &ajwt.Validator{Aud: "my-app", IgnoreIss: true})
|
// iss := signer.Issuer()
|
||||||
|
// v := &ajwt.Validator{Iss: "https://example.com", Aud: "my-app"}
|
||||||
|
//
|
||||||
|
// // Sign a token:
|
||||||
|
// tokenStr, err := signer.Sign(claims)
|
||||||
//
|
//
|
||||||
// // Per request:
|
// // Per request:
|
||||||
// var claims AppClaims
|
// var claims AppClaims
|
||||||
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
|
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, v, time.Now())
|
||||||
// if err != nil { /* hard error: bad sig, expired, etc. */ }
|
// if err != nil { /* hard error: bad sig, malformed token */ }
|
||||||
// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ }
|
// if len(errs) > 0 { /* soft errors: wrong aud, expired, etc. */ }
|
||||||
//
|
//
|
||||||
// Typical usage with UnsafeVerify (custom validation only):
|
// Typical usage with UnsafeVerify (custom validation):
|
||||||
//
|
//
|
||||||
// iss := ajwt.New("https://example.com", keys, nil)
|
// iss := ajwt.New(keys)
|
||||||
// jws, err := iss.UnsafeVerify(tokenStr)
|
// jws, err := iss.UnsafeVerify(tokenStr)
|
||||||
// var claims AppClaims
|
// var claims AppClaims
|
||||||
// jws.UnmarshalClaims(&claims)
|
// jws.UnmarshalClaims(&claims)
|
||||||
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims,
|
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims,
|
||||||
// ajwt.Validator{Aud: "myapp"}, time.Now())
|
// ajwt.Validator{Aud: "myapp"}, time.Now())
|
||||||
|
//
|
||||||
|
// Typical usage with JWKsFetcher (dynamic keys from remote):
|
||||||
|
//
|
||||||
|
// fetcher := &ajwt.JWKsFetcher{
|
||||||
|
// URL: "https://accounts.example.com/.well-known/jwks.json",
|
||||||
|
// MaxAge: time.Hour,
|
||||||
|
// StaleAge: time.Hour,
|
||||||
|
// KeepOnError: true,
|
||||||
|
// }
|
||||||
|
// iss, err := fetcher.Issuer(ctx)
|
||||||
|
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, v, time.Now())
|
||||||
package ajwt
|
package ajwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
@ -65,8 +77,8 @@ import (
|
|||||||
//
|
//
|
||||||
// It holds only the parsed structure — header, raw base64url fields, and
|
// It holds only the parsed structure — header, raw base64url fields, and
|
||||||
// decoded signature bytes. It carries no Claims interface and no Verified flag;
|
// decoded signature bytes. It carries no Claims interface and no Verified flag;
|
||||||
// use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the
|
// use [Issuer.Verify] or [Issuer.UnsafeVerify] to authenticate the token and
|
||||||
// token and [JWS.UnmarshalClaims] to decode the payload into a typed struct.
|
// [JWS.UnmarshalClaims] to decode the payload into a typed struct.
|
||||||
type JWS struct {
|
type JWS struct {
|
||||||
Protected string // base64url-encoded header
|
Protected string // base64url-encoded header
|
||||||
Header StandardHeader
|
Header StandardHeader
|
||||||
@ -120,10 +132,16 @@ type StandardClaimsSource interface {
|
|||||||
GetStandardClaims() StandardClaims
|
GetStandardClaims() StandardClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaimsValidator validates the standard JWT/OIDC claims in a token.
|
||||||
|
// Implemented by [*Validator] and [*MultiValidator].
|
||||||
|
type ClaimsValidator interface {
|
||||||
|
Validate(claims StandardClaimsSource, now time.Time) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
||||||
//
|
//
|
||||||
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
||||||
// [Issuer.UnsafeVerify] to safely populate a typed claims struct.
|
// [Issuer.Verify] or [Issuer.UnsafeVerify] to populate a typed claims struct.
|
||||||
func Decode(tokenStr string) (*JWS, error) {
|
func Decode(tokenStr string) (*JWS, error) {
|
||||||
parts := strings.Split(tokenStr, ".")
|
parts := strings.Split(tokenStr, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
@ -152,8 +170,8 @@ func Decode(tokenStr string) (*JWS, error) {
|
|||||||
// UnmarshalClaims decodes the JWT payload into v.
|
// UnmarshalClaims decodes the JWT payload into v.
|
||||||
//
|
//
|
||||||
// v must be a pointer to a struct (e.g. *AppClaims). Always call
|
// v must be a pointer to a struct (e.g. *AppClaims). Always call
|
||||||
// [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is
|
// [Issuer.Verify] or [Issuer.UnsafeVerify] before UnmarshalClaims to ensure
|
||||||
// authenticated before trusting the payload.
|
// the signature is authenticated before trusting the payload.
|
||||||
func (jws *JWS) UnmarshalClaims(v any) error {
|
func (jws *JWS) UnmarshalClaims(v any) error {
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -256,16 +274,15 @@ func (jws *JWS) Encode() string {
|
|||||||
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
|
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validator holds claim validation configuration.
|
// Validator holds claim validation configuration for single-tenant use.
|
||||||
//
|
//
|
||||||
// Configure once at startup; call [Validator.Validate] per request, passing
|
// Configure once at startup; pass to [Issuer.VerifyAndValidate] or call
|
||||||
// the current time. This keeps the config stable and makes the time dependency
|
// [Validator.Validate] directly per request.
|
||||||
// explicit at the call site.
|
|
||||||
//
|
//
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||||
type Validator struct {
|
type Validator struct {
|
||||||
IgnoreIss bool
|
IgnoreIss bool
|
||||||
Iss string // rarely needed — Issuer.UnsafeVerify already checks iss
|
Iss string
|
||||||
IgnoreSub bool
|
IgnoreSub bool
|
||||||
Sub string
|
Sub string
|
||||||
IgnoreAud bool
|
IgnoreAud bool
|
||||||
@ -284,12 +301,104 @@ type Validator struct {
|
|||||||
Azp string
|
Azp string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks the standard JWT/OIDC claim fields in claims against this config.
|
// Validate implements [ClaimsValidator].
|
||||||
//
|
func (v *Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
|
||||||
// now is typically time.Now() — passing it explicitly keeps the config stable
|
return ValidateStandardClaims(claims.GetStandardClaims(), *v, now)
|
||||||
// across requests and avoids hidden time dependencies in the validator struct.
|
}
|
||||||
func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
|
|
||||||
return ValidateStandardClaims(claims.GetStandardClaims(), v, now)
|
// MultiValidator holds claim validation configuration for multi-tenant use.
|
||||||
|
// Iss, Aud, and Azp accept slices — the claim value must appear in the slice.
|
||||||
|
type MultiValidator struct {
|
||||||
|
Iss []string
|
||||||
|
IgnoreIss bool
|
||||||
|
IgnoreSub bool
|
||||||
|
Aud []string
|
||||||
|
IgnoreAud bool
|
||||||
|
IgnoreExp bool
|
||||||
|
IgnoreIat bool
|
||||||
|
IgnoreAuthTime bool
|
||||||
|
MaxAge time.Duration
|
||||||
|
IgnoreNonce bool
|
||||||
|
IgnoreAmr bool
|
||||||
|
RequiredAmrs []string
|
||||||
|
IgnoreAzp bool
|
||||||
|
Azp []string
|
||||||
|
IgnoreJti bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate implements [ClaimsValidator].
|
||||||
|
func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
|
||||||
|
sc := claims.GetStandardClaims()
|
||||||
|
var errs []string
|
||||||
|
|
||||||
|
if !v.IgnoreIss {
|
||||||
|
if sc.Iss == "" {
|
||||||
|
errs = append(errs, "missing or malformed 'iss' (token issuer)")
|
||||||
|
} else if len(v.Iss) > 0 && !slices.Contains(v.Iss, sc.Iss) {
|
||||||
|
errs = append(errs, fmt.Sprintf("'iss' %q not in allowed list", sc.Iss))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.IgnoreAud {
|
||||||
|
if sc.Aud == "" {
|
||||||
|
errs = append(errs, "missing or malformed 'aud' (audience)")
|
||||||
|
} else if len(v.Aud) > 0 && !slices.Contains(v.Aud, sc.Aud) {
|
||||||
|
errs = append(errs, fmt.Sprintf("'aud' %q not in allowed list", sc.Aud))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.IgnoreExp {
|
||||||
|
if sc.Exp <= 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'exp' (expiration)")
|
||||||
|
} else if sc.Exp < now.Unix() {
|
||||||
|
duration := now.Sub(time.Unix(sc.Exp, 0))
|
||||||
|
errs = append(errs, fmt.Sprintf("token expired %s ago", formatDuration(duration)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.IgnoreIat {
|
||||||
|
if sc.Iat <= 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'iat' (issued at)")
|
||||||
|
} else if sc.Iat > now.Unix() {
|
||||||
|
errs = append(errs, "'iat' is in the future")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.MaxAge > 0 || !v.IgnoreAuthTime {
|
||||||
|
if sc.AuthTime == 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'auth_time'")
|
||||||
|
} else if sc.AuthTime > now.Unix() {
|
||||||
|
errs = append(errs, "'auth_time' is in the future")
|
||||||
|
} else if v.MaxAge > 0 {
|
||||||
|
age := now.Sub(time.Unix(sc.AuthTime, 0))
|
||||||
|
if age > v.MaxAge {
|
||||||
|
errs = append(errs, fmt.Sprintf("'auth_time' exceeds max age %s by %s", formatDuration(v.MaxAge), formatDuration(age-v.MaxAge)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.IgnoreAmr {
|
||||||
|
if len(sc.Amr) == 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'amr'")
|
||||||
|
} else {
|
||||||
|
for _, req := range v.RequiredAmrs {
|
||||||
|
if !slices.Contains(sc.Amr, req) {
|
||||||
|
errs = append(errs, fmt.Sprintf("missing required %q from 'amr'", req))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.IgnoreAzp {
|
||||||
|
if len(v.Azp) > 0 && !slices.Contains(v.Azp, sc.Azp) {
|
||||||
|
errs = append(errs, fmt.Sprintf("'azp' %q not in allowed list", sc.Azp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errs, fmt.Errorf("has errors")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against v.
|
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against v.
|
||||||
@ -304,7 +413,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
|||||||
if len(v.Iss) > 0 || !v.IgnoreIss {
|
if len(v.Iss) > 0 || !v.IgnoreIss {
|
||||||
if len(claims.Iss) == 0 {
|
if len(claims.Iss) == 0 {
|
||||||
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
|
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
|
||||||
} else if claims.Iss != v.Iss {
|
} else if len(v.Iss) > 0 && claims.Iss != v.Iss {
|
||||||
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss))
|
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -324,7 +433,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
|||||||
if len(v.Aud) > 0 || !v.IgnoreAud {
|
if len(v.Aud) > 0 || !v.IgnoreAud {
|
||||||
if len(claims.Aud) == 0 {
|
if len(claims.Aud) == 0 {
|
||||||
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
||||||
} else if claims.Aud != v.Aud {
|
} else if len(v.Aud) > 0 && claims.Aud != v.Aud {
|
||||||
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
|
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -426,87 +535,63 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issuer holds public keys and optional validation config for a trusted token issuer.
|
// Issuer holds public keys for a trusted token issuer.
|
||||||
//
|
//
|
||||||
// Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2].
|
// Issuer is immutable after construction — safe for concurrent use with no locking.
|
||||||
// After construction, Issuer is immutable.
|
// Use [New] to construct with a fixed key set, or use [Signer.Issuer] or
|
||||||
//
|
// [JWKsFetcher.Issuer] to obtain one from a signer or remote JWKS endpoint.
|
||||||
// [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check.
|
|
||||||
// [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator.
|
|
||||||
type Issuer struct {
|
type Issuer struct {
|
||||||
URL string // issuer URL for iss claim enforcement; empty skips the check
|
pubKeys []PublicJWK
|
||||||
validator *Validator
|
|
||||||
keys map[string]crypto.PublicKey // kid → key
|
keys map[string]crypto.PublicKey // kid → key
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates an Issuer with explicit keys.
|
// New creates an Issuer with an explicit set of public keys.
|
||||||
//
|
//
|
||||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
// The returned Issuer is immutable — keys cannot be added or removed after
|
||||||
// [Issuer.VerifyAndValidate] requires a non-nil Validator.
|
// construction. For dynamic key rotation, see [JWKsFetcher].
|
||||||
func New(issURL string, keys []PublicJWK, v *Validator) *Issuer {
|
func New(keys []PublicJWK) *Issuer {
|
||||||
m := make(map[string]crypto.PublicKey, len(keys))
|
m := make(map[string]crypto.PublicKey, len(keys))
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
m[k.KID] = k.Key
|
m[k.KID] = k.Key
|
||||||
}
|
}
|
||||||
return &Issuer{
|
return &Issuer{
|
||||||
URL: issURL,
|
pubKeys: keys,
|
||||||
validator: v,
|
|
||||||
keys: m,
|
keys: m,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithJWKs creates an Issuer by fetching keys from jwksURL.
|
// PublicKeys returns the public keys held by this Issuer.
|
||||||
|
func (iss *Issuer) PublicKeys() []PublicJWK {
|
||||||
|
return iss.pubKeys
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToJWKs serializes the Issuer's public keys as a JWKS JSON document.
|
||||||
|
func (iss *Issuer) ToJWKs() ([]byte, error) {
|
||||||
|
return MarshalPublicJWKs(iss.pubKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify decodes tokenStr and verifies its signature.
|
||||||
//
|
//
|
||||||
// The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is
|
// Returns (nil, err) on any failure — the caller never receives an
|
||||||
// not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement.
|
// unauthenticated JWS. For inspecting a JWS despite signature failure
|
||||||
//
|
// (e.g., for multi-issuer routing by kid/iss), use [Issuer.UnsafeVerify].
|
||||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
func (iss *Issuer) Verify(tokenStr string) (*JWS, error) {
|
||||||
func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) {
|
jws, err := iss.UnsafeVerify(tokenStr)
|
||||||
keys, err := FetchJWKs(ctx, jwksURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return New("", keys, v), nil
|
return jws, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithOIDC creates an Issuer using OIDC discovery.
|
// UnsafeVerify decodes tokenStr and verifies the signature.
|
||||||
//
|
//
|
||||||
// It fetches {baseURL}/.well-known/openid-configuration and reads the
|
// Unlike [Issuer.Verify], UnsafeVerify returns the parsed [*JWS] even when
|
||||||
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
|
// signature verification fails — the error is non-nil but the JWS is
|
||||||
// document's issuer field (not baseURL) because OIDC requires them to match.
|
// available for inspection (e.g., to read the kid or iss for multi-issuer
|
||||||
|
// routing). Returns (nil, err) only when the token cannot be parsed at all.
|
||||||
//
|
//
|
||||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
// "Unsafe" means exp, aud, iss, and other claim values are NOT checked.
|
||||||
func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
|
// Use [Issuer.VerifyAndValidate] for full validation.
|
||||||
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
|
|
||||||
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return New(issURL, keys, v), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414).
|
|
||||||
//
|
|
||||||
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the
|
|
||||||
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
|
|
||||||
// document's issuer field.
|
|
||||||
//
|
|
||||||
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
|
||||||
func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
|
|
||||||
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
|
|
||||||
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return New(issURL, keys, v), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL]
|
|
||||||
// is set) checks the iss claim.
|
|
||||||
//
|
|
||||||
// "Unsafe" means exp, aud, and other claim values are NOT checked — the token
|
|
||||||
// is forgery-safe but not semantically validated. Callers are responsible for
|
|
||||||
// validating claim values, or use [Issuer.VerifyAndValidate].
|
|
||||||
func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
|
func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
|
||||||
jws, err := Decode(tokenStr)
|
jws, err := Decode(tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -514,56 +599,34 @@ func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if jws.Header.Kid == "" {
|
if jws.Header.Kid == "" {
|
||||||
return nil, fmt.Errorf("missing 'kid' header")
|
return jws, fmt.Errorf("missing 'kid' header")
|
||||||
}
|
}
|
||||||
key, ok := iss.keys[jws.Header.Kid]
|
key, ok := iss.keys[jws.Header.Kid]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid)
|
return jws, fmt.Errorf("unknown kid: %q", jws.Header.Kid)
|
||||||
}
|
}
|
||||||
|
|
||||||
signingInput := jws.Protected + "." + jws.Payload
|
signingInput := jws.Protected + "." + jws.Payload
|
||||||
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
|
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
|
||||||
return nil, fmt.Errorf("signature verification failed: %w", err)
|
return jws, fmt.Errorf("signature verification failed: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
// Signature verified — now safe to inspect the payload for iss check.
|
|
||||||
if iss.URL != "" {
|
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid claims encoding: %w", err)
|
|
||||||
}
|
|
||||||
var partial struct {
|
|
||||||
Iss string `json:"iss"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(payload, &partial); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid claims JSON: %w", err)
|
|
||||||
}
|
|
||||||
if partial.Iss != iss.URL {
|
|
||||||
return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return jws, nil
|
return jws, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyAndValidate verifies the token signature and iss, unmarshals the claims
|
// VerifyAndValidate verifies the token signature, unmarshals the claims
|
||||||
// into claims, and runs the [Validator].
|
// into claims, and runs v.
|
||||||
//
|
//
|
||||||
// Returns a hard error (err != nil) for signature failures, decoding errors,
|
// Returns a hard error (err != nil) for signature failures and decoding errors.
|
||||||
// and nil Validator. Returns soft errors (errs != nil) for claim validation
|
// Returns soft errors (errs != nil) for claim validation failures (wrong aud,
|
||||||
// failures (wrong aud, expired token, etc.).
|
// expired token, etc.). If v is nil, claims are unmarshalled but not validated.
|
||||||
//
|
//
|
||||||
// claims must be a pointer whose underlying type embeds [StandardClaims] (or
|
// claims must be a pointer whose underlying type embeds [StandardClaims]:
|
||||||
// otherwise implements [StandardClaimsSource]):
|
|
||||||
//
|
//
|
||||||
// var claims AppClaims
|
// var claims AppClaims
|
||||||
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
|
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, v, time.Now())
|
||||||
func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) {
|
func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, v ClaimsValidator, now time.Time) (*JWS, []string, error) {
|
||||||
if iss.validator == nil {
|
jws, err := iss.Verify(tokenStr)
|
||||||
return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification")
|
|
||||||
}
|
|
||||||
|
|
||||||
jws, err := iss.UnsafeVerify(tokenStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@ -572,8 +635,12 @@ func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSourc
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
errs, err := iss.validator.Validate(claims, now)
|
if v == nil {
|
||||||
return jws, errs, err
|
return jws, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
errs, _ := v.Validate(claims, now) // discard sentinel; callers check len(errs) > 0
|
||||||
|
return jws, errs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyWith checks a JWS signature using the given algorithm and public key.
|
// verifyWith checks a JWS signature using the given algorithm and public key.
|
||||||
|
|||||||
@ -67,11 +67,11 @@ func goodClaims() AppClaims {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// goodValidator configures the validator. IgnoreIss is true because
|
// goodValidator configures the validator with iss set to "https://example.com".
|
||||||
// Issuer.UnsafeVerify already enforces the iss claim — no need to check twice.
|
// Iss checking is now the Validator's responsibility, not the Issuer's.
|
||||||
func goodValidator() *ajwt.Validator {
|
func goodValidator() *ajwt.Validator {
|
||||||
return &ajwt.Validator{
|
return &ajwt.Validator{
|
||||||
IgnoreIss: true, // UnsafeVerify handles iss
|
Iss: "https://example.com",
|
||||||
Sub: "user123",
|
Sub: "user123",
|
||||||
Aud: "myapp",
|
Aud: "myapp",
|
||||||
Jti: "abc123",
|
Jti: "abc123",
|
||||||
@ -82,15 +82,13 @@ func goodValidator() *ajwt.Validator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
|
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
|
||||||
return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator())
|
return ajwt.New([]ajwt.PublicJWK{pub})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoundTrip is the primary happy path using ES256.
|
// TestRoundTrip is the primary happy path using ES256.
|
||||||
// It demonstrates the full VerifyAndValidate flow:
|
// It demonstrates the full VerifyAndValidate flow:
|
||||||
//
|
//
|
||||||
// New → VerifyAndValidate → custom claim access
|
// New → VerifyAndValidate → custom claim access
|
||||||
//
|
|
||||||
// No Claims interface, no Verified flag, no type assertions on jws.
|
|
||||||
func TestRoundTrip(t *testing.T) {
|
func TestRoundTrip(t *testing.T) {
|
||||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -115,7 +113,7 @@ func TestRoundTrip(t *testing.T) {
|
|||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
jws2, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("VerifyAndValidate failed: %v", err)
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -125,7 +123,7 @@ func TestRoundTrip(t *testing.T) {
|
|||||||
if jws2.Header.Alg != "ES256" {
|
if jws2.Header.Alg != "ES256" {
|
||||||
t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg)
|
t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg)
|
||||||
}
|
}
|
||||||
// Direct field access — no type assertion needed, no jws.Claims interface.
|
// Direct field access — no type assertion needed.
|
||||||
if decoded.Email != claims.Email {
|
if decoded.Email != claims.Email {
|
||||||
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
|
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
|
||||||
}
|
}
|
||||||
@ -156,7 +154,7 @@ func TestRoundTripRS256(t *testing.T) {
|
|||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
_, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("VerifyAndValidate failed: %v", err)
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -190,7 +188,7 @@ func TestRoundTripEdDSA(t *testing.T) {
|
|||||||
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
_, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("VerifyAndValidate failed: %v", err)
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -209,8 +207,7 @@ func TestUnsafeVerifyFlow(t *testing.T) {
|
|||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
// Create issuer without validator — UnsafeVerify only.
|
iss := ajwt.New([]ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}})
|
||||||
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
|
|
||||||
|
|
||||||
jws2, err := iss.UnsafeVerify(token)
|
jws2, err := iss.UnsafeVerify(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -228,6 +225,34 @@ func TestUnsafeVerifyFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUnsafeVerifyReturnsJWSOnSigFailure verifies that UnsafeVerify returns a
|
||||||
|
// non-nil *JWS even when signature verification fails, so callers can inspect
|
||||||
|
// the header (kid, iss) for routing.
|
||||||
|
func TestUnsafeVerifyReturnsJWSOnSigFailure(t *testing.T) {
|
||||||
|
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(signingKey)
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
// Issuer has wrong public key — sig verification will fail.
|
||||||
|
iss := ajwt.New([]ajwt.PublicJWK{{Key: &wrongKey.PublicKey, KID: "k"}})
|
||||||
|
|
||||||
|
result, err := iss.UnsafeVerify(token)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for wrong key")
|
||||||
|
}
|
||||||
|
// UnsafeVerify must return the JWS despite sig failure.
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("UnsafeVerify should return non-nil JWS on sig failure")
|
||||||
|
}
|
||||||
|
if result.Header.Kid != "k" {
|
||||||
|
t.Errorf("expected kid %q, got %q", "k", result.Header.Kid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestCustomValidation demonstrates that ValidateStandardClaims is called
|
// TestCustomValidation demonstrates that ValidateStandardClaims is called
|
||||||
// explicitly and custom fields are validated without any Claims interface.
|
// explicitly and custom fields are validated without any Claims interface.
|
||||||
func TestCustomValidation(t *testing.T) {
|
func TestCustomValidation(t *testing.T) {
|
||||||
@ -264,8 +289,8 @@ func TestCustomValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly
|
// TestVerifyAndValidateNilValidator confirms that passing a nil ClaimsValidator
|
||||||
// when no Validator was provided at construction time.
|
// skips validation but still returns the verified JWS and unmarshalled claims.
|
||||||
func TestVerifyAndValidateNilValidator(t *testing.T) {
|
func TestVerifyAndValidateNilValidator(t *testing.T) {
|
||||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
c := goodClaims()
|
c := goodClaims()
|
||||||
@ -273,11 +298,21 @@ func TestVerifyAndValidateNilValidator(t *testing.T) {
|
|||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
|
iss := ajwt.New([]ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}})
|
||||||
|
|
||||||
var claims AppClaims
|
var claims AppClaims
|
||||||
if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil {
|
jws2, errs, err := iss.VerifyAndValidate(token, &claims, nil, time.Now())
|
||||||
t.Fatal("expected VerifyAndValidate to error with nil validator")
|
if err != nil {
|
||||||
|
t.Fatalf("expected success with nil validator: %v", err)
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
t.Fatalf("expected no validation errors with nil validator: %v", errs)
|
||||||
|
}
|
||||||
|
if jws2 == nil {
|
||||||
|
t.Fatal("expected non-nil JWS")
|
||||||
|
}
|
||||||
|
if claims.Email != c.Email {
|
||||||
|
t.Errorf("claims not unmarshalled: email got %q, want %q", claims.Email, c.Email)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,8 +328,8 @@ func TestIssuerWrongKey(t *testing.T) {
|
|||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
|
||||||
|
|
||||||
if _, err := iss.UnsafeVerify(token); err == nil {
|
if _, err := iss.Verify(token); err == nil {
|
||||||
t.Fatal("expected UnsafeVerify to fail with wrong key")
|
t.Fatal("expected Verify to fail with wrong key")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,27 +344,47 @@ func TestIssuerUnknownKid(t *testing.T) {
|
|||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
|
||||||
|
|
||||||
if _, err := iss.UnsafeVerify(token); err == nil {
|
if _, err := iss.Verify(token); err == nil {
|
||||||
t.Fatal("expected UnsafeVerify to fail for unknown kid")
|
t.Fatal("expected Verify to fail for unknown kid")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestIssuerIssMismatch confirms that a token with a mismatched iss is rejected
|
// TestIssuerIssMismatch confirms that a token with a mismatched iss is caught
|
||||||
// even if the signature is valid.
|
// by the Validator, not the Issuer. Signature verification succeeds; the iss
|
||||||
|
// mismatch appears as a soft validation error.
|
||||||
func TestIssuerIssMismatch(t *testing.T) {
|
func TestIssuerIssMismatch(t *testing.T) {
|
||||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
claims := goodClaims()
|
claims := goodClaims()
|
||||||
claims.Iss = "https://evil.example.com" // not the issuer URL
|
claims.Iss = "https://evil.example.com"
|
||||||
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
|
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
|
||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
// Issuer expects "https://example.com" but token says "https://evil.example.com"
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||||
|
|
||||||
if _, err := iss.UnsafeVerify(token); err == nil {
|
// UnsafeVerify succeeds — iss is not checked at the Issuer level.
|
||||||
t.Fatal("expected UnsafeVerify to fail: iss mismatch")
|
if _, err := iss.UnsafeVerify(token); err != nil {
|
||||||
|
t.Fatalf("UnsafeVerify should succeed (no iss check): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyAndValidate with a Validator that enforces iss catches the mismatch.
|
||||||
|
var decoded AppClaims
|
||||||
|
_, errs, err := iss.VerifyAndValidate(token, &decoded, goodValidator(), time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected hard error: %v", err)
|
||||||
|
}
|
||||||
|
if len(errs) == 0 {
|
||||||
|
t.Fatal("expected validation errors for iss mismatch")
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, e := range errs {
|
||||||
|
if strings.Contains(e, "iss") {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected iss error in validation errors: %v", errs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -352,8 +407,133 @@ func TestVerifyTamperedAlg(t *testing.T) {
|
|||||||
parts := strings.SplitN(token, ".", 3)
|
parts := strings.SplitN(token, ".", 3)
|
||||||
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
|
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
|
||||||
|
|
||||||
if _, err := iss.UnsafeVerify(tamperedToken); err == nil {
|
if _, err := iss.Verify(tamperedToken); err == nil {
|
||||||
t.Fatal("expected UnsafeVerify to fail for tampered alg")
|
t.Fatal("expected Verify to fail for tampered alg")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSignerRoundTrip verifies the Signer → Sign → Issuer → VerifyAndValidate flow.
|
||||||
|
func TestSignerRoundTrip(t *testing.T) {
|
||||||
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{KID: "k1", Signer: privKey}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
tokenStr, err := signer.Sign(&claims)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iss := signer.Issuer()
|
||||||
|
var decoded AppClaims
|
||||||
|
_, errs, err := iss.VerifyAndValidate(tokenStr, &decoded, goodValidator(), time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
t.Fatalf("claim validation failed: %v", errs)
|
||||||
|
}
|
||||||
|
if decoded.Email != claims.Email {
|
||||||
|
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSignerAutoKID verifies that KID is auto-computed from the key thumbprint
|
||||||
|
// when NamedSigner.KID is empty.
|
||||||
|
func TestSignerAutoKID(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{Signer: privKey}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := signer.PublicKeys()
|
||||||
|
if len(keys) != 1 {
|
||||||
|
t.Fatalf("expected 1 key, got %d", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].KID == "" {
|
||||||
|
t.Fatal("KID should be auto-computed from thumbprint")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token should verify with the auto-KID issuer.
|
||||||
|
iss := signer.Issuer()
|
||||||
|
claims := goodClaims()
|
||||||
|
tokenStr, _ := signer.Sign(&claims)
|
||||||
|
|
||||||
|
if _, err := iss.Verify(tokenStr); err != nil {
|
||||||
|
t.Fatalf("Verify failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSignerRoundRobin verifies that signing round-robins across keys and that
|
||||||
|
// all resulting tokens verify with the combined Issuer.
|
||||||
|
func TestSignerRoundRobin(t *testing.T) {
|
||||||
|
key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
key2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
signer, err := ajwt.NewSigner([]ajwt.NamedSigner{
|
||||||
|
{KID: "k1", Signer: key1},
|
||||||
|
{KID: "k2", Signer: key2},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iss := signer.Issuer()
|
||||||
|
v := goodValidator()
|
||||||
|
|
||||||
|
for i := range 4 {
|
||||||
|
claims := goodClaims()
|
||||||
|
tokenStr, err := signer.Sign(&claims)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Sign[%d] failed: %v", i, err)
|
||||||
|
}
|
||||||
|
var decoded AppClaims
|
||||||
|
if _, _, err := iss.VerifyAndValidate(tokenStr, &decoded, v, time.Now()); err != nil {
|
||||||
|
t.Fatalf("VerifyAndValidate[%d] failed: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIssuerToJWKs verifies JWKS serialization and round-trip parsing.
|
||||||
|
func TestIssuerToJWKs(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
signer, err := ajwt.NewSigner([]ajwt.NamedSigner{{KID: "k1", Signer: privKey}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iss := signer.Issuer()
|
||||||
|
jwksBytes, err := iss.ToJWKs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round-trip: parse the JWKS JSON and verify it produces a working Issuer.
|
||||||
|
keys, err := ajwt.UnmarshalPublicJWKs(jwksBytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(keys) != 1 {
|
||||||
|
t.Fatalf("expected 1 key, got %d", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].KID != "k1" {
|
||||||
|
t.Errorf("expected kid 'k1', got %q", keys[0].KID)
|
||||||
|
}
|
||||||
|
|
||||||
|
iss2 := ajwt.New(keys)
|
||||||
|
claims := goodClaims()
|
||||||
|
tokenStr, _ := signer.Sign(&claims)
|
||||||
|
if _, err := iss2.Verify(tokenStr); err != nil {
|
||||||
|
t.Fatalf("Verify on round-tripped JWKS failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -158,6 +158,74 @@ type JWKsJSON struct {
|
|||||||
Keys []PublicJWKJSON `json:"keys"`
|
Keys []PublicJWKJSON `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EncodePublicJWK converts a [PublicJWK] to its JSON representation.
|
||||||
|
//
|
||||||
|
// Supported key types: *ecdsa.PublicKey (EC), *rsa.PublicKey (RSA), ed25519.PublicKey (OKP).
|
||||||
|
func EncodePublicJWK(k PublicJWK) (PublicJWKJSON, error) {
|
||||||
|
switch key := k.Key.(type) {
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
var crv string
|
||||||
|
switch key.Curve {
|
||||||
|
case elliptic.P256():
|
||||||
|
crv = "P-256"
|
||||||
|
case elliptic.P384():
|
||||||
|
crv = "P-384"
|
||||||
|
case elliptic.P521():
|
||||||
|
crv = "P-521"
|
||||||
|
default:
|
||||||
|
return PublicJWKJSON{}, fmt.Errorf("EncodePublicJWK: unsupported EC curve %s", key.Curve.Params().Name)
|
||||||
|
}
|
||||||
|
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||||
|
xBytes := make([]byte, byteLen)
|
||||||
|
yBytes := make([]byte, byteLen)
|
||||||
|
key.X.FillBytes(xBytes)
|
||||||
|
key.Y.FillBytes(yBytes)
|
||||||
|
return PublicJWKJSON{
|
||||||
|
Kty: "EC",
|
||||||
|
KID: k.KID,
|
||||||
|
Crv: crv,
|
||||||
|
X: base64.RawURLEncoding.EncodeToString(xBytes),
|
||||||
|
Y: base64.RawURLEncoding.EncodeToString(yBytes),
|
||||||
|
Use: k.Use,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
eInt := big.NewInt(int64(key.E))
|
||||||
|
return PublicJWKJSON{
|
||||||
|
Kty: "RSA",
|
||||||
|
KID: k.KID,
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()),
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()),
|
||||||
|
Use: k.Use,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case ed25519.PublicKey:
|
||||||
|
return PublicJWKJSON{
|
||||||
|
Kty: "OKP",
|
||||||
|
KID: k.KID,
|
||||||
|
Crv: "Ed25519",
|
||||||
|
X: base64.RawURLEncoding.EncodeToString([]byte(key)),
|
||||||
|
Use: k.Use,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return PublicJWKJSON{}, fmt.Errorf("EncodePublicJWK: unsupported key type %T", k.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalPublicJWKs serializes a slice of [PublicJWK] as a JWKS JSON document.
|
||||||
|
func MarshalPublicJWKs(keys []PublicJWK) ([]byte, error) {
|
||||||
|
jsonKeys := make([]PublicJWKJSON, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
jk, err := EncodePublicJWK(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jsonKeys = append(jsonKeys, jk)
|
||||||
|
}
|
||||||
|
return json.Marshal(JWKsJSON{Keys: jsonKeys})
|
||||||
|
}
|
||||||
|
|
||||||
// FetchJWKs retrieves and parses a JWKS document from jwksURL.
|
// FetchJWKs retrieves and parses a JWKS document from jwksURL.
|
||||||
//
|
//
|
||||||
// ctx is used for the HTTP request timeout and cancellation.
|
// ctx is used for the HTTP request timeout and cancellation.
|
||||||
|
|||||||
98
auth/ajwt/sign.go
Normal file
98
auth/ajwt/sign.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package ajwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NamedSigner pairs a [crypto.Signer] with a key ID (KID).
|
||||||
|
//
|
||||||
|
// If KID is empty, it is auto-computed from the RFC 7638 thumbprint of the
|
||||||
|
// public key when passed to [NewSigner].
|
||||||
|
type NamedSigner struct {
|
||||||
|
KID string
|
||||||
|
Signer crypto.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signer manages one or more private signing keys and issues JWTs by
|
||||||
|
// round-robining across them.
|
||||||
|
//
|
||||||
|
// Do not copy a Signer after first use — it contains an atomic counter.
|
||||||
|
type Signer struct {
|
||||||
|
signers []NamedSigner
|
||||||
|
signerIdx atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSigner creates a Signer from the provided signing keys.
|
||||||
|
//
|
||||||
|
// If a NamedSigner's KID is empty, it is auto-computed from the RFC 7638
|
||||||
|
// thumbprint of the public key. Returns an error if the slice is empty or
|
||||||
|
// a thumbprint cannot be computed.
|
||||||
|
func NewSigner(signers []NamedSigner) (*Signer, error) {
|
||||||
|
if len(signers) == 0 {
|
||||||
|
return nil, fmt.Errorf("NewSigner: at least one signer is required")
|
||||||
|
}
|
||||||
|
// Copy so the caller can't mutate after construction.
|
||||||
|
ss := make([]NamedSigner, len(signers))
|
||||||
|
copy(ss, signers)
|
||||||
|
for i, ns := range ss {
|
||||||
|
if ns.KID == "" {
|
||||||
|
jwk := PublicJWK{Key: ns.Signer.Public()}
|
||||||
|
thumb, err := jwk.Thumbprint()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("NewSigner: compute thumbprint for signer[%d]: %w", i, err)
|
||||||
|
}
|
||||||
|
ss[i].KID = thumb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Signer{signers: ss}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign creates and signs a compact JWT from claims, using the next signing key
|
||||||
|
// in round-robin order. The caller is responsible for setting the "iss" field
|
||||||
|
// in claims if issuer identification is needed.
|
||||||
|
func (s *Signer) Sign(claims any) (string, error) {
|
||||||
|
idx := s.signerIdx.Add(1) - 1
|
||||||
|
ns := s.signers[idx%uint64(len(s.signers))]
|
||||||
|
|
||||||
|
jws, err := NewJWSFromClaims(claims, ns.KID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if _, err := jws.Sign(ns.Signer); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return jws.Encode(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issuer returns a new [*Issuer] containing the public keys of all signing keys.
|
||||||
|
//
|
||||||
|
// Use this to construct an Issuer for verifying tokens signed by this Signer.
|
||||||
|
// For key rotation, combine with old public keys:
|
||||||
|
//
|
||||||
|
// iss := ajwt.New(append(signer.PublicKeys(), oldKeys...))
|
||||||
|
func (s *Signer) Issuer() *Issuer {
|
||||||
|
return New(s.PublicKeys())
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicKeys returns the public-key side of each signing key, in the same order
|
||||||
|
// as the signers were provided to [NewSigner].
|
||||||
|
func (s *Signer) PublicKeys() []PublicJWK {
|
||||||
|
keys := make([]PublicJWK, len(s.signers))
|
||||||
|
for i, ns := range s.signers {
|
||||||
|
keys[i] = PublicJWK{
|
||||||
|
Key: ns.Signer.Public(),
|
||||||
|
KID: ns.KID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user