mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
ajwt: implement redesigned API from REDESIGN.md
Rename ValidateParams → Validator, make Issuer immutable after construction. Key changes: - StandardClaims.GetStandardClaims() + StandardClaimsSource interface: any struct embedding StandardClaims satisfies the interface for free via Go's method promotion — zero boilerplate for callers - Issuer is now immutable after construction; keys and validator are unexported; Params field removed - New constructors: New, NewWithJWKs, NewWithOIDC, NewWithOAuth2 - UnsafeVerify(tokenStr string) (*JWS, error): Decode + sig verify + iss check; "unsafe" means exp/aud/etc. are NOT checked - VerifyAndValidate(tokenStr, claims, now): full pipeline requiring non-nil Validator; fails loudly with nil Validator - FetchJWKs(ctx, url), FetchJWKsFromOIDC(ctx, base), FetchJWKsFromOAuth2(ctx, base): standalone fetch functions with context - PublicJWK.Thumbprint(): RFC 7638 SHA-256 thumbprint, canonical field ordering per spec (EC: crv/kty/x/y, RSA: e/kty/n, OKP: crv/kty/x) - DecodePublicJWKsJSON: auto-populates KID from Thumbprint when absent - Tests: 14 pass, covering VerifyAndValidate, UnsafeVerify, nil-validator error, all alg round trips, tampered alg, Thumbprint, auto-KID
This commit is contained in:
parent
1f0b36fc6d
commit
3f7985317f
353
auth/ajwt/jwt.go
353
auth/ajwt/jwt.go
@ -12,25 +12,33 @@
|
|||||||
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
|
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
|
||||||
// - [Issuer] owns key management and signature verification, centralizing
|
// - [Issuer] owns key management and signature verification, centralizing
|
||||||
// the key lookup → sig verify → iss check sequence.
|
// the key lookup → sig verify → iss check sequence.
|
||||||
// - [ValidateParams] is a stable config object; time is passed at the call
|
// - [Validator] is a stable config object; time is passed at the call site
|
||||||
// site so the same params can be reused across requests.
|
// so the same validator can be reused across requests.
|
||||||
|
// - [StandardClaimsSource] is implemented for free by embedding [StandardClaims].
|
||||||
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
|
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
|
||||||
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
|
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
|
||||||
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
|
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
|
||||||
//
|
//
|
||||||
// Typical usage:
|
// Typical usage with VerifyAndValidate:
|
||||||
//
|
//
|
||||||
// // At startup:
|
// // At startup:
|
||||||
// iss := ajwt.NewIssuer("https://accounts.example.com")
|
// iss, err := ajwt.NewWithOIDC(ctx, "https://accounts.example.com",
|
||||||
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
|
// &ajwt.Validator{Aud: "my-app", IgnoreIss: true})
|
||||||
// if err := iss.FetchKeys(ctx); err != nil { ... }
|
|
||||||
//
|
//
|
||||||
// // Per request:
|
// // Per request:
|
||||||
// jws, err := ajwt.Decode(tokenStr)
|
|
||||||
// if err := iss.Verify(jws); err != nil { ... } // sig + iss check
|
|
||||||
// var claims AppClaims
|
// var claims AppClaims
|
||||||
// if err := jws.UnmarshalClaims(&claims); err != nil { ... }
|
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
|
||||||
// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... }
|
// if err != nil { /* hard error: bad sig, expired, etc. */ }
|
||||||
|
// if len(errs) > 0 { /* soft errors: wrong aud, missing amr, etc. */ }
|
||||||
|
//
|
||||||
|
// Typical usage with UnsafeVerify (custom validation only):
|
||||||
|
//
|
||||||
|
// iss := ajwt.New("https://example.com", keys, nil)
|
||||||
|
// jws, err := iss.UnsafeVerify(tokenStr)
|
||||||
|
// var claims AppClaims
|
||||||
|
// jws.UnmarshalClaims(&claims)
|
||||||
|
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims,
|
||||||
|
// ajwt.Validator{Aud: "myapp"}, time.Now())
|
||||||
package ajwt
|
package ajwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -48,7 +56,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -58,10 +65,10 @@ import (
|
|||||||
//
|
//
|
||||||
// It holds only the parsed structure — header, raw base64url fields, and
|
// It holds only the parsed structure — header, raw base64url fields, and
|
||||||
// decoded signature bytes. It carries no Claims interface and no Verified flag;
|
// decoded signature bytes. It carries no Claims interface and no Verified flag;
|
||||||
// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to
|
// use [Issuer.UnsafeVerify] or [Issuer.VerifyAndValidate] to authenticate the
|
||||||
// decode the payload into a typed struct.
|
// token and [JWS.UnmarshalClaims] to decode the payload into a typed struct.
|
||||||
type JWS struct {
|
type JWS struct {
|
||||||
Protected string // base64url-encoded header
|
Protected string // base64url-encoded header
|
||||||
Header StandardHeader
|
Header StandardHeader
|
||||||
Payload string // base64url-encoded claims
|
Payload string // base64url-encoded claims
|
||||||
Signature []byte
|
Signature []byte
|
||||||
@ -77,12 +84,15 @@ type StandardHeader struct {
|
|||||||
// StandardClaims holds the registered JWT claim names defined in RFC 7519
|
// StandardClaims holds the registered JWT claim names defined in RFC 7519
|
||||||
// and extended by OpenID Connect Core.
|
// and extended by OpenID Connect Core.
|
||||||
//
|
//
|
||||||
// Embed StandardClaims in your own claims struct:
|
// Embed StandardClaims in your own claims struct to satisfy [StandardClaimsSource]
|
||||||
|
// for free via Go's method promotion — zero boilerplate:
|
||||||
//
|
//
|
||||||
// type AppClaims struct {
|
// type AppClaims struct {
|
||||||
// ajwt.StandardClaims
|
// ajwt.StandardClaims // promotes GetStandardClaims()
|
||||||
// Email string `json:"email"`
|
// Email string `json:"email"`
|
||||||
|
// Roles []string `json:"roles"`
|
||||||
// }
|
// }
|
||||||
|
// // AppClaims now satisfies StandardClaimsSource automatically.
|
||||||
type StandardClaims struct {
|
type StandardClaims struct {
|
||||||
Iss string `json:"iss"`
|
Iss string `json:"iss"`
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
@ -96,10 +106,24 @@ type StandardClaims struct {
|
|||||||
Jti string `json:"jti"`
|
Jti string `json:"jti"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStandardClaims implements [StandardClaimsSource].
|
||||||
|
// Any struct embedding StandardClaims gets this method for free via promotion.
|
||||||
|
func (sc StandardClaims) GetStandardClaims() StandardClaims { return sc }
|
||||||
|
|
||||||
|
// StandardClaimsSource is implemented for free by any struct that embeds [StandardClaims].
|
||||||
|
//
|
||||||
|
// type AppClaims struct {
|
||||||
|
// ajwt.StandardClaims // promotes GetStandardClaims() — zero boilerplate
|
||||||
|
// Email string `json:"email"`
|
||||||
|
// }
|
||||||
|
type StandardClaimsSource interface {
|
||||||
|
GetStandardClaims() StandardClaims
|
||||||
|
}
|
||||||
|
|
||||||
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
||||||
//
|
//
|
||||||
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
||||||
// [Issuer.Verify] to populate a typed claims struct.
|
// [Issuer.UnsafeVerify] to safely populate a typed claims struct.
|
||||||
func Decode(tokenStr string) (*JWS, error) {
|
func Decode(tokenStr string) (*JWS, error) {
|
||||||
parts := strings.Split(tokenStr, ".")
|
parts := strings.Split(tokenStr, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
@ -128,7 +152,7 @@ func Decode(tokenStr string) (*JWS, error) {
|
|||||||
// UnmarshalClaims decodes the JWT payload into v.
|
// UnmarshalClaims decodes the JWT payload into v.
|
||||||
//
|
//
|
||||||
// v must be a pointer to a struct (e.g. *AppClaims). Always call
|
// v must be a pointer to a struct (e.g. *AppClaims). Always call
|
||||||
// [Issuer.Verify] before UnmarshalClaims to ensure the signature is
|
// [Issuer.UnsafeVerify] before UnmarshalClaims to ensure the signature is
|
||||||
// authenticated before trusting the payload.
|
// authenticated before trusting the payload.
|
||||||
func (jws *JWS) UnmarshalClaims(v any) error {
|
func (jws *JWS) UnmarshalClaims(v any) error {
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
@ -232,16 +256,16 @@ func (jws *JWS) Encode() string {
|
|||||||
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
|
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateParams holds claim validation configuration.
|
// Validator holds claim validation configuration.
|
||||||
//
|
//
|
||||||
// Configure once at startup; call [ValidateParams.Validate] per request,
|
// Configure once at startup; call [Validator.Validate] per request, passing
|
||||||
// passing the current time. This keeps the config stable and makes the
|
// the current time. This keeps the config stable and makes the time dependency
|
||||||
// time dependency explicit at the call site.
|
// explicit at the call site.
|
||||||
//
|
//
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||||
type ValidateParams struct {
|
type Validator struct {
|
||||||
IgnoreIss bool
|
IgnoreIss bool
|
||||||
Iss string
|
Iss string // rarely needed — Issuer.UnsafeVerify already checks iss
|
||||||
IgnoreSub bool
|
IgnoreSub bool
|
||||||
Sub string
|
Sub string
|
||||||
IgnoreAud bool
|
IgnoreAud bool
|
||||||
@ -260,53 +284,53 @@ type ValidateParams struct {
|
|||||||
Azp string
|
Azp string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks the standard JWT/OIDC claim fields against this config.
|
// Validate checks the standard JWT/OIDC claim fields in claims against this config.
|
||||||
//
|
//
|
||||||
// now is typically time.Now() — passing it explicitly keeps the config stable
|
// now is typically time.Now() — passing it explicitly keeps the config stable
|
||||||
// across requests and avoids hidden time dependencies in the params struct.
|
// across requests and avoids hidden time dependencies in the validator struct.
|
||||||
func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) {
|
func (v Validator) Validate(claims StandardClaimsSource, now time.Time) ([]string, error) {
|
||||||
return ValidateStandardClaims(claims, p, now)
|
return ValidateStandardClaims(claims.GetStandardClaims(), v, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params.
|
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against v.
|
||||||
//
|
//
|
||||||
// Exported so callers can use it directly without a [ValidateParams] receiver:
|
// Exported so callers can use it directly without a [Validator] receiver:
|
||||||
//
|
//
|
||||||
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now())
|
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, v, time.Now())
|
||||||
func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) {
|
func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ([]string, error) {
|
||||||
var errs []string
|
var errs []string
|
||||||
|
|
||||||
// Required to exist and match
|
// Required to exist and match
|
||||||
if len(params.Iss) > 0 || !params.IgnoreIss {
|
if len(v.Iss) > 0 || !v.IgnoreIss {
|
||||||
if len(claims.Iss) == 0 {
|
if len(claims.Iss) == 0 {
|
||||||
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
|
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
|
||||||
} else if claims.Iss != params.Iss {
|
} else if claims.Iss != v.Iss {
|
||||||
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss))
|
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, v.Iss))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Required to exist, optional match
|
// Required to exist, optional match
|
||||||
if len(claims.Sub) == 0 {
|
if len(claims.Sub) == 0 {
|
||||||
if !params.IgnoreSub {
|
if !v.IgnoreSub {
|
||||||
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
|
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
|
||||||
}
|
}
|
||||||
} else if len(params.Sub) > 0 {
|
} else if len(v.Sub) > 0 {
|
||||||
if params.Sub != claims.Sub {
|
if v.Sub != claims.Sub {
|
||||||
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub))
|
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, v.Sub))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Required to exist and match
|
// Required to exist and match
|
||||||
if len(params.Aud) > 0 || !params.IgnoreAud {
|
if len(v.Aud) > 0 || !v.IgnoreAud {
|
||||||
if len(claims.Aud) == 0 {
|
if len(claims.Aud) == 0 {
|
||||||
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
||||||
} else if claims.Aud != params.Aud {
|
} else if claims.Aud != v.Aud {
|
||||||
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud))
|
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Required to exist and not be in the past
|
// Required to exist and not be in the past
|
||||||
if !params.IgnoreExp {
|
if !v.IgnoreExp {
|
||||||
if claims.Exp <= 0 {
|
if claims.Exp <= 0 {
|
||||||
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
|
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
|
||||||
} else if claims.Exp < now.Unix() {
|
} else if claims.Exp < now.Unix() {
|
||||||
@ -317,7 +341,7 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Required to exist and not be in the future
|
// Required to exist and not be in the future
|
||||||
if !params.IgnoreIat {
|
if !v.IgnoreIat {
|
||||||
if claims.Iat <= 0 {
|
if claims.Iat <= 0 {
|
||||||
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
|
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
|
||||||
} else if claims.Iat > now.Unix() {
|
} else if claims.Iat > now.Unix() {
|
||||||
@ -328,53 +352,53 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should exist, in the past, with optional max age
|
// Should exist, in the past, with optional max age
|
||||||
if params.MaxAge > 0 || !params.IgnoreAuthTime {
|
if v.MaxAge > 0 || !v.IgnoreAuthTime {
|
||||||
if claims.AuthTime == 0 {
|
if claims.AuthTime == 0 {
|
||||||
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
|
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
|
||||||
} else {
|
} else {
|
||||||
authTime := time.Unix(claims.AuthTime, 0)
|
authTime := time.Unix(claims.AuthTime, 0)
|
||||||
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
|
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
|
||||||
age := now.Sub(authTime)
|
age := now.Sub(authTime)
|
||||||
diff := age - params.MaxAge
|
diff := age - v.MaxAge
|
||||||
if claims.AuthTime > now.Unix() {
|
if claims.AuthTime > now.Unix() {
|
||||||
fromNow := time.Unix(claims.AuthTime, 0).Sub(now)
|
fromNow := time.Unix(claims.AuthTime, 0).Sub(now)
|
||||||
errs = append(errs, fmt.Sprintf(
|
errs = append(errs, fmt.Sprintf(
|
||||||
"'auth_time' of %s is %s in the future (server time %s)",
|
"'auth_time' of %s is %s in the future (server time %s)",
|
||||||
authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")),
|
authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")),
|
||||||
)
|
)
|
||||||
} else if params.MaxAge > 0 && age > params.MaxAge {
|
} else if v.MaxAge > 0 && age > v.MaxAge {
|
||||||
errs = append(errs, fmt.Sprintf(
|
errs = append(errs, fmt.Sprintf(
|
||||||
"'auth_time' of %s is %s old, exceeding max age %s by %s",
|
"'auth_time' of %s is %s old, exceeding max age %s by %s",
|
||||||
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)),
|
authTimeStr, formatDuration(age), formatDuration(v.MaxAge), formatDuration(diff)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional exact match
|
// Optional exact match
|
||||||
if params.Jti != claims.Jti {
|
if v.Jti != claims.Jti {
|
||||||
if len(params.Jti) > 0 {
|
if len(v.Jti) > 0 {
|
||||||
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti))
|
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
|
||||||
} else if !params.IgnoreJti {
|
} else if !v.IgnoreJti {
|
||||||
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
|
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional exact match
|
// Optional exact match
|
||||||
if params.Nonce != claims.Nonce {
|
if v.Nonce != claims.Nonce {
|
||||||
if len(params.Nonce) > 0 {
|
if len(v.Nonce) > 0 {
|
||||||
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce))
|
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
|
||||||
} else if !params.IgnoreNonce {
|
} else if !v.IgnoreNonce {
|
||||||
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
|
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should exist, optional required-set check
|
// Should exist, optional required-set check
|
||||||
if !params.IgnoreAmr {
|
if !v.IgnoreAmr {
|
||||||
if len(claims.Amr) == 0 {
|
if len(claims.Amr) == 0 {
|
||||||
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
|
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
|
||||||
} else if len(params.RequiredAmrs) > 0 {
|
} else if len(v.RequiredAmrs) > 0 {
|
||||||
for _, required := range params.RequiredAmrs {
|
for _, required := range v.RequiredAmrs {
|
||||||
if !slices.Contains(claims.Amr, required) {
|
if !slices.Contains(claims.Amr, required) {
|
||||||
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
|
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
|
||||||
}
|
}
|
||||||
@ -383,10 +407,10 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Optional, match if present
|
// Optional, match if present
|
||||||
if params.Azp != claims.Azp {
|
if v.Azp != claims.Azp {
|
||||||
if len(params.Azp) > 0 {
|
if len(v.Azp) > 0 {
|
||||||
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp))
|
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
|
||||||
} else if !params.IgnoreAzp {
|
} else if !v.IgnoreAzp {
|
||||||
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
|
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -402,111 +426,154 @@ func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now ti
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issuer holds public keys and validation config for a trusted token issuer.
|
// Issuer holds public keys and optional validation config for a trusted token issuer.
|
||||||
//
|
//
|
||||||
// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint.
|
// Create with [New], [NewWithJWKs], [NewWithOIDC], or [NewWithOAuth2].
|
||||||
// [Issuer.SetKeys] injects keys directly (useful in tests).
|
// After construction, Issuer is immutable.
|
||||||
// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check.
|
|
||||||
//
|
//
|
||||||
// Typical setup:
|
// [Issuer.UnsafeVerify] authenticates the token: Decode + key lookup + sig verify + iss check.
|
||||||
//
|
// [Issuer.VerifyAndValidate] additionally unmarshals claims and runs the Validator.
|
||||||
// iss := ajwt.NewIssuer("https://accounts.example.com")
|
|
||||||
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
|
|
||||||
// if err := iss.FetchKeys(ctx); err != nil { ... }
|
|
||||||
type Issuer struct {
|
type Issuer struct {
|
||||||
URL string
|
URL string // issuer URL for iss claim enforcement; empty skips the check
|
||||||
JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json"
|
validator *Validator
|
||||||
Params ValidateParams
|
keys map[string]crypto.PublicKey // kid → key
|
||||||
keys map[string]crypto.PublicKey // kid → key
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIssuer creates an Issuer for the given base URL.
|
// New creates an Issuer with explicit keys.
|
||||||
func NewIssuer(url string) *Issuer {
|
//
|
||||||
return &Issuer{
|
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||||
URL: url,
|
// [Issuer.VerifyAndValidate] requires a non-nil Validator.
|
||||||
keys: make(map[string]crypto.PublicKey),
|
func New(issURL string, keys []PublicJWK, v *Validator) *Issuer {
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKeys stores public keys by their KID, replacing any previously stored keys.
|
|
||||||
// Useful for injecting keys in tests without an HTTP round-trip.
|
|
||||||
func (iss *Issuer) SetKeys(keys []PublicJWK) {
|
|
||||||
m := make(map[string]crypto.PublicKey, len(keys))
|
m := make(map[string]crypto.PublicKey, len(keys))
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
m[k.KID] = k.Key
|
m[k.KID] = k.Key
|
||||||
}
|
}
|
||||||
iss.keys = m
|
return &Issuer{
|
||||||
|
URL: issURL,
|
||||||
|
validator: v,
|
||||||
|
keys: m,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchKeys retrieves and stores the JWKS from the issuer's endpoint.
|
// NewWithJWKs creates an Issuer by fetching keys from jwksURL.
|
||||||
// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json".
|
|
||||||
func (iss *Issuer) FetchKeys(ctx context.Context) error {
|
|
||||||
url := iss.JWKsURL
|
|
||||||
if url == "" {
|
|
||||||
url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json"
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("fetch JWKS: %w", err)
|
|
||||||
}
|
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("fetch JWKS: %w", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
keys, err := DecodePublicJWKs(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse JWKS: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
iss.SetKeys(keys)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify authenticates jws against this issuer:
|
|
||||||
// 1. Looks up the signing key by jws.Header.Kid.
|
|
||||||
// 2. Verifies the signature before trusting any payload data.
|
|
||||||
// 3. Checks that the token's "iss" claim matches iss.URL.
|
|
||||||
//
|
//
|
||||||
// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a
|
// The issuer URL (used for iss claim enforcement in [Issuer.UnsafeVerify]) is
|
||||||
// typed struct, then [ValidateParams.Validate] to check claim values.
|
// not set; use [New] or [NewWithOIDC]/[NewWithOAuth2] if you need iss enforcement.
|
||||||
func (iss *Issuer) Verify(jws *JWS) error {
|
//
|
||||||
|
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||||
|
func NewWithJWKs(ctx context.Context, jwksURL string, v *Validator) (*Issuer, error) {
|
||||||
|
keys, err := FetchJWKs(ctx, jwksURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return New("", keys, v), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithOIDC creates an Issuer using OIDC discovery.
|
||||||
|
//
|
||||||
|
// It fetches {baseURL}/.well-known/openid-configuration and reads the
|
||||||
|
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
|
||||||
|
// document's issuer field (not baseURL) because OIDC requires them to match.
|
||||||
|
//
|
||||||
|
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||||
|
func NewWithOIDC(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
|
||||||
|
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
|
||||||
|
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return New(issURL, keys, v), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithOAuth2 creates an Issuer using OAuth 2.0 authorization server metadata (RFC 8414).
|
||||||
|
//
|
||||||
|
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the
|
||||||
|
// jwks_uri and issuer fields. The Issuer URL is set from the discovery
|
||||||
|
// document's issuer field.
|
||||||
|
//
|
||||||
|
// v is optional — pass nil to use [Issuer.UnsafeVerify] only.
|
||||||
|
func NewWithOAuth2(ctx context.Context, baseURL string, v *Validator) (*Issuer, error) {
|
||||||
|
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
|
||||||
|
keys, issURL, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return New(issURL, keys, v), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsafeVerify decodes tokenStr, verifies the signature, and (if [Issuer.URL]
|
||||||
|
// is set) checks the iss claim.
|
||||||
|
//
|
||||||
|
// "Unsafe" means exp, aud, and other claim values are NOT checked — the token
|
||||||
|
// is forgery-safe but not semantically validated. Callers are responsible for
|
||||||
|
// validating claim values, or use [Issuer.VerifyAndValidate].
|
||||||
|
func (iss *Issuer) UnsafeVerify(tokenStr string) (*JWS, error) {
|
||||||
|
jws, err := Decode(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if jws.Header.Kid == "" {
|
if jws.Header.Kid == "" {
|
||||||
return fmt.Errorf("missing 'kid' header")
|
return nil, fmt.Errorf("missing 'kid' header")
|
||||||
}
|
}
|
||||||
key, ok := iss.keys[jws.Header.Kid]
|
key, ok := iss.keys[jws.Header.Kid]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("unknown kid: %q", jws.Header.Kid)
|
return nil, fmt.Errorf("unknown kid: %q", jws.Header.Kid)
|
||||||
}
|
}
|
||||||
|
|
||||||
signingInput := jws.Protected + "." + jws.Payload
|
signingInput := jws.Protected + "." + jws.Payload
|
||||||
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
|
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
|
||||||
return fmt.Errorf("signature verification failed: %w", err)
|
return nil, fmt.Errorf("signature verification failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signature verified — now safe to inspect the payload.
|
// Signature verified — now safe to inspect the payload for iss check.
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
if iss.URL != "" {
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid claims encoding: %w", err)
|
||||||
|
}
|
||||||
|
var partial struct {
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(payload, &partial); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid claims JSON: %w", err)
|
||||||
|
}
|
||||||
|
if partial.Iss != iss.URL {
|
||||||
|
return nil, fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return jws, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyAndValidate verifies the token signature and iss, unmarshals the claims
|
||||||
|
// into claims, and runs the [Validator].
|
||||||
|
//
|
||||||
|
// Returns a hard error (err != nil) for signature failures, decoding errors,
|
||||||
|
// and nil Validator. Returns soft errors (errs != nil) for claim validation
|
||||||
|
// failures (wrong aud, expired token, etc.).
|
||||||
|
//
|
||||||
|
// claims must be a pointer whose underlying type embeds [StandardClaims] (or
|
||||||
|
// otherwise implements [StandardClaimsSource]):
|
||||||
|
//
|
||||||
|
// var claims AppClaims
|
||||||
|
// jws, errs, err := iss.VerifyAndValidate(tokenStr, &claims, time.Now())
|
||||||
|
func (iss *Issuer) VerifyAndValidate(tokenStr string, claims StandardClaimsSource, now time.Time) (*JWS, []string, error) {
|
||||||
|
if iss.validator == nil {
|
||||||
|
return nil, nil, fmt.Errorf("VerifyAndValidate requires a non-nil Validator; use UnsafeVerify for signature-only verification")
|
||||||
|
}
|
||||||
|
|
||||||
|
jws, err := iss.UnsafeVerify(tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid claims encoding: %w", err)
|
return nil, nil, err
|
||||||
}
|
|
||||||
var partial struct {
|
|
||||||
Iss string `json:"iss"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(payload, &partial); err != nil {
|
|
||||||
return fmt.Errorf("invalid claims JSON: %w", err)
|
|
||||||
}
|
|
||||||
if partial.Iss != iss.URL {
|
|
||||||
return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
if err := jws.UnmarshalClaims(claims); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
errs, err := iss.validator.Validate(claims, now)
|
||||||
|
return jws, errs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyWith checks a JWS signature using the given algorithm and public key.
|
// verifyWith checks a JWS signature using the given algorithm and public key.
|
||||||
|
|||||||
@ -14,7 +14,9 @@ import (
|
|||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -23,9 +25,8 @@ import (
|
|||||||
|
|
||||||
// AppClaims embeds StandardClaims and adds application-specific fields.
|
// AppClaims embeds StandardClaims and adds application-specific fields.
|
||||||
//
|
//
|
||||||
// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate
|
// Because StandardClaims is embedded, AppClaims satisfies StandardClaimsSource
|
||||||
// interface — there is none. Validation is explicit: call
|
// for free via Go's method promotion — no interface to implement.
|
||||||
// ValidateStandardClaims or ValidateParams.Validate at the call site.
|
|
||||||
type AppClaims struct {
|
type AppClaims struct {
|
||||||
ajwt.StandardClaims
|
ajwt.StandardClaims
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@ -33,9 +34,10 @@ type AppClaims struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateAppClaims is a plain function — not a method satisfying an interface.
|
// validateAppClaims is a plain function — not a method satisfying an interface.
|
||||||
// Custom validation logic lives here, calling ValidateStandardClaims directly.
|
// It demonstrates the UnsafeVerify pattern: custom validation logic lives here,
|
||||||
func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) {
|
// calling ValidateStandardClaims directly and adding app-specific checks.
|
||||||
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now)
|
func validateAppClaims(c AppClaims, v ajwt.Validator, now time.Time) ([]string, error) {
|
||||||
|
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, v, now)
|
||||||
if c.Email == "" {
|
if c.Email == "" {
|
||||||
errs = append(errs, "missing email claim")
|
errs = append(errs, "missing email claim")
|
||||||
}
|
}
|
||||||
@ -65,11 +67,11 @@ func goodClaims() AppClaims {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// goodParams configures the validator. Iss is omitted because Issuer.Verify
|
// goodValidator configures the validator. IgnoreIss is true because
|
||||||
// already enforces the iss claim — no need to check it twice.
|
// Issuer.UnsafeVerify already enforces the iss claim — no need to check twice.
|
||||||
func goodParams() ajwt.ValidateParams {
|
func goodValidator() *ajwt.Validator {
|
||||||
return ajwt.ValidateParams{
|
return &ajwt.Validator{
|
||||||
IgnoreIss: true, // Issuer.Verify handles iss
|
IgnoreIss: true, // UnsafeVerify handles iss
|
||||||
Sub: "user123",
|
Sub: "user123",
|
||||||
Aud: "myapp",
|
Aud: "myapp",
|
||||||
Jti: "abc123",
|
Jti: "abc123",
|
||||||
@ -80,16 +82,13 @@ func goodParams() ajwt.ValidateParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
|
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
|
||||||
iss := ajwt.NewIssuer("https://example.com")
|
return ajwt.New("https://example.com", []ajwt.PublicJWK{pub}, goodValidator())
|
||||||
iss.Params = goodParams()
|
|
||||||
iss.SetKeys([]ajwt.PublicJWK{pub})
|
|
||||||
return iss
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoundTrip is the primary happy path using ES256.
|
// TestRoundTrip is the primary happy path using ES256.
|
||||||
// It demonstrates the full Issuer-based flow:
|
// It demonstrates the full VerifyAndValidate flow:
|
||||||
//
|
//
|
||||||
// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate
|
// New → VerifyAndValidate → custom claim access
|
||||||
//
|
//
|
||||||
// No Claims interface, no Verified flag, no type assertions on jws.
|
// No Claims interface, no Verified flag, no type assertions on jws.
|
||||||
func TestRoundTrip(t *testing.T) {
|
func TestRoundTrip(t *testing.T) {
|
||||||
@ -115,20 +114,16 @@ func TestRoundTrip(t *testing.T) {
|
|||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
||||||
|
|
||||||
jws2, err := ajwt.Decode(token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err = iss.Verify(jws2); err != nil {
|
|
||||||
t.Fatalf("Verify failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
jws2, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
}
|
}
|
||||||
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
|
if len(errs) > 0 {
|
||||||
t.Fatalf("validation failed: %v", errs)
|
t.Fatalf("claim validation failed: %v", errs)
|
||||||
|
}
|
||||||
|
if jws2.Header.Alg != "ES256" {
|
||||||
|
t.Errorf("expected ES256 alg in jws, got %s", jws2.Header.Alg)
|
||||||
}
|
}
|
||||||
// Direct field access — no type assertion needed, no jws.Claims interface.
|
// Direct field access — no type assertion needed, no jws.Claims interface.
|
||||||
if decoded.Email != claims.Email {
|
if decoded.Email != claims.Email {
|
||||||
@ -160,20 +155,13 @@ func TestRoundTripRS256(t *testing.T) {
|
|||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
|
||||||
|
|
||||||
jws2, err := ajwt.Decode(token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err = iss.Verify(jws2); err != nil {
|
|
||||||
t.Fatalf("Verify failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
}
|
}
|
||||||
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
|
if len(errs) > 0 {
|
||||||
t.Fatalf("validation failed: %v", errs)
|
t.Fatalf("claim validation failed: %v", errs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,25 +189,47 @@ func TestRoundTripEdDSA(t *testing.T) {
|
|||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
|
||||||
|
|
||||||
jws2, err := ajwt.Decode(token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err = iss.Verify(jws2); err != nil {
|
|
||||||
t.Fatalf("Verify failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
_, errs, err := iss.VerifyAndValidate(token, &decoded, time.Now())
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyAndValidate failed: %v", err)
|
||||||
}
|
}
|
||||||
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
|
if len(errs) > 0 {
|
||||||
t.Fatalf("validation failed: %v", errs)
|
t.Fatalf("claim validation failed: %v", errs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestCustomValidation demonstrates custom claim validation without any interface.
|
// TestUnsafeVerifyFlow demonstrates the UnsafeVerify + custom validation pattern.
|
||||||
// The caller owns the validation logic and calls ValidateStandardClaims directly.
|
// The caller owns the full validation pipeline.
|
||||||
|
func TestUnsafeVerifyFlow(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
// Create issuer without validator — UnsafeVerify only.
|
||||||
|
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
|
||||||
|
|
||||||
|
jws2, err := iss.UnsafeVerify(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UnsafeVerify failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
||||||
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errs, err := ajwt.ValidateStandardClaims(decoded.StandardClaims, *goodValidator(), time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateStandardClaims failed: %v — errs: %v", err, errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCustomValidation demonstrates that ValidateStandardClaims is called
|
||||||
|
// explicitly and custom fields are validated without any Claims interface.
|
||||||
func TestCustomValidation(t *testing.T) {
|
func TestCustomValidation(t *testing.T) {
|
||||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
@ -231,12 +241,15 @@ func TestCustomValidation(t *testing.T) {
|
|||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||||
jws2, _ := ajwt.Decode(token)
|
jws2, err := iss.UnsafeVerify(token)
|
||||||
_ = iss.Verify(jws2)
|
if err != nil {
|
||||||
|
t.Fatalf("UnsafeVerify failed unexpectedly: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
_ = jws2.UnmarshalClaims(&decoded)
|
_ = jws2.UnmarshalClaims(&decoded)
|
||||||
|
|
||||||
errs, err := validateAppClaims(decoded, goodParams(), time.Now())
|
errs, err := validateAppClaims(decoded, *goodValidator(), time.Now())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected validation to fail: email is empty")
|
t.Fatal("expected validation to fail: email is empty")
|
||||||
}
|
}
|
||||||
@ -251,6 +264,23 @@ func TestCustomValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestVerifyAndValidateNilValidator confirms that VerifyAndValidate fails loudly
|
||||||
|
// when no Validator was provided at construction time.
|
||||||
|
func TestVerifyAndValidateNilValidator(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
c := goodClaims()
|
||||||
|
jws, _ := ajwt.NewJWSFromClaims(&c, "k")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
iss := ajwt.New("https://example.com", []ajwt.PublicJWK{{Key: &privKey.PublicKey, KID: "k"}}, nil)
|
||||||
|
|
||||||
|
var claims AppClaims
|
||||||
|
if _, _, err := iss.VerifyAndValidate(token, &claims, time.Now()); err == nil {
|
||||||
|
t.Fatal("expected VerifyAndValidate to error with nil validator")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestIssuerWrongKey confirms that a different key's public key is rejected.
|
// TestIssuerWrongKey confirms that a different key's public key is rejected.
|
||||||
func TestIssuerWrongKey(t *testing.T) {
|
func TestIssuerWrongKey(t *testing.T) {
|
||||||
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
@ -262,10 +292,9 @@ func TestIssuerWrongKey(t *testing.T) {
|
|||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
|
||||||
jws2, _ := ajwt.Decode(token)
|
|
||||||
|
|
||||||
if err := iss.Verify(jws2); err == nil {
|
if _, err := iss.UnsafeVerify(token); err == nil {
|
||||||
t.Fatal("expected Verify to fail with wrong key")
|
t.Fatal("expected UnsafeVerify to fail with wrong key")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,10 +308,9 @@ func TestIssuerUnknownKid(t *testing.T) {
|
|||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
|
||||||
jws2, _ := ajwt.Decode(token)
|
|
||||||
|
|
||||||
if err := iss.Verify(jws2); err == nil {
|
if _, err := iss.UnsafeVerify(token); err == nil {
|
||||||
t.Fatal("expected Verify to fail for unknown kid")
|
t.Fatal("expected UnsafeVerify to fail for unknown kid")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,14 +327,15 @@ func TestIssuerIssMismatch(t *testing.T) {
|
|||||||
|
|
||||||
// Issuer expects "https://example.com" but token says "https://evil.example.com"
|
// Issuer expects "https://example.com" but token says "https://evil.example.com"
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||||
jws2, _ := ajwt.Decode(token)
|
|
||||||
|
|
||||||
if err := iss.Verify(jws2); err == nil {
|
if _, err := iss.UnsafeVerify(token); err == nil {
|
||||||
t.Fatal("expected Verify to fail: iss mismatch")
|
t.Fatal("expected UnsafeVerify to fail: iss mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestVerifyTamperedAlg confirms that a tampered alg header is rejected.
|
// TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected.
|
||||||
|
// The token is reconstructed with a replaced protected header; the original
|
||||||
|
// ES256 signature is kept, making the signing input mismatch detectable.
|
||||||
func TestVerifyTamperedAlg(t *testing.T) {
|
func TestVerifyTamperedAlg(t *testing.T) {
|
||||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
@ -316,11 +345,15 @@ func TestVerifyTamperedAlg(t *testing.T) {
|
|||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
|
||||||
jws2, _ := ajwt.Decode(token)
|
|
||||||
jws2.Header.Alg = "none" // tamper
|
|
||||||
|
|
||||||
if err := iss.Verify(jws2); err == nil {
|
// Replace the protected header with one that has alg:"none".
|
||||||
t.Fatal("expected Verify to fail for tampered alg")
|
// The original ES256 signature stays — the signing input will mismatch.
|
||||||
|
noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`))
|
||||||
|
parts := strings.SplitN(token, ".", 3)
|
||||||
|
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
|
||||||
|
|
||||||
|
if _, err := iss.UnsafeVerify(tamperedToken); err == nil {
|
||||||
|
t.Fatal("expected UnsafeVerify to fail for tampered alg")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,3 +441,79 @@ func TestDecodePublicJWKJSON(t *testing.T) {
|
|||||||
t.Errorf("expected 1 RSA key, got %d", rsaCount)
|
t.Errorf("expected 1 RSA key, got %d", rsaCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestThumbprint verifies that Thumbprint returns a non-empty base64url string
|
||||||
|
// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint.
|
||||||
|
func TestThumbprint(t *testing.T) {
|
||||||
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
jwk ajwt.PublicJWK
|
||||||
|
}{
|
||||||
|
{"EC P-256", ajwt.PublicJWK{Key: &ecKey.PublicKey}},
|
||||||
|
{"RSA 2048", ajwt.PublicJWK{Key: &rsaKey.PublicKey}},
|
||||||
|
{"Ed25519", ajwt.PublicJWK{Key: edPub}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
thumb, err := tt.jwk.Thumbprint()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Thumbprint() error: %v", err)
|
||||||
|
}
|
||||||
|
if thumb == "" {
|
||||||
|
t.Fatal("Thumbprint() returned empty string")
|
||||||
|
}
|
||||||
|
// Must be valid base64url (no padding, no +/)
|
||||||
|
if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") {
|
||||||
|
t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb)
|
||||||
|
}
|
||||||
|
// Same key → same thumbprint
|
||||||
|
thumb2, _ := tt.jwk.Thumbprint()
|
||||||
|
if thumb != thumb2 {
|
||||||
|
t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets
|
||||||
|
// its KID auto-populated from the RFC 7638 thumbprint.
|
||||||
|
func TestNoKidAutoThumbprint(t *testing.T) {
|
||||||
|
// EC key with no "kid" field in the JWKS
|
||||||
|
jwksJSON := []byte(`{"keys":[
|
||||||
|
{"kty":"EC","crv":"P-256",
|
||||||
|
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
|
||||||
|
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
|
||||||
|
"use":"sig"}
|
||||||
|
]}`)
|
||||||
|
|
||||||
|
keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(keys) != 1 {
|
||||||
|
t.Fatalf("expected 1 key, got %d", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].KID == "" {
|
||||||
|
t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The auto-KID should be a valid base64url string.
|
||||||
|
kid := keys[0].KID
|
||||||
|
if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") {
|
||||||
|
t.Errorf("auto-KID contains non-base64url characters: %s", kid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round-trip: compute Thumbprint directly and compare.
|
||||||
|
thumb, err := keys[0].Thumbprint()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Thumbprint() error: %v", err)
|
||||||
|
}
|
||||||
|
if kid != thumb {
|
||||||
|
t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
171
auth/ajwt/pub.go
171
auth/ajwt/pub.go
@ -9,11 +9,13 @@
|
|||||||
package ajwt
|
package ajwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -21,6 +23,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -54,6 +57,90 @@ func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) {
|
|||||||
return key, ok
|
return key, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Thumbprint computes the RFC 7638 JWK Thumbprint (SHA-256 of the canonical
|
||||||
|
// key JSON with fields in lexicographic order). The result is base64url-encoded.
|
||||||
|
//
|
||||||
|
// Canonical forms per RFC 7638:
|
||||||
|
// - EC: {"crv":…, "kty":"EC", "x":…, "y":…}
|
||||||
|
// - RSA: {"e":…, "kty":"RSA", "n":…}
|
||||||
|
// - OKP: {"crv":"Ed25519", "kty":"OKP", "x":…}
|
||||||
|
//
|
||||||
|
// Use Thumbprint as KID when none is provided in the JWKS source.
|
||||||
|
func (k PublicJWK) Thumbprint() (string, error) {
|
||||||
|
var canonical []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch key := k.Key.(type) {
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||||
|
xBytes := make([]byte, byteLen)
|
||||||
|
yBytes := make([]byte, byteLen)
|
||||||
|
key.X.FillBytes(xBytes)
|
||||||
|
key.Y.FillBytes(yBytes)
|
||||||
|
|
||||||
|
var crv string
|
||||||
|
switch key.Curve {
|
||||||
|
case elliptic.P256():
|
||||||
|
crv = "P-256"
|
||||||
|
case elliptic.P384():
|
||||||
|
crv = "P-384"
|
||||||
|
case elliptic.P521():
|
||||||
|
crv = "P-521"
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("Thumbprint: unsupported EC curve %s", key.Curve.Params().Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields in lexicographic order: crv, kty, x, y
|
||||||
|
canonical, err = json.Marshal(struct {
|
||||||
|
Crv string `json:"crv"`
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
X string `json:"x"`
|
||||||
|
Y string `json:"y"`
|
||||||
|
}{
|
||||||
|
Crv: crv,
|
||||||
|
Kty: "EC",
|
||||||
|
X: base64.RawURLEncoding.EncodeToString(xBytes),
|
||||||
|
Y: base64.RawURLEncoding.EncodeToString(yBytes),
|
||||||
|
})
|
||||||
|
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
eInt := big.NewInt(int64(key.E))
|
||||||
|
|
||||||
|
// Fields in lexicographic order: e, kty, n
|
||||||
|
canonical, err = json.Marshal(struct {
|
||||||
|
E string `json:"e"`
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
N string `json:"n"`
|
||||||
|
}{
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(eInt.Bytes()),
|
||||||
|
Kty: "RSA",
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(key.N.Bytes()),
|
||||||
|
})
|
||||||
|
|
||||||
|
case ed25519.PublicKey:
|
||||||
|
// Fields in lexicographic order: crv, kty, x
|
||||||
|
canonical, err = json.Marshal(struct {
|
||||||
|
Crv string `json:"crv"`
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
X string `json:"x"`
|
||||||
|
}{
|
||||||
|
Crv: "Ed25519",
|
||||||
|
Kty: "OKP",
|
||||||
|
X: base64.RawURLEncoding.EncodeToString([]byte(key)),
|
||||||
|
})
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("Thumbprint: unsupported key type %T", k.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Thumbprint: marshal canonical JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sha256.Sum256(canonical)
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sum[:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
|
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
|
||||||
type PublicJWKJSON struct {
|
type PublicJWKJSON struct {
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
@ -71,24 +158,83 @@ type JWKsJSON struct {
|
|||||||
Keys []PublicJWKJSON `json:"keys"`
|
Keys []PublicJWKJSON `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchPublicJWKs retrieves and parses a JWKS document from url.
|
// FetchJWKs retrieves and parses a JWKS document from jwksURL.
|
||||||
//
|
//
|
||||||
// For issuer-scoped key management with context support, use
|
// ctx is used for the HTTP request timeout and cancellation.
|
||||||
// [Issuer.FetchKeys] instead.
|
func FetchJWKs(ctx context.Context, jwksURL string) ([]PublicJWK, error) {
|
||||||
func FetchPublicJWKs(url string) ([]PublicJWK, error) {
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
|
||||||
resp, err := client.Get(url)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
return nil, fmt.Errorf("fetch JWKS: %w", err)
|
||||||
|
}
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("fetch JWKS: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
return nil, fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
return DecodePublicJWKs(resp.Body)
|
return DecodePublicJWKs(resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FetchJWKsFromOIDC fetches JWKS via OIDC discovery from baseURL.
|
||||||
|
//
|
||||||
|
// It fetches {baseURL}/.well-known/openid-configuration and reads the jwks_uri field.
|
||||||
|
func FetchJWKsFromOIDC(ctx context.Context, baseURL string) ([]PublicJWK, error) {
|
||||||
|
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
|
||||||
|
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchJWKsFromOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata (RFC 8414)
|
||||||
|
// from baseURL.
|
||||||
|
//
|
||||||
|
// It fetches {baseURL}/.well-known/oauth-authorization-server and reads the jwks_uri field.
|
||||||
|
func FetchJWKsFromOAuth2(ctx context.Context, baseURL string) ([]PublicJWK, error) {
|
||||||
|
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
|
||||||
|
keys, _, err := fetchJWKsFromDiscovery(ctx, discoveryURL)
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchJWKsFromDiscovery fetches a discovery document from discoveryURL, then
|
||||||
|
// fetches the JWKS from the jwks_uri field. Returns the keys and the issuer
|
||||||
|
// URL from the discovery document's "issuer" field.
|
||||||
|
func fetchJWKsFromDiscovery(ctx context.Context, discoveryURL string) ([]PublicJWK, string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("fetch discovery: %w", err)
|
||||||
|
}
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("fetch discovery: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, "", fmt.Errorf("fetch discovery: unexpected status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var doc struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
JWKsURI string `json:"jwks_uri"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("parse discovery doc: %w", err)
|
||||||
|
}
|
||||||
|
if doc.JWKsURI == "" {
|
||||||
|
return nil, "", fmt.Errorf("discovery doc missing jwks_uri field")
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := FetchJWKs(ctx, doc.JWKsURI)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
return keys, doc.Issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReadPublicJWKs reads and parses a JWKS document from a file path.
|
// ReadPublicJWKs reads and parses a JWKS document from a file path.
|
||||||
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
|
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
|
||||||
file, err := os.Open(filePath)
|
file, err := os.Open(filePath)
|
||||||
@ -118,6 +264,9 @@ func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
|
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
|
||||||
|
//
|
||||||
|
// If a key has no kid field in the source document, the KID is auto-populated
|
||||||
|
// from [PublicJWK.Thumbprint] per RFC 7638.
|
||||||
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
|
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
|
||||||
var keys []PublicJWK
|
var keys []PublicJWK
|
||||||
for _, jwk := range jwks.Keys {
|
for _, jwk := range jwks.Keys {
|
||||||
@ -125,6 +274,12 @@ func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err)
|
return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err)
|
||||||
}
|
}
|
||||||
|
if key.KID == "" {
|
||||||
|
key.KID, err = key.Thumbprint()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("compute thumbprint for kid-less key: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
keys = append(keys, *key)
|
keys = append(keys, *key)
|
||||||
}
|
}
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user