ajwt: fix Audience type, digestFor panic, validator bugs, fetcher staleness

- Add Audience type (RFC 7519 §4.1.3): unmarshals string or []string,
  marshals to string for single value and array for multiple
- Fix digestFor panic: return ([]byte, error) instead of panicking on
  unsupported hash; plumb error through Sign and verifyWith callers
- Fix headerJSON marshal error: propagate instead of discarding in
  NewJWSFromClaims and JWS.Sign (all three key-type branches)
- Fix MaxAge/IgnoreAuthTime interaction: IgnoreAuthTime: true now
  correctly skips auth_time checks even when MaxAge > 0
- Fix "unchecked" warnings for Jti/Nonce/Azp: invert to opt-in —
  these fields are only validated when the Validator has them set
- Fix MultiValidator.Aud for Audience type: checks if any token
  audience value is in the allowed list (set intersection)
- Fix stale now in JWKsFetcher slow path: recapture time.Now() after
  acquiring the mutex so stale-window checks use a current timestamp
- Remove RespectHeaders no-op field from JWKsFetcher
- Simplify RSA exponent decode: use big.Int.IsInt64() instead of
  platform-dependent int size check
This commit is contained in:
AJ ONeal 2026-03-13 12:04:15 -06:00
parent 52ffecb5b3
commit ac25aa2ee5
No known key found for this signature in database
4 changed files with 105 additions and 53 deletions

View File

