diff --git a/auth/bestjwt/go.mod b/auth/bestjwt/go.mod new file mode 100644 index 0000000..f585ee0 --- /dev/null +++ b/auth/bestjwt/go.mod @@ -0,0 +1,3 @@ +module github.com/therootcompany/golib/auth/bestjwt + +go 1.24.0 diff --git a/auth/bestjwt/jwt.go b/auth/bestjwt/jwt.go new file mode 100644 index 0000000..739e6c3 --- /dev/null +++ b/auth/bestjwt/jwt.go @@ -0,0 +1,706 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +// Package bestjwt is a lightweight JWT/JWS/JWK library that combines the +// best ergonomics from the genericjwt and embeddedjwt packages: +// +// - Claims via embedded structs: Decode(token, &myClaims) — no generics at +// call sites, no type assertions to access custom fields. +// - crypto.Signer for all signing: works with in-process keys AND +// hardware-backed keys (HSM, cloud KMS, PKCS#11) without modification. +// - Full ECDSA curve support: ES256 (P-256), ES384 (P-384), ES512 (P-521). +// The algorithm is inferred from the key's curve, not hardcoded. +// - Curve/algorithm consistency enforcement: UnsafeVerify rejects a P-256 +// key presented for an ES384 token and vice versa. +// - Generic PublicJWK[K Key]: type-safe JWKS key management with TypedKeys +// to filter a mixed []PublicJWK[Key] to a concrete key type. +package bestjwt + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "encoding/asn1" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "slices" + "strings" + "time" +) + +// Claims is the interface that custom claims types must satisfy. +// +// Because [StandardClaims] implements Claims with a value receiver, any struct +// that embeds StandardClaims satisfies Claims automatically via method +// promotion — no boilerplate required: +// +// type AppClaims struct { +// bestjwt.StandardClaims +// Email string `json:"email"` +// } +// // AppClaims now satisfies Claims for free. +// +// Override Validate on the outer struct to add application-specific checks. +// Call [ValidateStandardClaims] inside your override to preserve all standard +// OIDC/JWT validation. +type Claims interface { + Validate(params ValidateParams) ([]string, error) +} + +// JWS is a decoded JSON Web Signature / JWT. +// +// Claims is stored as the [Claims] interface so that any embedded-struct type +// can be used without generics. The most ergonomic access pattern is via the +// pointer you passed to [Decode]: +// +// var claims AppClaims +// jws, err := bestjwt.Decode(tokenString, &claims) +// jws.UnsafeVerify(pubKey) +// errs, err := jws.Validate(params) +// // Access claims.Email, claims.Roles, etc. directly — no type assertion. +type JWS struct { + Protected string `json:"-"` // base64url-encoded header + Header StandardHeader `json:"header"` + Payload string `json:"-"` // base64url-encoded claims + Claims Claims `json:"claims"` + Signature URLBase64 `json:"signature"` + Verified bool `json:"-"` +} + +// StandardHeader holds the standard JOSE header fields. +type StandardHeader struct { + Alg string `json:"alg"` + Kid string `json:"kid"` + Typ string `json:"typ"` +} + +// StandardClaims holds the registered JWT claim names defined in RFC 7519 +// and extended by OpenID Connect Core. +// +// Embed StandardClaims in your own struct to satisfy [Claims] automatically: +// +// type AppClaims struct { +// bestjwt.StandardClaims +// Email string `json:"email"` +// } +// // AppClaims now satisfies Claims via promoted Validate. +// // Override Validate on AppClaims to add custom checks. +type StandardClaims struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Aud string `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + AuthTime int64 `json:"auth_time"` + Nonce string `json:"nonce,omitempty"` + Amr []string `json:"amr"` + Azp string `json:"azp,omitempty"` + Jti string `json:"jti"` +} + +// Validate implements [Claims] by checking all standard OIDC/JWT claim fields. +// +// This method is promoted to any struct that embeds [StandardClaims], so +// embedding structs satisfy Claims without writing any additional code. +// params.Now must be non-zero; [JWS.Validate] ensures this before delegating. +func (c StandardClaims) Validate(params ValidateParams) ([]string, error) { + return ValidateStandardClaims(c, params) +} + +// Decode parses a compact JWT string (header.payload.signature) into a JWS. +// +// claims must be a pointer to the caller's pre-allocated claims struct +// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it, +// and the same pointer is stored in jws.Claims. Callers can access custom +// fields through their own variable without a type assertion: +// +// var claims AppClaims +// jws, err := bestjwt.Decode(token, &claims) +// // claims.Email is already set; no type assertion needed. +// +// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first, +// then [JWS.Validate]. +func Decode(tokenStr string, claims Claims) (*JWS, error) { + parts := strings.Split(tokenStr, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT format") + } + + var jws JWS + var sigEnc string + jws.Protected, jws.Payload, sigEnc = parts[0], parts[1], parts[2] + + header, err := base64.RawURLEncoding.DecodeString(jws.Protected) + if err != nil { + return nil, fmt.Errorf("invalid header encoding: %v", err) + } + if err := json.Unmarshal(header, &jws.Header); err != nil { + return nil, fmt.Errorf("invalid header JSON: %v", err) + } + + payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return nil, fmt.Errorf("invalid claims encoding: %v", err) + } + // Unmarshal into the concrete type behind the Claims interface. + // json.Unmarshal reaches the concrete pointer via reflection. + if err := json.Unmarshal(payload, claims); err != nil { + return nil, fmt.Errorf("invalid claims JSON: %v", err) + } + + if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil { + return nil, fmt.Errorf("invalid signature encoding: %v", err) + } + + jws.Claims = claims + return &jws, nil +} + +// NewJWSFromClaims creates an unsigned JWS from the provided claims. +// +// kid identifies the signing key. The "alg" header field is set automatically +// when [JWS.Sign] is called, since only the key type determines the algorithm. +// Call [JWS.Encode] to produce a compact JWT string after signing. +func NewJWSFromClaims(claims Claims, kid string) (*JWS, error) { + var jws JWS + + jws.Header = StandardHeader{ + // Alg is intentionally omitted here; Sign sets it from the key type. + Kid: kid, + Typ: "JWT", + } + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + claimsJSON, _ := json.Marshal(claims) + jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON) + jws.Claims = claims + + return &jws, nil +} + +// Sign signs the JWS in-place using the provided [crypto.Signer]. +// +// The algorithm is determined from the signer's public key type and, for EC +// keys, from the curve. The "alg" header field is set and the protected header +// is re-encoded before computing the signing input, so the signed bytes are +// always consistent with the token header. +// +// Supported algorithms (inferred automatically): +// - *ecdsa.PublicKey P-256 → ES256 (ECDSA + SHA-256, raw r||s) +// - *ecdsa.PublicKey P-384 → ES384 (ECDSA + SHA-384, raw r||s) +// - *ecdsa.PublicKey P-521 → ES512 (ECDSA + SHA-512, raw r||s) +// - *rsa.PublicKey → RS256 (PKCS#1 v1.5 + SHA-256) +// +// Because the parameter is [crypto.Signer] rather than a concrete key type, +// hardware-backed signers (HSM, OS keychain, cloud KMS) work transparently. +func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) { + switch pub := key.Public().(type) { + case *ecdsa.PublicKey: + alg, h, err := algForECKey(pub) + if err != nil { + return nil, err + } + jws.Header.Alg = alg + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + digest := digestFor(h, jws.Protected+"."+jws.Payload) + // crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS. + derSig, err := key.Sign(rand.Reader, digest, h) + if err != nil { + return nil, fmt.Errorf("Sign %s: %w", alg, err) + } + jws.Signature, err = ecdsaDERToRaw(derSig, pub.Curve) + return jws.Signature, err + + case *rsa.PublicKey: + jws.Header.Alg = "RS256" + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload) + // crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly. + var err error + jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256) + return jws.Signature, err + + default: + return nil, fmt.Errorf( + "Sign: unsupported public key type %T (supported: *ecdsa.PublicKey, *rsa.PublicKey)", + key.Public(), + ) + } +} + +// Encode produces the compact JWT string (header.payload.signature). +func (jws JWS) Encode() string { + sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature) + return jws.Protected + "." + jws.Payload + "." + sigEnc +} + +// UnsafeVerify checks the signature using the algorithm in the JWT header and +// sets jws.Verified on success. It only checks the signature — use +// [JWS.Validate] to check claim values (expiry, issuer, audience, etc.). +// +// pub must be of the concrete type matching the header alg (e.g. +// *ecdsa.PublicKey for ES256/ES384/ES512). Callers can pass PublicJWK.Key +// directly without a type assertion. +// +// For ECDSA tokens, the key's curve is checked against the claimed algorithm +// (e.g. P-384 key is rejected for an ES256 token) to prevent cross-algorithm +// downgrade attacks. +// +// Currently supported: ES256, ES384, ES512, RS256. +func (jws *JWS) UnsafeVerify(pub Key) bool { + signingInput := jws.Protected + "." + jws.Payload + + switch jws.Header.Alg { + case "ES256", "ES384", "ES512": + k, ok := pub.(*ecdsa.PublicKey) + if !ok || k == nil { + jws.Verified = false + return false + } + // Require the key's curve to match the token's claimed algorithm. + // A P-256 key must not verify an ES384 token and vice versa. + expectedAlg, h, err := algForECKey(k) + if err != nil || expectedAlg != jws.Header.Alg { + jws.Verified = false + return false + } + byteLen := (k.Curve.Params().BitSize + 7) / 8 + if len(jws.Signature) != 2*byteLen { + jws.Verified = false + return false + } + digest := digestFor(h, signingInput) + r := new(big.Int).SetBytes(jws.Signature[:byteLen]) + s := new(big.Int).SetBytes(jws.Signature[byteLen:]) + jws.Verified = ecdsa.Verify(k, digest, r, s) + + case "RS256": + k, ok := pub.(*rsa.PublicKey) + if !ok || k == nil { + jws.Verified = false + return false + } + digest := digestFor(crypto.SHA256, signingInput) + jws.Verified = rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, jws.Signature) == nil + + default: + jws.Verified = false + } + return jws.Verified +} + +// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and +// additionally enforces that the signature was verified (unless params.IgnoreSig). +// +// Returns a list of human-readable errors and a non-nil sentinel if any exist. +func (jws *JWS) Validate(params ValidateParams) ([]string, error) { + if params.Now.IsZero() { + params.Now = time.Now() + } + + errs, _ := jws.Claims.Validate(params) + + if !params.IgnoreSig && !jws.Verified { + errs = append(errs, "signature was not checked") + } + + if len(errs) > 0 { + timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST")) + if loc, err := time.LoadLocation("Local"); err == nil { + timeInfo += fmt.Sprintf(" %s", loc) + } + errs = append(errs, timeInfo) + return errs, fmt.Errorf("has errors") + } + return nil, nil +} + +// ValidateParams holds validation configuration. +// https://openid.net/specs/openid-connect-core-1_0.html#IDToken +type ValidateParams struct { + Now time.Time + IgnoreIss bool + Iss string + IgnoreSub bool + Sub string + IgnoreAud bool + Aud string + IgnoreExp bool + IgnoreJti bool + Jti string + IgnoreIat bool + IgnoreAuthTime bool + MaxAge time.Duration + IgnoreNonce bool + Nonce string + IgnoreAmr bool + RequiredAmrs []string + IgnoreAzp bool + Azp string + IgnoreSig bool +} + +// ValidateStandardClaims checks the registered JWT/OIDC claim fields. +// +// Called by [StandardClaims.Validate] and exported so that custom claims types +// can call it from an overriding Validate method: +// +// func (c AppClaims) Validate(params bestjwt.ValidateParams) ([]string, error) { +// errs, _ := bestjwt.ValidateStandardClaims(c.StandardClaims, params) +// if c.Email == "" { +// errs = append(errs, "missing email claim") +// } +// if len(errs) > 0 { +// return errs, fmt.Errorf("has errors") +// } +// return nil, nil +// } +// +// params.Now must be non-zero; [JWS.Validate] ensures this before delegating. +func ValidateStandardClaims(claims StandardClaims, params ValidateParams) ([]string, error) { + var errs []string + + // Required to exist and match + if len(params.Iss) > 0 || !params.IgnoreIss { + if len(claims.Iss) == 0 { + errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)") + } else if claims.Iss != params.Iss { + errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss)) + } + } + + // Required to exist, optional match + if len(claims.Sub) == 0 { + if !params.IgnoreSub { + errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)") + } + } else if len(params.Sub) > 0 { + if params.Sub != claims.Sub { + errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub)) + } + } + + // Required to exist and match + if len(params.Aud) > 0 || !params.IgnoreAud { + if len(claims.Aud) == 0 { + errs = append(errs, "missing or malformed 'aud' (audience receiving token)") + } else if claims.Aud != params.Aud { + errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud)) + } + } + + // Required to exist and not be in the past + if !params.IgnoreExp { + if claims.Exp <= 0 { + errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)") + } else if claims.Exp < params.Now.Unix() { + duration := time.Since(time.Unix(claims.Exp, 0)) + expTime := time.Unix(claims.Exp, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime)) + } + } + + // Required to exist and not be in the future + if !params.IgnoreIat { + if claims.Iat <= 0 { + errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)") + } else if claims.Iat > params.Now.Unix() { + duration := time.Unix(claims.Iat, 0).Sub(params.Now) + iatTime := time.Unix(claims.Iat, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf("'iat' (issued at) is %s in the future (%s)", formatDuration(duration), iatTime)) + } + } + + // Should exist, in the past, with optional max age + if params.MaxAge > 0 || !params.IgnoreAuthTime { + if claims.AuthTime == 0 { + errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)") + } else { + authTime := time.Unix(claims.AuthTime, 0) + authTimeStr := authTime.Format("2006-01-02 15:04:05 MST") + age := params.Now.Sub(authTime) + diff := age - params.MaxAge + if claims.AuthTime > params.Now.Unix() { + fromNow := time.Unix(claims.AuthTime, 0).Sub(params.Now) + errs = append(errs, fmt.Sprintf( + "'auth_time' of %s is %s in the future (server time %s)", + authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")), + ) + } else if params.MaxAge > 0 && age > params.MaxAge { + errs = append(errs, fmt.Sprintf( + "'auth_time' of %s is %s old, exceeding max age %s by %s", + authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)), + ) + } + } + } + + // Optional exact match + if params.Jti != claims.Jti { + if len(params.Jti) > 0 { + errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti)) + } else if !params.IgnoreJti { + errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti)) + } + } + + // Optional exact match + if params.Nonce != claims.Nonce { + if len(params.Nonce) > 0 { + errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce)) + } else if !params.IgnoreNonce { + errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce)) + } + } + + // Should exist, optional required-set check + if !params.IgnoreAmr { + if len(claims.Amr) == 0 { + errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)") + } else if len(params.RequiredAmrs) > 0 { + for _, required := range params.RequiredAmrs { + if !slices.Contains(claims.Amr, required) { + errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required)) + } + } + } + } + + // Optional, match if present + if params.Azp != claims.Azp { + if len(params.Azp) > 0 { + errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp)) + } else if !params.IgnoreAzp { + errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp)) + } + } + + if len(errs) > 0 { + return errs, fmt.Errorf("has errors") + } + return nil, nil +} + +// --- Signing helpers --- + +// SignES256 computes an ES256 signature over header.payload using a P-256 key. +// The signature is a fixed-width 64-byte raw r||s value (not ASN.1 DER). +// Each component is zero-padded to 32 bytes via [big.Int.FillBytes]. +func SignES256(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) { + return signEC(header, payload, key, crypto.SHA256) +} + +// SignES384 computes an ES384 signature over header.payload using a P-384 key. +// The signature is a fixed-width 96-byte raw r||s value. +func SignES384(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) { + return signEC(header, payload, key, crypto.SHA384) +} + +// SignES512 computes an ES512 signature over header.payload using a P-521 key. +// The signature is a fixed-width 132-byte raw r||s value. +func SignES512(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) { + return signEC(header, payload, key, crypto.SHA512) +} + +// SignRS256 computes an RS256 (PKCS#1 v1.5 + SHA-256) signature over header.payload. +func SignRS256(header, payload string, key *rsa.PrivateKey) ([]byte, error) { + digest := digestFor(crypto.SHA256, header+"."+payload) + sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, digest) + if err != nil { + return nil, fmt.Errorf("SignRS256: %w", err) + } + return sig, nil +} + +// EncodeToJWT appends a base64url-encoded signature to a signing input string. +func EncodeToJWT(signingInput string, signature []byte) string { + return signingInput + "." + base64.RawURLEncoding.EncodeToString(signature) +} + +// --- EC private key JWK utilities --- + +// JWK represents a private key in JSON Web Key format (EC only). +type JWK struct { + Kty string `json:"kty"` + Crv string `json:"crv"` + D string `json:"d"` + X string `json:"x"` + Y string `json:"y"` +} + +// UnmarshalJWK parses an EC private key from a JWK struct. +func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) { + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("invalid JWK X: %v", err) + } + y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, fmt.Errorf("invalid JWK Y: %v", err) + } + d, err := base64.RawURLEncoding.DecodeString(jwk.D) + if err != nil { + return nil, fmt.Errorf("invalid JWK D: %v", err) + } + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + }, + D: new(big.Int).SetBytes(d), + }, nil +} + +// Thumbprint computes the RFC 7638 JWK Thumbprint for an EC public key. +func (jwk JWK) Thumbprint() (string, error) { + data := map[string]string{ + "crv": jwk.Crv, + "kty": jwk.Kty, + "x": jwk.X, + "y": jwk.Y, + } + jsonData, err := json.Marshal(data) + if err != nil { + return "", err + } + hash := sha256.Sum256(jsonData) + return base64.RawURLEncoding.EncodeToString(hash[:]), nil +} + +// --- Internal helpers --- + +// algForECKey returns the JWA algorithm name and hash function for an EC public +// key, inferred from its curve. Returns an error for unsupported curves. +func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) { + switch pub.Curve { + case elliptic.P256(): + return "ES256", crypto.SHA256, nil + case elliptic.P384(): + return "ES384", crypto.SHA384, nil + case elliptic.P521(): + return "ES512", crypto.SHA512, nil + default: + return "", 0, fmt.Errorf("unsupported EC curve: %s", pub.Curve.Params().Name) + } +} + +// digestFor hashes data with h and returns the digest. +// Uses fixed-size stack arrays for the three supported hashes to avoid +// unnecessary heap allocation on the hot signing/verification path. +func digestFor(h crypto.Hash, data string) []byte { + switch h { + case crypto.SHA256: + d := sha256.Sum256([]byte(data)) + return d[:] + case crypto.SHA384: + d := sha512.Sum384([]byte(data)) + return d[:] + case crypto.SHA512: + d := sha512.Sum512([]byte(data)) + return d[:] + default: + panic(fmt.Sprintf("bestjwt: unsupported hash %v", h)) + } +} + +// signEC is the shared implementation for SignES256/384/512. +func signEC(header, payload string, key *ecdsa.PrivateKey, h crypto.Hash) ([]byte, error) { + digest := digestFor(h, header+"."+payload) + r, s, err := ecdsa.Sign(rand.Reader, key, digest) + if err != nil { + return nil, fmt.Errorf("signEC: %w", err) + } + byteLen := (key.Curve.Params().BitSize + 7) / 8 + out := make([]byte, 2*byteLen) + r.FillBytes(out[:byteLen]) + s.FillBytes(out[byteLen:]) + return out, nil +} + +// ecdsaDERToRaw converts an ASN.1 DER-encoded ECDSA signature (as returned by +// [crypto.Signer]) to the fixed-width raw r||s format required by JWS. +// r and s are zero-padded to the curve's byte length via [big.Int.FillBytes]. +func ecdsaDERToRaw(der []byte, curve elliptic.Curve) ([]byte, error) { + var sig struct{ R, S *big.Int } + if _, err := asn1.Unmarshal(der, &sig); err != nil { + return nil, fmt.Errorf("ecdsaDERToRaw: %w", err) + } + byteLen := (curve.Params().BitSize + 7) / 8 + out := make([]byte, 2*byteLen) + sig.R.FillBytes(out[:byteLen]) + sig.S.FillBytes(out[byteLen:]) + return out, nil +} + +// URLBase64 is a []byte that marshals to/from raw base64url in JSON. +type URLBase64 []byte + +func (s URLBase64) String() string { + return base64.RawURLEncoding.EncodeToString(s) +} + +func (s URLBase64) MarshalJSON() ([]byte, error) { + encoded := base64.RawURLEncoding.EncodeToString(s) + return json.Marshal(encoded) +} + +func (s *URLBase64) UnmarshalJSON(data []byte) error { + dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data) + if err != nil { + return fmt.Errorf("decode base64url signature: %w", err) + } + *s = dst + return nil +} + +func formatDuration(d time.Duration) string { + if d < 0 { + d = -d + } + days := int(d / (24 * time.Hour)) + d -= time.Duration(days) * 24 * time.Hour + hours := int(d / time.Hour) + d -= time.Duration(hours) * time.Hour + minutes := int(d / time.Minute) + d -= time.Duration(minutes) * time.Minute + seconds := int(d / time.Second) + + var parts []string + if days > 0 { + parts = append(parts, fmt.Sprintf("%dd", days)) + } + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + if seconds == 0 || len(parts) == 0 { + d -= time.Duration(seconds) * time.Second + millis := int(d / time.Millisecond) + parts = append(parts, fmt.Sprintf("%dms", millis)) + } + + return strings.Join(parts, " ") +} diff --git a/auth/bestjwt/jwt_test.go b/auth/bestjwt/jwt_test.go new file mode 100644 index 0000000..8dc5706 --- /dev/null +++ b/auth/bestjwt/jwt_test.go @@ -0,0 +1,478 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package bestjwt_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "fmt" + "testing" + "time" + + "github.com/therootcompany/golib/auth/bestjwt" +) + +// AppClaims is an example custom claims type. +// Embedding StandardClaims promotes Validate — no boilerplate needed. +type AppClaims struct { + bestjwt.StandardClaims + Email string `json:"email"` + Roles []string `json:"roles"` +} + +// StrictAppClaims overrides Validate to also require a non-empty Email. +// This demonstrates how to add application-specific validation on top of +// the standard OIDC checks. +type StrictAppClaims struct { + bestjwt.StandardClaims + Email string `json:"email"` +} + +func (c StrictAppClaims) Validate(params bestjwt.ValidateParams) ([]string, error) { + errs, _ := bestjwt.ValidateStandardClaims(c.StandardClaims, params) + if c.Email == "" { + errs = append(errs, "missing email claim") + } + if len(errs) > 0 { + return errs, fmt.Errorf("has errors") + } + return nil, nil +} + +// goodClaims returns a valid AppClaims with all standard fields populated. +func goodClaims() AppClaims { + now := time.Now() + return AppClaims{ + StandardClaims: bestjwt.StandardClaims{ + Iss: "https://example.com", + Sub: "user123", + Aud: "myapp", + Exp: now.Add(time.Hour).Unix(), + Iat: now.Unix(), + AuthTime: now.Unix(), + Amr: []string{"pwd"}, + Jti: "abc123", + Azp: "myapp", + Nonce: "nonce1", + }, + Email: "user@example.com", + Roles: []string{"admin"}, + } +} + +// goodParams returns ValidateParams matching the claims from goodClaims. +func goodParams() bestjwt.ValidateParams { + return bestjwt.ValidateParams{ + Iss: "https://example.com", + Sub: "user123", + Aud: "myapp", + Jti: "abc123", + Nonce: "nonce1", + Azp: "myapp", + RequiredAmrs: []string{"pwd"}, + } +} + +// --- Round-trip tests (sign → encode → decode → verify → validate) --- + +// TestRoundTripES256 exercises the most common path: ECDSA P-256 / ES256. +// Demonstrates the Decode(&claims) ergonomic — no generics at the call site, +// no type assertion needed to access Email after decoding. +func TestRoundTripES256(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "ES256" { + t.Fatalf("expected ES256, got %s", jws.Header.Alg) + } + if len(jws.Signature) != 64 { // P-256: 32 bytes each for r and s + t.Fatalf("expected 64-byte signature, got %d", len(jws.Signature)) + } + + token := jws.Encode() + + var decoded AppClaims + jws2, err := bestjwt.Decode(token, &decoded) + if err != nil { + t.Fatal(err) + } + if !jws2.UnsafeVerify(&privKey.PublicKey) { + t.Fatal("signature verification failed") + } + if errs, err := jws2.Validate(goodParams()); err != nil { + t.Fatalf("validation failed: %v", errs) + } + // Direct field access via the local variable — no type assertion. + if decoded.Email != claims.Email { + t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) + } +} + +// TestRoundTripES384 exercises ECDSA P-384 / ES384, verifying that the +// algorithm is inferred from the key's curve and that the 96-byte r||s +// signature format is produced and verified correctly. +func TestRoundTripES384(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "ES384" { + t.Fatalf("expected ES384, got %s", jws.Header.Alg) + } + if len(jws.Signature) != 96 { // P-384: 48 bytes each for r and s + t.Fatalf("expected 96-byte signature, got %d", len(jws.Signature)) + } + + token := jws.Encode() + + var decoded AppClaims + jws2, err := bestjwt.Decode(token, &decoded) + if err != nil { + t.Fatal(err) + } + if !jws2.UnsafeVerify(&privKey.PublicKey) { + t.Fatal("signature verification failed") + } + if errs, err := jws2.Validate(goodParams()); err != nil { + t.Fatalf("validation failed: %v", errs) + } +} + +// TestRoundTripES512 exercises ECDSA P-521 / ES512 and the 132-byte signature. +func TestRoundTripES512(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "ES512" { + t.Fatalf("expected ES512, got %s", jws.Header.Alg) + } + if len(jws.Signature) != 132 { // P-521: 66 bytes each for r and s + t.Fatalf("expected 132-byte signature, got %d", len(jws.Signature)) + } + + token := jws.Encode() + + var decoded AppClaims + jws2, err := bestjwt.Decode(token, &decoded) + if err != nil { + t.Fatal(err) + } + if !jws2.UnsafeVerify(&privKey.PublicKey) { + t.Fatal("signature verification failed") + } + if errs, err := jws2.Validate(goodParams()); err != nil { + t.Fatalf("validation failed: %v", errs) + } +} + +// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256. +func TestRoundTripRS256(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "RS256" { + t.Fatalf("expected RS256, got %s", jws.Header.Alg) + } + + token := jws.Encode() + + var decoded AppClaims + jws2, err := bestjwt.Decode(token, &decoded) + if err != nil { + t.Fatal(err) + } + if !jws2.UnsafeVerify(&privKey.PublicKey) { + t.Fatal("signature verification failed") + } + if errs, err := jws2.Validate(goodParams()); err != nil { + t.Fatalf("validation failed: %v", errs) + } +} + +// --- Security / negative tests --- + +// TestVerifyWrongKeyType verifies that an RSA public key is rejected when +// verifying a token signed with ECDSA (alg = ES256). +func TestVerifyWrongKeyType(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + + claims := goodClaims() + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(ecKey) // alg = "ES256" + + token := jws.Encode() + var decoded AppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + + if jws2.UnsafeVerify(&rsaKey.PublicKey) { + t.Fatal("expected verification to fail: RSA key for ES256 token") + } +} + +// TestVerifyAlgCurveMismatch verifies that a P-256 key is rejected when +// verifying a token whose header claims ES384 (signed with P-384). +// Without the curve/alg consistency check this would silently return false +// from ecdsa.Verify, but the explicit check makes the rejection reason clear. +func TestVerifyAlgCurveMismatch(t *testing.T) { + p384Key, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + p256Key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(p384Key) // alg = "ES384" + + token := jws.Encode() + var decoded AppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + + // P-256 key must be rejected for an ES384 token. + if jws2.UnsafeVerify(&p256Key.PublicKey) { + t.Fatal("expected verification to fail: P-256 key for ES384 token") + } +} + +// TestVerifyUnknownAlg verifies that a tampered alg header is rejected. +func TestVerifyUnknownAlg(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + + token := jws.Encode() + var decoded AppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + + // Tamper: overwrite alg in the decoded header. + jws2.Header.Alg = "none" + + if jws2.UnsafeVerify(&privKey.PublicKey) { + t.Fatal("expected verification to fail for unknown alg") + } +} + +// TestValidateMissingSignatureCheck verifies that Validate fails when +// UnsafeVerify was never called (Verified is false). +func TestValidateMissingSignatureCheck(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + + token := jws.Encode() + var decoded AppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + + // Deliberately skip UnsafeVerify. + errs, err := jws2.Validate(goodParams()) + if err == nil { + t.Fatal("expected validation to fail: signature was not checked") + } + found := false + for _, e := range errs { + if e == "signature was not checked" { + found = true + } + } + if !found { + t.Fatalf("expected 'signature was not checked' in errors: %v", errs) + } +} + +// --- Embedded vs overridden Validate --- + +// TestPromotedValidate confirms that AppClaims (which only embeds +// StandardClaims) gets the standard OIDC validation for free via promotion, +// without writing any Validate method. +func TestPromotedValidate(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + var decoded AppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + jws2.UnsafeVerify(&privKey.PublicKey) + + if errs, err := jws2.Validate(goodParams()); err != nil { + t.Fatalf("promoted Validate failed unexpectedly: %v", errs) + } +} + +// TestOverriddenValidate confirms that a StrictAppClaims with an empty Email +// fails validation via its overridden Validate method. +func TestOverriddenValidate(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + now := time.Now() + claims := StrictAppClaims{ + StandardClaims: bestjwt.StandardClaims{ + Iss: "https://example.com", + Sub: "user123", + Aud: "myapp", + Exp: now.Add(time.Hour).Unix(), + Iat: now.Unix(), + AuthTime: now.Unix(), + Amr: []string{"pwd"}, + Jti: "abc123", + Azp: "myapp", + Nonce: "nonce1", + }, + Email: "", // intentionally empty to trigger the override + } + + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + var decoded StrictAppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + jws2.UnsafeVerify(&privKey.PublicKey) + + errs, err := jws2.Validate(goodParams()) + if err == nil { + t.Fatal("expected validation to fail: email is empty") + } + found := false + for _, e := range errs { + if e == "missing email claim" { + found = true + } + } + if !found { + t.Fatalf("expected 'missing email claim' in errors: %v", errs) + } +} + +// --- JWKS / key management --- + +// TestTypedKeys verifies that TypedKeys correctly filters a mixed +// []PublicJWK[Key] into typed slices without type assertions at use sites. +func TestTypedKeys(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + + allKeys := []bestjwt.PublicJWK[bestjwt.Key]{ + {Key: &ecKey.PublicKey, KID: "ec-1", Use: "sig"}, + {Key: &rsaKey.PublicKey, KID: "rsa-1", Use: "sig"}, + } + + ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](allKeys) + if len(ecKeys) != 1 || ecKeys[0].KID != "ec-1" { + t.Errorf("unexpected EC keys: %+v", ecKeys) + } + // Typed access — no assertion needed. + _ = ecKeys[0].Key.Curve + + rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](allKeys) + if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-1" { + t.Errorf("unexpected RSA keys: %+v", rsaKeys) + } +} + +// TestVerifyWithJWKSKey verifies that PublicJWK.Key can be passed directly to +// UnsafeVerify without a type assertion when using a typed PublicJWK[Key]. +func TestVerifyWithJWKSKey(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + jwksKey := bestjwt.PublicJWK[bestjwt.Key]{Key: &privKey.PublicKey, KID: "k1"} + + claims := goodClaims() + jws, _ := bestjwt.NewJWSFromClaims(&claims, "k1") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + var decoded AppClaims + jws2, _ := bestjwt.Decode(token, &decoded) + + // Pass PublicJWK.Key directly — Key interface satisfies the Key constraint. + if !jws2.UnsafeVerify(jwksKey.Key) { + t.Fatal("verification via PublicJWK.Key failed") + } +} + +// TestDecodePublicJWKJSON verifies JWKS JSON parsing and TypedKeys filtering +// with real base64url-encoded key material from RFC 7517 / OIDC examples. +func TestDecodePublicJWKJSON(t *testing.T) { + jwksJSON := []byte(`{"keys":[ + {"kty":"EC","crv":"P-256", + "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "kid":"ec-256","use":"sig"}, + {"kty":"RSA", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e":"AQAB","kid":"rsa-2048","use":"sig"} + ]}`) + + keys, err := bestjwt.UnmarshalPublicJWKs(jwksJSON) + if err != nil { + t.Fatal(err) + } + if len(keys) != 2 { + t.Fatalf("expected 2 keys, got %d", len(keys)) + } + + ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](keys) + if len(ecKeys) != 1 || ecKeys[0].KID != "ec-256" { + t.Errorf("EC key mismatch: %+v", ecKeys) + } + + rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](keys) + if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-2048" { + t.Errorf("RSA key mismatch: %+v", rsaKeys) + } +} diff --git a/auth/bestjwt/pub.go b/auth/bestjwt/pub.go new file mode 100644 index 0000000..25c23d7 --- /dev/null +++ b/auth/bestjwt/pub.go @@ -0,0 +1,229 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package bestjwt + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "os" + "time" +) + +// Key is the constraint for the public key type parameter K used in PublicJWK. +// +// All standard-library asymmetric public key types satisfy this interface +// since Go 1.15: *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey. +// +// Note: crypto.PublicKey is defined as interface{} and does NOT satisfy Key. +// Use Key itself as the type argument for heterogeneous collections +// (e.g. []PublicJWK[Key]), since Key declares Equal and therefore satisfies +// its own constraint. Use [TypedKeys] to narrow to a concrete type. +type Key interface { + Equal(x crypto.PublicKey) bool +} + +// PublicJWK wraps a parsed public key with its JWKS metadata. +// +// K is constrained to [Key], providing type-safe access to the underlying +// key without a type assertion at each use site. +// +// For a heterogeneous JWKS endpoint (mixed RSA/EC) use PublicJWK[Key]. +// For a homogeneous store use the concrete type directly (e.g. +// PublicJWK[*ecdsa.PublicKey]). Use [TypedKeys] to narrow a mixed slice. +// +// Example — sign with a known key type, no assertion needed: +// +// ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](allKeys) +// jws.UnsafeVerify(ecKeys[0].Key) // Key is *ecdsa.PublicKey directly +type PublicJWK[K Key] struct { + Key K + KID string + Use string +} + +// PublicJWKJSON is the JSON representation of a single key in a JWKS document. +type PublicJWKJSON struct { + Kty string `json:"kty"` + KID string `json:"kid"` + N string `json:"n,omitempty"` // RSA modulus + E string `json:"e,omitempty"` // RSA exponent + Crv string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` + Use string `json:"use,omitempty"` +} + +// JWKsJSON is the JSON representation of a JWKS document. +type JWKsJSON struct { + Keys []PublicJWKJSON `json:"keys"` +} + +// TypedKeys filters a heterogeneous []PublicJWK[Key] slice to only those whose +// underlying key is of concrete type K, returning a typed []PublicJWK[K]. +// Keys of other types are silently skipped. +// +// Example — extract only ECDSA keys from a mixed JWKS result: +// +// all, _ := bestjwt.FetchPublicJWKs(jwksURL) +// ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](all) +// rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](all) +func TypedKeys[K Key](keys []PublicJWK[Key]) []PublicJWK[K] { + var result []PublicJWK[K] + for _, k := range keys { + if typed, ok := k.Key.(K); ok { + result = append(result, PublicJWK[K]{Key: typed, KID: k.KID, Use: k.Use}) + } + } + return result +} + +// FetchPublicJWKs retrieves and parses a JWKS document from url. +// Keys are returned as []PublicJWK[Key] since a JWKS endpoint may contain a +// mix of key types. Use [TypedKeys] to narrow to a concrete type. +func FetchPublicJWKs(url string) ([]PublicJWK[Key], error) { + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return DecodePublicJWKs(resp.Body) +} + +// ReadPublicJWKs reads and parses a JWKS document from a file path. +func ReadPublicJWKs(filePath string) ([]PublicJWK[Key], error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open JWKS file '%s': %w", filePath, err) + } + defer func() { _ = file.Close() }() + return DecodePublicJWKs(file) +} + +// UnmarshalPublicJWKs parses a JWKS document from raw JSON bytes. +func UnmarshalPublicJWKs(data []byte) ([]PublicJWK[Key], error) { + var jwks JWKsJSON + if err := json.Unmarshal(data, &jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err) + } + return DecodePublicJWKsJSON(jwks) +} + +// DecodePublicJWKs parses a JWKS document from an [io.Reader]. +func DecodePublicJWKs(r io.Reader) ([]PublicJWK[Key], error) { + var jwks JWKsJSON + if err := json.NewDecoder(r).Decode(&jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err) + } + return DecodePublicJWKsJSON(jwks) +} + +// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys. +func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK[Key], error) { + var keys []PublicJWK[Key] + for _, jwk := range jwks.Keys { + key, err := DecodePublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse public jwk '%s': %w", jwk.KID, err) + } + keys = append(keys, *key) + } + if len(keys) == 0 { + return nil, fmt.Errorf("no valid RSA or ECDSA keys found") + } + return keys, nil +} + +// DecodePublicJWK parses a single [PublicJWKJSON] into a PublicJWK[Key]. +// Supports RSA (minimum 1024-bit) and EC (P-256, P-384, P-521) keys. +func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK[Key], error) { + switch jwk.Kty { + case "RSA": + key, err := decodeRSAPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse RSA key '%s': %w", jwk.KID, err) + } + if key.Size() < 128 { // 1024 bits minimum + return nil, fmt.Errorf("RSA key '%s' too small: %d bytes", jwk.KID, key.Size()) + } + return &PublicJWK[Key]{Key: key, KID: jwk.KID, Use: jwk.Use}, nil + + case "EC": + key, err := decodeECDSAPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse EC key '%s': %w", jwk.KID, err) + } + return &PublicJWK[Key]{Key: key, KID: jwk.KID, Use: jwk.Use}, nil + + default: + return nil, fmt.Errorf("unsupported key type '%s' for kid '%s'", jwk.Kty, jwk.KID) + } +} + +func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) { + n, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, fmt.Errorf("invalid RSA modulus: %w", err) + } + e, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, fmt.Errorf("invalid RSA exponent: %w", err) + } + + eInt := new(big.Int).SetBytes(e).Int64() + if eInt > int64(^uint(0)>>1) || eInt < 0 { + return nil, fmt.Errorf("RSA exponent too large or negative") + } + + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(eInt), + }, nil +} + +func decodeECDSAPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) { + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("invalid ECDSA X: %w", err) + } + y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, fmt.Errorf("invalid ECDSA Y: %w", err) + } + + var curve elliptic.Curve + switch jwk.Crv { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return nil, fmt.Errorf("unsupported ECDSA curve: %s", jwk.Crv) + } + + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + }, nil +}