mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-20 02:57:58 +00:00
ajwt: fix Audience type, digestFor panic, validator bugs, fetcher staleness
- Add Audience type (RFC 7519 §4.1.3): unmarshals string or []string, marshals to string for single value and array for multiple - Fix digestFor panic: return ([]byte, error) instead of panicking on unsupported hash; plumb error through Sign and verifyWith callers - Fix headerJSON marshal error: propagate instead of discarding in NewJWSFromClaims and JWS.Sign (all three key-type branches) - Fix MaxAge/IgnoreAuthTime interaction: IgnoreAuthTime: true now correctly skips auth_time checks even when MaxAge > 0 - Fix "unchecked" warnings for Jti/Nonce/Azp: invert to opt-in — these fields are only validated when the Validator has them set - Fix MultiValidator.Aud for Audience type: checks if any token audience value is in the allowed list (set intersection) - Fix stale now in JWKsFetcher slow path: recapture time.Now() after acquiring the mutex so stale-window checks use a current timestamp - Remove RespectHeaders no-op field from JWKsFetcher - Simplify RSA exponent decode: use big.Int.IsInt64() instead of platform-dependent int size check
This commit is contained in:
parent
52ffecb5b3
commit
ac25aa2ee5
@ -64,10 +64,6 @@ type JWKsFetcher struct {
|
||||
// (nil, err).
|
||||
KeepOnError bool
|
||||
|
||||
// RespectHeaders is reserved for future use (honor Cache-Control max-age
|
||||
// from the JWKS response, capped at MaxAge).
|
||||
RespectHeaders bool
|
||||
|
||||
mu sync.Mutex
|
||||
cached atomic.Pointer[cachedIssuer]
|
||||
}
|
||||
@ -90,6 +86,10 @@ func (f *JWKsFetcher) Issuer(ctx context.Context) (*Issuer, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
// Recapture time after acquiring lock — the fast-path timestamp may be stale
|
||||
// if there was contention and another goroutine held the lock for a while.
|
||||
now = time.Now()
|
||||
|
||||
// Re-check after acquiring lock — another goroutine may have refreshed.
|
||||
if ci := f.cached.Load(); ci != nil && now.Before(ci.expiresAt) {
|
||||
return ci.iss, nil
|
||||
|
||||
136
auth/ajwt/jwt.go
136
auth/ajwt/jwt.go
@ -93,6 +93,42 @@ type StandardHeader struct {
|
||||
Typ string `json:"typ"`
|
||||
}
|
||||
|
||||
// Audience represents the "aud" JWT claim (RFC 7519 §4.1.3).
|
||||
//
|
||||
// It unmarshals from both a single string ("foo") and an array of strings
|
||||
// (["foo","bar"]). It marshals to a plain string for a single value and to
|
||||
// an array for multiple values, per the RFC.
|
||||
type Audience []string
|
||||
|
||||
// Contains reports whether s appears in the audience list.
|
||||
func (a Audience) Contains(s string) bool {
|
||||
return slices.Contains([]string(a), s)
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes both the string and []string forms of the "aud" claim.
|
||||
func (a *Audience) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
*a = Audience{s}
|
||||
return nil
|
||||
}
|
||||
var ss []string
|
||||
if err := json.Unmarshal(data, &ss); err != nil {
|
||||
return fmt.Errorf("'aud' must be a string or array of strings: %w", err)
|
||||
}
|
||||
*a = ss
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON encodes the audience as a plain string when there is one value,
|
||||
// or as a JSON array for multiple values, per RFC 7519 §4.1.3.
|
||||
func (a Audience) MarshalJSON() ([]byte, error) {
|
||||
if len(a) == 1 {
|
||||
return json.Marshal(a[0])
|
||||
}
|
||||
return json.Marshal([]string(a))
|
||||
}
|
||||
|
||||
// StandardClaims holds the registered JWT claim names defined in RFC 7519
|
||||
// and extended by OpenID Connect Core.
|
||||
//
|
||||
@ -108,7 +144,7 @@ type StandardHeader struct {
|
||||
type StandardClaims struct {
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
Aud string `json:"aud"`
|
||||
Aud Audience `json:"aud,omitempty"`
|
||||
Exp int64 `json:"exp"`
|
||||
Iat int64 `json:"iat"`
|
||||
AuthTime int64 `json:"auth_time"`
|
||||
@ -196,7 +232,10 @@ func NewJWSFromClaims(claims any, kid string) (*JWS, error) {
|
||||
Kid: kid,
|
||||
Typ: "JWT",
|
||||
}
|
||||
headerJSON, _ := json.Marshal(jws.Header)
|
||||
headerJSON, err := json.Marshal(jws.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal header: %w", err)
|
||||
}
|
||||
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
@ -227,10 +266,16 @@ func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
jws.Header.Alg = alg
|
||||
headerJSON, _ := json.Marshal(jws.Header)
|
||||
headerJSON, err := json.Marshal(jws.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal header: %w", err)
|
||||
}
|
||||
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
digest := digestFor(h, jws.Protected+"."+jws.Payload)
|
||||
digest, err := digestFor(h, jws.Protected+"."+jws.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS.
|
||||
derSig, err := key.Sign(rand.Reader, digest, h)
|
||||
if err != nil {
|
||||
@ -241,23 +286,30 @@ func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
|
||||
|
||||
case *rsa.PublicKey:
|
||||
jws.Header.Alg = "RS256"
|
||||
headerJSON, _ := json.Marshal(jws.Header)
|
||||
headerJSON, err := json.Marshal(jws.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal header: %w", err)
|
||||
}
|
||||
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload)
|
||||
digest, err := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly.
|
||||
var err error
|
||||
jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256)
|
||||
return jws.Signature, err
|
||||
|
||||
case ed25519.PublicKey:
|
||||
jws.Header.Alg = "EdDSA"
|
||||
headerJSON, _ := json.Marshal(jws.Header)
|
||||
headerJSON, err := json.Marshal(jws.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal header: %w", err)
|
||||
}
|
||||
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
// Ed25519 signs the raw message with no pre-hashing; pass crypto.Hash(0).
|
||||
signingInput := jws.Protected + "." + jws.Payload
|
||||
var err error
|
||||
jws.Signature, err = key.Sign(rand.Reader, []byte(signingInput), crypto.Hash(0))
|
||||
return jws.Signature, err
|
||||
|
||||
@ -340,10 +392,12 @@ func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([
|
||||
}
|
||||
|
||||
if !v.IgnoreAud {
|
||||
if sc.Aud == "" {
|
||||
if len(sc.Aud) == 0 {
|
||||
errs = append(errs, "missing or malformed 'aud' (audience)")
|
||||
} else if len(v.Aud) > 0 && !slices.Contains(v.Aud, sc.Aud) {
|
||||
errs = append(errs, fmt.Sprintf("'aud' %q not in allowed list", sc.Aud))
|
||||
} else if len(v.Aud) > 0 && !slices.ContainsFunc([]string(sc.Aud), func(a string) bool {
|
||||
return slices.Contains(v.Aud, a)
|
||||
}) {
|
||||
errs = append(errs, fmt.Sprintf("'aud' not in allowed list: %v", sc.Aud))
|
||||
}
|
||||
}
|
||||
|
||||
@ -364,7 +418,7 @@ func (v *MultiValidator) Validate(claims StandardClaimsSource, now time.Time) ([
|
||||
}
|
||||
}
|
||||
|
||||
if v.MaxAge > 0 || !v.IgnoreAuthTime {
|
||||
if !v.IgnoreAuthTime {
|
||||
if sc.AuthTime == 0 {
|
||||
errs = append(errs, "missing or malformed 'auth_time'")
|
||||
} else if sc.AuthTime > now.Unix() {
|
||||
@ -433,8 +487,8 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
||||
if len(v.Aud) > 0 || !v.IgnoreAud {
|
||||
if len(claims.Aud) == 0 {
|
||||
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
||||
} else if len(v.Aud) > 0 && claims.Aud != v.Aud {
|
||||
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, v.Aud))
|
||||
} else if len(v.Aud) > 0 && !claims.Aud.Contains(v.Aud) {
|
||||
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %v, expected %s", claims.Aud, v.Aud))
|
||||
}
|
||||
}
|
||||
|
||||
@ -461,7 +515,7 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
||||
}
|
||||
|
||||
// Should exist, in the past, with optional max age
|
||||
if v.MaxAge > 0 || !v.IgnoreAuthTime {
|
||||
if !v.IgnoreAuthTime {
|
||||
if claims.AuthTime == 0 {
|
||||
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
|
||||
} else {
|
||||
@ -484,22 +538,14 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
||||
}
|
||||
}
|
||||
|
||||
// Optional exact match
|
||||
if v.Jti != claims.Jti {
|
||||
if len(v.Jti) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
|
||||
} else if !v.IgnoreJti {
|
||||
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
|
||||
}
|
||||
// Optional exact match (only checked when v.Jti is set)
|
||||
if len(v.Jti) > 0 && v.Jti != claims.Jti {
|
||||
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, v.Jti))
|
||||
}
|
||||
|
||||
// Optional exact match
|
||||
if v.Nonce != claims.Nonce {
|
||||
if len(v.Nonce) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
|
||||
} else if !v.IgnoreNonce {
|
||||
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
|
||||
}
|
||||
// Optional exact match (only checked when v.Nonce is set)
|
||||
if len(v.Nonce) > 0 && v.Nonce != claims.Nonce {
|
||||
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, v.Nonce))
|
||||
}
|
||||
|
||||
// Should exist, optional required-set check
|
||||
@ -515,13 +561,9 @@ func ValidateStandardClaims(claims StandardClaims, v Validator, now time.Time) (
|
||||
}
|
||||
}
|
||||
|
||||
// Optional, match if present
|
||||
if v.Azp != claims.Azp {
|
||||
if len(v.Azp) > 0 {
|
||||
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
|
||||
} else if !v.IgnoreAzp {
|
||||
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
|
||||
}
|
||||
// Optional, match if present (only checked when v.Azp is set)
|
||||
if len(v.Azp) > 0 && v.Azp != claims.Azp {
|
||||
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, v.Azp))
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
@ -668,7 +710,10 @@ func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKe
|
||||
if len(sig) != 2*byteLen {
|
||||
return fmt.Errorf("invalid %s signature length: got %d, want %d", alg, len(sig), 2*byteLen)
|
||||
}
|
||||
digest := digestFor(h, signingInput)
|
||||
digest, err := digestFor(h, signingInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r := new(big.Int).SetBytes(sig[:byteLen])
|
||||
s := new(big.Int).SetBytes(sig[byteLen:])
|
||||
if !ecdsa.Verify(k, digest, r, s) {
|
||||
@ -681,7 +726,10 @@ func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKe
|
||||
if !ok {
|
||||
return fmt.Errorf("alg RS256 requires *rsa.PublicKey, got %T", key)
|
||||
}
|
||||
digest := digestFor(crypto.SHA256, signingInput)
|
||||
digest, err := digestFor(crypto.SHA256, signingInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, sig); err != nil {
|
||||
return fmt.Errorf("RS256 signature invalid: %w", err)
|
||||
}
|
||||
@ -717,19 +765,19 @@ func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func digestFor(h crypto.Hash, data string) []byte {
|
||||
func digestFor(h crypto.Hash, data string) ([]byte, error) {
|
||||
switch h {
|
||||
case crypto.SHA256:
|
||||
d := sha256.Sum256([]byte(data))
|
||||
return d[:]
|
||||
return d[:], nil
|
||||
case crypto.SHA384:
|
||||
d := sha512.Sum384([]byte(data))
|
||||
return d[:]
|
||||
return d[:], nil
|
||||
case crypto.SHA512:
|
||||
d := sha512.Sum512([]byte(data))
|
||||
return d[:]
|
||||
return d[:], nil
|
||||
default:
|
||||
panic(fmt.Sprintf("ajwt: unsupported hash %v", h))
|
||||
return nil, fmt.Errorf("ajwt: unsupported hash %v", h)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ func goodClaims() AppClaims {
|
||||
StandardClaims: ajwt.StandardClaims{
|
||||
Iss: "https://example.com",
|
||||
Sub: "user123",
|
||||
Aud: "myapp",
|
||||
Aud: ajwt.Audience{"myapp"},
|
||||
Exp: now.Add(time.Hour).Unix(),
|
||||
Iat: now.Unix(),
|
||||
AuthTime: now.Unix(),
|
||||
|
||||
@ -412,14 +412,18 @@ func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
|
||||
return nil, fmt.Errorf("invalid RSA exponent: %w", err)
|
||||
}
|
||||
|
||||
eInt := new(big.Int).SetBytes(e).Int64()
|
||||
if eInt > int64(^uint(0)>>1) || eInt < 0 {
|
||||
return nil, fmt.Errorf("RSA exponent too large or negative")
|
||||
eInt := new(big.Int).SetBytes(e)
|
||||
if !eInt.IsInt64() {
|
||||
return nil, fmt.Errorf("RSA exponent too large")
|
||||
}
|
||||
eVal := eInt.Int64()
|
||||
if eVal <= 0 {
|
||||
return nil, fmt.Errorf("RSA exponent must be positive")
|
||||
}
|
||||
|
||||
return &rsa.PublicKey{
|
||||
N: new(big.Int).SetBytes(n),
|
||||
E: int(eInt),
|
||||
E: int(eVal),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user