@ -64,10 +64,6 @@ type JWKsFetcher struct {
// (nil, err).
KeepOnError bool
// RespectHeaders is reserved for future use (honor Cache-Control max-age
// from the JWKS response, capped at MaxAge).
RespectHeaders bool
mu sync.Mutex
cached atomic.Pointer[cachedIssuer]
}
@ -90,6 +86,10 @@ func (f *JWKsFetcher) Issuer(ctx context.Context) (*Issuer, error) {
f.mu.Lock()
defer f.mu.Unlock()
// Recapture time after acquiring lock — the fast-path timestamp may be stale
// if there was contention and another goroutine held the lock for a while.
now = time.Now()
// Re-check after acquiring lock — another goroutine may have refreshed.
if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) {
return ci.iss, nil

View File

@ -93,6 +93,42 @@ type StandardHeader struct {
Typ string `json:"typ"`
}
// Audience represents the "aud" JWT claim (RFC 7519 §4.1.3).
//
// It unmarshals from both a single string ("foo") and an array of strings
// (["foo","bar"]). It marshals to a plain string for a single value and to
// an array for multiple values, per the RFC.
type Audience []string
// Contains reports whether s appears in the audience list.
func (a Audience) Contains(s string) bool {
return slices.Contains([]string(a), s)
}
// UnmarshalJSON decodes both the string and []string forms of the "aud" claim.
func (a *Audience) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
*a = Audience{s}
return nil
}
var ss []string
if err := json.Unmarshal(data, &ss); err != nil {
return fmt.Errorf("'aud' must be a string or array of strings: %w", err)
}
*a = ss
return nil
}
// MarshalJSON encodes the audience as a plain string when there is one value,
// or as a JSON array for multiple values, per RFC 7519 §4.1.3.
func (a Audience) MarshalJSON() ([]byte, error) {
if len(a) == 1 {
return json.Marshal(a[0])
}
return json.Marshal([]string(a))
}
// StandardClaims holds the registered JWT claim names defined in RFC 7519
// and extended by OpenID Connect Core.
//
@ -108,7 +144,7 @@ type StandardHeader struct {
type StandardClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Aud string `json:"aud"`
Aud Audience `json:"aud,omitempty"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
AuthTime int64 `json:"auth_time"`
@ -196,7 +232,10 @@ func NewJWSFromClaims(claims any, kid string) (*JWS, error) {
Kid: kid,
Typ: "JWT",
}
headerJSON, _ := json.Marshal(jws.Header)
headerJSON, err := json.Marshal(jws.Header)
if err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
claimsJSON, err := json.Marshal(claims)
@ -227,10 +266,16 @@ func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
return nil, err
}
jws.Header.Alg = alg
headerJSON, _ := json.Marshal(jws.Header)
headerJSON, err := json.Marshal(jws.Header)
if err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
digest := digestFor(h, jws.Protected+"."+jws.Payload)
digest, err := digestFor(h, jws.Protected+"."+jws.Payload)
if err != nil {
return nil, err
}
// crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS.
derSig, err := key.Sign(rand.Reader, digest, h)
if err != nil {
@ -241,23 +286,30 @@ func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
case *rsa.PublicKey:
jws.Header.Alg = "RS256"
headerJSON, _ := json.Marshal(jws.Header)
headerJSON, err := json.Marshal(jws.Header)
if err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload)
digest, err := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload)
if err != nil {
return nil, err
}
// crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly.
var err error
jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256)
return jws.Signature, err
case ed25519.PublicKey:
jws.Header.Alg = "EdDSA"
headerJSON, _ := json.Marshal(jws.Header)
headerJSON, err := json.Marshal(jws.Header)
if err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
// Ed25519 signs the raw message with no pre-hashing; pass crypto.Hash(0).
signingInput := jws.Protected + "." + jws.Payload
var err error
jws.Signature, err = key.Sign(rand.Reader, []byte(signingInput), crypto.Hash(0))
return jws.Signature, err
@ -340,10 +392,12 @@ func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([
}
if !v.IgnoreAud {
if sc.Aud == "" {
if len(sc.Aud) == 0 {
errs = append(errs, "missing or malformed 'aud' (audience)")
} else if len(v.Aud) > 0 && !slices.Contains(v.Aud, sc.Aud) {
errs = append(errs, fmt.Sprintf("'aud' %q not in allowed list", sc.Aud))
} else if len(v.Aud) > 0 && !slices.ContainsFunc([]string(sc.Aud), func(a string) bool {
return slices.Contains(v.Aud, a)
}) {
errs = append(errs, fmt.Sprintf("'aud' not in allowed list: %v", sc.Aud))
}
}
@ -364,7 +418,7 @@ func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([
}
}
if v.MaxAge > 0 || !v.IgnoreAuthTime {
if !v.IgnoreAuthTime {
if sc.AuthTime == 0 {
errs = append(errs, "missing or malformed 'auth_time'")
} else if sc.AuthTime > now.Unix() {
@ -433,8 +487,8 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
if len(v.Aud) > 0 || !v.IgnoreAud {
if len(claims.Aud) == 0 {
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
} else if len(v.Aud) > 0 && claims.Aud != v.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
} else if len(v.Aud) > 0 && !claims.Aud.Contains(v.Aud) {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %v, expected %s", claims.Aud, v.Aud))
}
}
@ -461,7 +515,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
}
// Should exist, in the past, with optional max age
if v.MaxAge > 0 || !v.IgnoreAuthTime {
if !v.IgnoreAuthTime {
if claims.AuthTime == 0 {
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
} else {
@ -484,22 +538,14 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
}
}
// Optional exact match
if v.Jti != claims.Jti {
if len(v.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
} else if !v.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
}
// Optional exact match (only checked when v.Jti is set)
if len(v.Jti) > 0 && v.Jti != claims.Jti {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
}
// Optional exact match
if v.Nonce != claims.Nonce {
if len(v.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
} else if !v.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
}
// Optional exact match (only checked when v.Nonce is set)
if len(v.Nonce) > 0 && v.Nonce != claims.Nonce {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
}
// Should exist, optional required-set check
@ -515,13 +561,9 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
}
}
// Optional, match if present
if v.Azp != claims.Azp {
if len(v.Azp) > 0 {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
} else if !v.IgnoreAzp {
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
}
// Optional, match if present (only checked when v.Azp is set)
if len(v.Azp) > 0 && v.Azp != claims.Azp {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
}
if len(errs) > 0 {
@ -668,7 +710,10 @@ func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKe
if len(sig) != 2*byteLen {
return fmt.Errorf("invalid %s signature length: got %d, want %d", alg, len(sig), 2*byteLen)
}
digest := digestFor(h, signingInput)
digest, err := digestFor(h, signingInput)
if err != nil {
return err
}
r := new(big.Int).SetBytes(sig[:byteLen])
s := new(big.Int).SetBytes(sig[byteLen:])
if !ecdsa.Verify(k, digest, r, s) {
@ -681,7 +726,10 @@ func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKe
if !ok {
return fmt.Errorf("alg RS256 requires *rsa.PublicKey, got %T", key)
}
digest := digestFor(crypto.SHA256, signingInput)
digest, err := digestFor(crypto.SHA256, signingInput)
if err != nil {
return err
}
if err := rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, sig); err != nil {
return fmt.Errorf("RS256 signature invalid: %w", err)
}
@ -717,19 +765,19 @@ func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) {
}
}
func digestFor(h crypto.Hash, data string) []byte {
func digestFor(h crypto.Hash, data string) ([]byte, error) {
switch h {
case crypto.SHA256:
d := sha256.Sum256([]byte(data))
return d[:]
return d[:], nil
case crypto.SHA384:
d := sha512.Sum384([]byte(data))
return d[:]
return d[:], nil
case crypto.SHA512:
d := sha512.Sum512([]byte(data))
return d[:]
return d[:], nil
default:
panic(fmt.Sprintf("ajwt: unsupported hash %v", h))
return nil, fmt.Errorf("ajwt: unsupported hash %v", h)
}
}

View File

@ -53,7 +53,7 @@ func goodClaims() AppClaims {
StandardClaims: ajwt.StandardClaims{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Aud: ajwt.Audience{"myapp"},
Exp: now.Add(time.Hour).Unix(),
Iat: now.Unix(),
AuthTime: now.Unix(),

View File

@ -412,14 +412,18 @@ func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
return nil, fmt.Errorf("invalid RSA exponent: %w", err)
}
eInt := new(big.Int).SetBytes(e).Int64()
if eInt > int64(^uint(0)>>1) || eInt < 0 {
return nil, fmt.Errorf("RSA exponent too large or negative")
eInt := new(big.Int).SetBytes(e)
if !eInt.IsInt64() {
return nil, fmt.Errorf("RSA exponent too large")
}
eVal := eInt.Int64()
if eVal <= 0 {
return nil, fmt.Errorf("RSA exponent must be positive")
}
return &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(eInt),
E: int(eVal),
}, nil
}