From ac25aa2ee5e706bdd484722f2d8b189c7a1239cf Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 13 Mar 2026 12:04:15 -0600 Subject: [PATCH] ajwt: fix Audience type, digestFor panic, validator bugs, fetcher staleness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Audience type (RFC 7519 §4.1.3): unmarshals string or []string, marshals to string for single value and array for multiple - Fix digestFor panic: return ([]byte, error) instead of panicking on unsupported hash; plumb error through Sign and verifyWith callers - Fix headerJSON marshal error: propagate instead of discarding in NewJWSFromClaims and JWS.Sign (all three key-type branches) - Fix MaxAge/IgnoreAuthTime interaction: IgnoreAuthTime: true now correctly skips auth_time checks even when MaxAge > 0 - Fix "unchecked" warnings for Jti/Nonce/Azp: invert to opt-in — these fields are only validated when the Validator has them set - Fix MultiValidator.Aud for Audience type: checks if any token audience value is in the allowed list (set intersection) - Fix stale now in JWKsFetcher slow path: recapture time.Now() after acquiring the mutex so stale-window checks use a current timestamp - Remove RespectHeaders no-op field from JWKsFetcher - Simplify RSA exponent decode: use big.Int.IsInt64() instead of platform-dependent int size check --- auth/ajwt/fetcher.go | 8 +-- auth/ajwt/jwt.go | 136 ++++++++++++++++++++++++++++-------------- auth/ajwt/jwt_test.go | 2 +- auth/ajwt/pub.go | 12 ++-- 4 files changed, 105 insertions(+), 53 deletions(-) diff --git a/auth/ajwt/fetcher.go b/auth/ajwt/fetcher.go index f8a00e9..e91ddfd 100644 --- a/auth/ajwt/fetcher.go +++ b/auth/ajwt/fetcher.go @@ -64,10 +64,6 @@ type JWKsFetcher struct { // (nil, err). KeepOnError bool - // RespectHeaders is reserved for future use (honor Cache-Control max-age - // from the JWKS response, capped at MaxAge). - RespectHeaders bool - mu sync.Mutex cached atomic.Pointer[cachedIssuer] } @@ -90,6 +86,10 @@ func (f *JWKsFetcher) Issuer(ctx context.Context) (*Issuer, error) { f.mu.Lock() defer f.mu.Unlock() + // Recapture time after acquiring lock — the fast-path timestamp may be stale + // if there was contention and another goroutine held the lock for a while. + now = time.Now() + // Re-check after acquiring lock — another goroutine may have refreshed. if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) { return ci.iss, nil diff --git a/auth/ajwt/jwt.go b/auth/ajwt/jwt.go index ad60a18..4c96ec1 100644 --- a/auth/ajwt/jwt.go +++ b/auth/ajwt/jwt.go @@ -93,6 +93,42 @@ type StandardHeader struct { Typ string `json:"typ"` } +// Audience represents the "aud" JWT claim (RFC 7519 §4.1.3). +// +// It unmarshals from both a single string ("foo") and an array of strings +// (["foo","bar"]). It marshals to a plain string for a single value and to +// an array for multiple values, per the RFC. +type Audience []string + +// Contains reports whether s appears in the audience list. +func (a Audience) Contains(s string) bool { + return slices.Contains([]string(a), s) +} + +// UnmarshalJSON decodes both the string and []string forms of the "aud" claim. +func (a *Audience) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + *a = Audience{s} + return nil + } + var ss []string + if err := json.Unmarshal(data, &ss); err != nil { + return fmt.Errorf("'aud' must be a string or array of strings: %w", err) + } + *a = ss + return nil +} + +// MarshalJSON encodes the audience as a plain string when there is one value, +// or as a JSON array for multiple values, per RFC 7519 §4.1.3. +func (a Audience) MarshalJSON() ([]byte, error) { + if len(a) == 1 { + return json.Marshal(a[0]) + } + return json.Marshal([]string(a)) +} + // StandardClaims holds the registered JWT claim names defined in RFC 7519 // and extended by OpenID Connect Core. // @@ -108,7 +144,7 @@ type StandardHeader struct { type StandardClaims struct { Iss string `json:"iss"` Sub string `json:"sub"` - Aud string `json:"aud"` + Aud Audience `json:"aud,omitempty"` Exp int64 `json:"exp"` Iat int64 `json:"iat"` AuthTime int64 `json:"auth_time"` @@ -196,7 +232,10 @@ func NewJWSFromClaims(claims any, kid string) (*JWS, error) { Kid: kid, Typ: "JWT", } - headerJSON, _ := json.Marshal(jws.Header) + headerJSON, err := json.Marshal(jws.Header) + if err != nil { + return nil, fmt.Errorf("marshal header: %w", err) + } jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) claimsJSON, err := json.Marshal(claims) @@ -227,10 +266,16 @@ func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) { return nil, err } jws.Header.Alg = alg - headerJSON, _ := json.Marshal(jws.Header) + headerJSON, err := json.Marshal(jws.Header) + if err != nil { + return nil, fmt.Errorf("marshal header: %w", err) + } jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) - digest := digestFor(h, jws.Protected+"."+jws.Payload) + digest, err := digestFor(h, jws.Protected+"."+jws.Payload) + if err != nil { + return nil, err + } // crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS. derSig, err := key.Sign(rand.Reader, digest, h) if err != nil { @@ -241,23 +286,30 @@ func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) { case *rsa.PublicKey: jws.Header.Alg = "RS256" - headerJSON, _ := json.Marshal(jws.Header) + headerJSON, err := json.Marshal(jws.Header) + if err != nil { + return nil, fmt.Errorf("marshal header: %w", err) + } jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) - digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload) + digest, err := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload) + if err != nil { + return nil, err + } // crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly. - var err error jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256) return jws.Signature, err case ed25519.PublicKey: jws.Header.Alg = "EdDSA" - headerJSON, _ := json.Marshal(jws.Header) + headerJSON, err := json.Marshal(jws.Header) + if err != nil { + return nil, fmt.Errorf("marshal header: %w", err) + } jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) // Ed25519 signs the raw message with no pre-hashing; pass crypto.Hash(0). signingInput := jws.Protected + "." + jws.Payload - var err error jws.Signature, err = key.Sign(rand.Reader, []byte(signingInput), crypto.Hash(0)) return jws.Signature, err @@ -340,10 +392,12 @@ func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([ } if !v.IgnoreAud { - if sc.Aud == "" { + if len(sc.Aud) == 0 { errs = append(errs, "missing or malformed 'aud' (audience)") - } else if len(v.Aud) > 0 && !slices.Contains(v.Aud, sc.Aud) { - errs = append(errs, fmt.Sprintf("'aud' %q not in allowed list", sc.Aud)) + } else if len(v.Aud) > 0 && !slices.ContainsFunc([]string(sc.Aud), func(a string) bool { + return slices.Contains(v.Aud, a) + }) { + errs = append(errs, fmt.Sprintf("'aud' not in allowed list: %v", sc.Aud)) } } @@ -364,7 +418,7 @@ func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([ } } - if v.MaxAge > 0 || !v.IgnoreAuthTime { + if !v.IgnoreAuthTime { if sc.AuthTime == 0 { errs = append(errs, "missing or malformed 'auth_time'") } else if sc.AuthTime > now.Unix() { @@ -433,8 +487,8 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( if len(v.Aud) > 0 || !v.IgnoreAud { if len(claims.Aud) == 0 { errs = append(errs, "missing or malformed 'aud' (audience receiving token)") - } else if len(v.Aud) > 0 && claims.Aud != v.Aud { - errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud)) + } else if len(v.Aud) > 0 && !claims.Aud.Contains(v.Aud) { + errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %v, expected %s", claims.Aud, v.Aud)) } } @@ -461,7 +515,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( } // Should exist, in the past, with optional max age - if v.MaxAge > 0 || !v.IgnoreAuthTime { + if !v.IgnoreAuthTime { if claims.AuthTime == 0 { errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)") } else { @@ -484,22 +538,14 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( } } - // Optional exact match - if v.Jti != claims.Jti { - if len(v.Jti) > 0 { - errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti)) - } else if !v.IgnoreJti { - errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti)) - } + // Optional exact match (only checked when v.Jti is set) + if len(v.Jti) > 0 && v.Jti != claims.Jti { + errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti)) } - // Optional exact match - if v.Nonce != claims.Nonce { - if len(v.Nonce) > 0 { - errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce)) - } else if !v.IgnoreNonce { - errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce)) - } + // Optional exact match (only checked when v.Nonce is set) + if len(v.Nonce) > 0 && v.Nonce != claims.Nonce { + errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce)) } // Should exist, optional required-set check @@ -515,13 +561,9 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) ( } } - // Optional, match if present - if v.Azp != claims.Azp { - if len(v.Azp) > 0 { - errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp)) - } else if !v.IgnoreAzp { - errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp)) - } + // Optional, match if present (only checked when v.Azp is set) + if len(v.Azp) > 0 && v.Azp != claims.Azp { + errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp)) } if len(errs) > 0 { @@ -668,7 +710,10 @@ func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKe if len(sig) != 2*byteLen { return fmt.Errorf("invalid %s signature length: got %d, want %d", alg, len(sig), 2*byteLen) } - digest := digestFor(h, signingInput) + digest, err := digestFor(h, signingInput) + if err != nil { + return err + } r := new(big.Int).SetBytes(sig[:byteLen]) s := new(big.Int).SetBytes(sig[byteLen:]) if !ecdsa.Verify(k, digest, r, s) { @@ -681,7 +726,10 @@ func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKe if !ok { return fmt.Errorf("alg RS256 requires *rsa.PublicKey, got %T", key) } - digest := digestFor(crypto.SHA256, signingInput) + digest, err := digestFor(crypto.SHA256, signingInput) + if err != nil { + return err + } if err := rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, sig); err != nil { return fmt.Errorf("RS256 signature invalid: %w", err) } @@ -717,19 +765,19 @@ func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) { } } -func digestFor(h crypto.Hash, data string) []byte { +func digestFor(h crypto.Hash, data string) ([]byte, error) { switch h { case crypto.SHA256: d := sha256.Sum256([]byte(data)) - return d[:] + return d[:], nil case crypto.SHA384: d := sha512.Sum384([]byte(data)) - return d[:] + return d[:], nil case crypto.SHA512: d := sha512.Sum512([]byte(data)) - return d[:] + return d[:], nil default: - panic(fmt.Sprintf("ajwt: unsupported hash %v", h)) + return nil, fmt.Errorf("ajwt: unsupported hash %v", h) } } diff --git a/auth/ajwt/jwt_test.go b/auth/ajwt/jwt_test.go index e7b6c0b..51f8240 100644 --- a/auth/ajwt/jwt_test.go +++ b/auth/ajwt/jwt_test.go @@ -53,7 +53,7 @@ func goodClaims() AppClaims { StandardClaims: ajwt.StandardClaims{ Iss: "https://example.com", Sub: "user123", - Aud: "myapp", + Aud: ajwt.Audience{"myapp"}, Exp: now.Add(time.Hour).Unix(), Iat: now.Unix(), AuthTime: now.Unix(), diff --git a/auth/ajwt/pub.go b/auth/ajwt/pub.go index 59dd690..5bfa144 100644 --- a/auth/ajwt/pub.go +++ b/auth/ajwt/pub.go @@ -412,14 +412,18 @@ func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) { return nil, fmt.Errorf("invalid RSA exponent: %w", err) } - eInt := new(big.Int).SetBytes(e).Int64() - if eInt > int64(^uint(0)>>1) || eInt < 0 { - return nil, fmt.Errorf("RSA exponent too large or negative") + eInt := new(big.Int).SetBytes(e) + if !eInt.IsInt64() { + return nil, fmt.Errorf("RSA exponent too large") + } + eVal := eInt.Int64() + if eVal <= 0 { + return nil, fmt.Errorf("RSA exponent must be positive") } return &rsa.PublicKey{ N: new(big.Int).SetBytes(n), - E: int(eInt), + E: int(eVal), }, nil }