mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-13 12:27:59 +00:00
feat(auth/bestjwt): add hybrid JWT/JWS/JWK package
Combines the best ergonomics from genericjwt and embeddedjwt: - Decode(&claims) pattern (embedded structs, no generics at call sites, no type assertion to access custom fields) - StandardClaims.Validate promoted to any embedding struct via value receiver; override Validate on the outer struct for custom checks - Sign(crypto.Signer): algorithm inferred from key.Public() type switch, supports HSM/cloud KMS transparently - Full ECDSA curve support: ES256 (P-256), ES384 (P-384), ES512 (P-521) all inferred automatically from key curve via algForECKey - Curve/alg consistency check in UnsafeVerify: P-256 key rejected for ES384 token and vice versa (prevents cross-algorithm downgrade) - digestFor: fixed-size stack arrays for SHA-256/384/512 digests - ecdsaDERToRaw + FillBytes: correct zero-padded r||s conversion from ASN.1 DER output of crypto.Signer - Generic PublicJWK[K Key] + TypedKeys[K]: type-safe JWKS key management, filter mixed []PublicJWK[Key] to concrete type without assertions - JWKS fetch/parse: FetchPublicJWKs, ReadPublicJWKs, UnmarshalPublicJWKs, DecodePublicJWKs for RSA and EC (P-256/384/521) - RS256 (PKCS#1 v1.5 + SHA-256) support via crypto.Signer - 13 tests covering all algorithms, negative cases, and JWKS integration
This commit is contained in:
parent
55a7b9b2f4
commit
fac58cf1ad
3
auth/bestjwt/go.mod
Normal file
3
auth/bestjwt/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module github.com/therootcompany/golib/auth/bestjwt
|
||||||
|
|
||||||
|
go 1.24.0
|
||||||
706
auth/bestjwt/jwt.go
Normal file
706
auth/bestjwt/jwt.go
Normal file
@ -0,0 +1,706 @@
|
|||||||
|
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
// Package bestjwt is a lightweight JWT/JWS/JWK library that combines the
|
||||||
|
// best ergonomics from the genericjwt and embeddedjwt packages:
|
||||||
|
//
|
||||||
|
// - Claims via embedded structs: Decode(token, &myClaims) — no generics at
|
||||||
|
// call sites, no type assertions to access custom fields.
|
||||||
|
// - crypto.Signer for all signing: works with in-process keys AND
|
||||||
|
// hardware-backed keys (HSM, cloud KMS, PKCS#11) without modification.
|
||||||
|
// - Full ECDSA curve support: ES256 (P-256), ES384 (P-384), ES512 (P-521).
|
||||||
|
// The algorithm is inferred from the key's curve, not hardcoded.
|
||||||
|
// - Curve/algorithm consistency enforcement: UnsafeVerify rejects a P-256
|
||||||
|
// key presented for an ES384 token and vice versa.
|
||||||
|
// - Generic PublicJWK[K Key]: type-safe JWKS key management with TypedKeys
|
||||||
|
// to filter a mixed []PublicJWK[Key] to a concrete key type.
|
||||||
|
package bestjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/asn1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claims is the interface that custom claims types must satisfy.
|
||||||
|
//
|
||||||
|
// Because [StandardClaims] implements Claims with a value receiver, any struct
|
||||||
|
// that embeds StandardClaims satisfies Claims automatically via method
|
||||||
|
// promotion — no boilerplate required:
|
||||||
|
//
|
||||||
|
// type AppClaims struct {
|
||||||
|
// bestjwt.StandardClaims
|
||||||
|
// Email string `json:"email"`
|
||||||
|
// }
|
||||||
|
// // AppClaims now satisfies Claims for free.
|
||||||
|
//
|
||||||
|
// Override Validate on the outer struct to add application-specific checks.
|
||||||
|
// Call [ValidateStandardClaims] inside your override to preserve all standard
|
||||||
|
// OIDC/JWT validation.
|
||||||
|
type Claims interface {
|
||||||
|
Validate(params ValidateParams) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWS is a decoded JSON Web Signature / JWT.
|
||||||
|
//
|
||||||
|
// Claims is stored as the [Claims] interface so that any embedded-struct type
|
||||||
|
// can be used without generics. The most ergonomic access pattern is via the
|
||||||
|
// pointer you passed to [Decode]:
|
||||||
|
//
|
||||||
|
// var claims AppClaims
|
||||||
|
// jws, err := bestjwt.Decode(tokenString, &claims)
|
||||||
|
// jws.UnsafeVerify(pubKey)
|
||||||
|
// errs, err := jws.Validate(params)
|
||||||
|
// // Access claims.Email, claims.Roles, etc. directly — no type assertion.
|
||||||
|
type JWS struct {
|
||||||
|
Protected string `json:"-"` // base64url-encoded header
|
||||||
|
Header StandardHeader `json:"header"`
|
||||||
|
Payload string `json:"-"` // base64url-encoded claims
|
||||||
|
Claims Claims `json:"claims"`
|
||||||
|
Signature URLBase64 `json:"signature"`
|
||||||
|
Verified bool `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StandardHeader holds the standard JOSE header fields.
|
||||||
|
type StandardHeader struct {
|
||||||
|
Alg string `json:"alg"`
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
Typ string `json:"typ"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StandardClaims holds the registered JWT claim names defined in RFC 7519
|
||||||
|
// and extended by OpenID Connect Core.
|
||||||
|
//
|
||||||
|
// Embed StandardClaims in your own struct to satisfy [Claims] automatically:
|
||||||
|
//
|
||||||
|
// type AppClaims struct {
|
||||||
|
// bestjwt.StandardClaims
|
||||||
|
// Email string `json:"email"`
|
||||||
|
// }
|
||||||
|
// // AppClaims now satisfies Claims via promoted Validate.
|
||||||
|
// // Override Validate on AppClaims to add custom checks.
|
||||||
|
type StandardClaims struct {
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Aud string `json:"aud"`
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
Iat int64 `json:"iat"`
|
||||||
|
AuthTime int64 `json:"auth_time"`
|
||||||
|
Nonce string `json:"nonce,omitempty"`
|
||||||
|
Amr []string `json:"amr"`
|
||||||
|
Azp string `json:"azp,omitempty"`
|
||||||
|
Jti string `json:"jti"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate implements [Claims] by checking all standard OIDC/JWT claim fields.
|
||||||
|
//
|
||||||
|
// This method is promoted to any struct that embeds [StandardClaims], so
|
||||||
|
// embedding structs satisfy Claims without writing any additional code.
|
||||||
|
// params.Now must be non-zero; [JWS.Validate] ensures this before delegating.
|
||||||
|
func (c StandardClaims) Validate(params ValidateParams) ([]string, error) {
|
||||||
|
return ValidateStandardClaims(c, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
||||||
|
//
|
||||||
|
// claims must be a pointer to the caller's pre-allocated claims struct
|
||||||
|
// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it,
|
||||||
|
// and the same pointer is stored in jws.Claims. Callers can access custom
|
||||||
|
// fields through their own variable without a type assertion:
|
||||||
|
//
|
||||||
|
// var claims AppClaims
|
||||||
|
// jws, err := bestjwt.Decode(token, &claims)
|
||||||
|
// // claims.Email is already set; no type assertion needed.
|
||||||
|
//
|
||||||
|
// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first,
|
||||||
|
// then [JWS.Validate].
|
||||||
|
func Decode(tokenStr string, claims Claims) (*JWS, error) {
|
||||||
|
parts := strings.Split(tokenStr, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, fmt.Errorf("invalid JWT format")
|
||||||
|
}
|
||||||
|
|
||||||
|
var jws JWS
|
||||||
|
var sigEnc string
|
||||||
|
jws.Protected, jws.Payload, sigEnc = parts[0], parts[1], parts[2]
|
||||||
|
|
||||||
|
header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid header encoding: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(header, &jws.Header); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid header JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid claims encoding: %v", err)
|
||||||
|
}
|
||||||
|
// Unmarshal into the concrete type behind the Claims interface.
|
||||||
|
// json.Unmarshal reaches the concrete pointer via reflection.
|
||||||
|
if err := json.Unmarshal(payload, claims); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid claims JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid signature encoding: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jws.Claims = claims
|
||||||
|
return &jws, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
|
||||||
|
//
|
||||||
|
// kid identifies the signing key. The "alg" header field is set automatically
|
||||||
|
// when [JWS.Sign] is called, since only the key type determines the algorithm.
|
||||||
|
// Call [JWS.Encode] to produce a compact JWT string after signing.
|
||||||
|
func NewJWSFromClaims(claims Claims, kid string) (*JWS, error) {
|
||||||
|
var jws JWS
|
||||||
|
|
||||||
|
jws.Header = StandardHeader{
|
||||||
|
// Alg is intentionally omitted here; Sign sets it from the key type.
|
||||||
|
Kid: kid,
|
||||||
|
Typ: "JWT",
|
||||||
|
}
|
||||||
|
headerJSON, _ := json.Marshal(jws.Header)
|
||||||
|
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||||
|
|
||||||
|
claimsJSON, _ := json.Marshal(claims)
|
||||||
|
jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||||
|
jws.Claims = claims
|
||||||
|
|
||||||
|
return &jws, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign signs the JWS in-place using the provided [crypto.Signer].
|
||||||
|
//
|
||||||
|
// The algorithm is determined from the signer's public key type and, for EC
|
||||||
|
// keys, from the curve. The "alg" header field is set and the protected header
|
||||||
|
// is re-encoded before computing the signing input, so the signed bytes are
|
||||||
|
// always consistent with the token header.
|
||||||
|
//
|
||||||
|
// Supported algorithms (inferred automatically):
|
||||||
|
// - *ecdsa.PublicKey P-256 → ES256 (ECDSA + SHA-256, raw r||s)
|
||||||
|
// - *ecdsa.PublicKey P-384 → ES384 (ECDSA + SHA-384, raw r||s)
|
||||||
|
// - *ecdsa.PublicKey P-521 → ES512 (ECDSA + SHA-512, raw r||s)
|
||||||
|
// - *rsa.PublicKey → RS256 (PKCS#1 v1.5 + SHA-256)
|
||||||
|
//
|
||||||
|
// Because the parameter is [crypto.Signer] rather than a concrete key type,
|
||||||
|
// hardware-backed signers (HSM, OS keychain, cloud KMS) work transparently.
|
||||||
|
func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
|
||||||
|
switch pub := key.Public().(type) {
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
alg, h, err := algForECKey(pub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jws.Header.Alg = alg
|
||||||
|
headerJSON, _ := json.Marshal(jws.Header)
|
||||||
|
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||||
|
|
||||||
|
digest := digestFor(h, jws.Protected+"."+jws.Payload)
|
||||||
|
// crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS.
|
||||||
|
derSig, err := key.Sign(rand.Reader, digest, h)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Sign %s: %w", alg, err)
|
||||||
|
}
|
||||||
|
jws.Signature, err = ecdsaDERToRaw(derSig, pub.Curve)
|
||||||
|
return jws.Signature, err
|
||||||
|
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
jws.Header.Alg = "RS256"
|
||||||
|
headerJSON, _ := json.Marshal(jws.Header)
|
||||||
|
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||||
|
|
||||||
|
digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload)
|
||||||
|
// crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly.
|
||||||
|
var err error
|
||||||
|
jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256)
|
||||||
|
return jws.Signature, err
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"Sign: unsupported public key type %T (supported: *ecdsa.PublicKey, *rsa.PublicKey)",
|
||||||
|
key.Public(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode produces the compact JWT string (header.payload.signature).
|
||||||
|
func (jws JWS) Encode() string {
|
||||||
|
sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature)
|
||||||
|
return jws.Protected + "." + jws.Payload + "." + sigEnc
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnsafeVerify checks the signature using the algorithm in the JWT header and
|
||||||
|
// sets jws.Verified on success. It only checks the signature — use
|
||||||
|
// [JWS.Validate] to check claim values (expiry, issuer, audience, etc.).
|
||||||
|
//
|
||||||
|
// pub must be of the concrete type matching the header alg (e.g.
|
||||||
|
// *ecdsa.PublicKey for ES256/ES384/ES512). Callers can pass PublicJWK.Key
|
||||||
|
// directly without a type assertion.
|
||||||
|
//
|
||||||
|
// For ECDSA tokens, the key's curve is checked against the claimed algorithm
|
||||||
|
// (e.g. P-384 key is rejected for an ES256 token) to prevent cross-algorithm
|
||||||
|
// downgrade attacks.
|
||||||
|
//
|
||||||
|
// Currently supported: ES256, ES384, ES512, RS256.
|
||||||
|
func (jws *JWS) UnsafeVerify(pub Key) bool {
|
||||||
|
signingInput := jws.Protected + "." + jws.Payload
|
||||||
|
|
||||||
|
switch jws.Header.Alg {
|
||||||
|
case "ES256", "ES384", "ES512":
|
||||||
|
k, ok := pub.(*ecdsa.PublicKey)
|
||||||
|
if !ok || k == nil {
|
||||||
|
jws.Verified = false
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Require the key's curve to match the token's claimed algorithm.
|
||||||
|
// A P-256 key must not verify an ES384 token and vice versa.
|
||||||
|
expectedAlg, h, err := algForECKey(k)
|
||||||
|
if err != nil || expectedAlg != jws.Header.Alg {
|
||||||
|
jws.Verified = false
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
byteLen := (k.Curve.Params().BitSize + 7) / 8
|
||||||
|
if len(jws.Signature) != 2*byteLen {
|
||||||
|
jws.Verified = false
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
digest := digestFor(h, signingInput)
|
||||||
|
r := new(big.Int).SetBytes(jws.Signature[:byteLen])
|
||||||
|
s := new(big.Int).SetBytes(jws.Signature[byteLen:])
|
||||||
|
jws.Verified = ecdsa.Verify(k, digest, r, s)
|
||||||
|
|
||||||
|
case "RS256":
|
||||||
|
k, ok := pub.(*rsa.PublicKey)
|
||||||
|
if !ok || k == nil {
|
||||||
|
jws.Verified = false
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
digest := digestFor(crypto.SHA256, signingInput)
|
||||||
|
jws.Verified = rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, jws.Signature) == nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
jws.Verified = false
|
||||||
|
}
|
||||||
|
return jws.Verified
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
|
||||||
|
// additionally enforces that the signature was verified (unless params.IgnoreSig).
|
||||||
|
//
|
||||||
|
// Returns a list of human-readable errors and a non-nil sentinel if any exist.
|
||||||
|
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
|
||||||
|
if params.Now.IsZero() {
|
||||||
|
params.Now = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
errs, _ := jws.Claims.Validate(params)
|
||||||
|
|
||||||
|
if !params.IgnoreSig && !jws.Verified {
|
||||||
|
errs = append(errs, "signature was not checked")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
if loc, err := time.LoadLocation("Local"); err == nil {
|
||||||
|
timeInfo += fmt.Sprintf(" %s", loc)
|
||||||
|
}
|
||||||
|
errs = append(errs, timeInfo)
|
||||||
|
return errs, fmt.Errorf("has errors")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateParams holds validation configuration.
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||||
|
type ValidateParams struct {
|
||||||
|
Now time.Time
|
||||||
|
IgnoreIss bool
|
||||||
|
Iss string
|
||||||
|
IgnoreSub bool
|
||||||
|
Sub string
|
||||||
|
IgnoreAud bool
|
||||||
|
Aud string
|
||||||
|
IgnoreExp bool
|
||||||
|
IgnoreJti bool
|
||||||
|
Jti string
|
||||||
|
IgnoreIat bool
|
||||||
|
IgnoreAuthTime bool
|
||||||
|
MaxAge time.Duration
|
||||||
|
IgnoreNonce bool
|
||||||
|
Nonce string
|
||||||
|
IgnoreAmr bool
|
||||||
|
RequiredAmrs []string
|
||||||
|
IgnoreAzp bool
|
||||||
|
Azp string
|
||||||
|
IgnoreSig bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateStandardClaims checks the registered JWT/OIDC claim fields.
|
||||||
|
//
|
||||||
|
// Called by [StandardClaims.Validate] and exported so that custom claims types
|
||||||
|
// can call it from an overriding Validate method:
|
||||||
|
//
|
||||||
|
// func (c AppClaims) Validate(params bestjwt.ValidateParams) ([]string, error) {
|
||||||
|
// errs, _ := bestjwt.ValidateStandardClaims(c.StandardClaims, params)
|
||||||
|
// if c.Email == "" {
|
||||||
|
// errs = append(errs, "missing email claim")
|
||||||
|
// }
|
||||||
|
// if len(errs) > 0 {
|
||||||
|
// return errs, fmt.Errorf("has errors")
|
||||||
|
// }
|
||||||
|
// return nil, nil
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// params.Now must be non-zero; [JWS.Validate] ensures this before delegating.
|
||||||
|
func ValidateStandardClaims(claims StandardClaims, params ValidateParams) ([]string, error) {
|
||||||
|
var errs []string
|
||||||
|
|
||||||
|
// Required to exist and match
|
||||||
|
if len(params.Iss) > 0 || !params.IgnoreIss {
|
||||||
|
if len(claims.Iss) == 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
|
||||||
|
} else if claims.Iss != params.Iss {
|
||||||
|
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required to exist, optional match
|
||||||
|
if len(claims.Sub) == 0 {
|
||||||
|
if !params.IgnoreSub {
|
||||||
|
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
|
||||||
|
}
|
||||||
|
} else if len(params.Sub) > 0 {
|
||||||
|
if params.Sub != claims.Sub {
|
||||||
|
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required to exist and match
|
||||||
|
if len(params.Aud) > 0 || !params.IgnoreAud {
|
||||||
|
if len(claims.Aud) == 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
|
||||||
|
} else if claims.Aud != params.Aud {
|
||||||
|
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required to exist and not be in the past
|
||||||
|
if !params.IgnoreExp {
|
||||||
|
if claims.Exp <= 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
|
||||||
|
} else if claims.Exp < params.Now.Unix() {
|
||||||
|
duration := time.Since(time.Unix(claims.Exp, 0))
|
||||||
|
expTime := time.Unix(claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
|
||||||
|
errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required to exist and not be in the future
|
||||||
|
if !params.IgnoreIat {
|
||||||
|
if claims.Iat <= 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
|
||||||
|
} else if claims.Iat > params.Now.Unix() {
|
||||||
|
duration := time.Unix(claims.Iat, 0).Sub(params.Now)
|
||||||
|
iatTime := time.Unix(claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
|
||||||
|
errs = append(errs, fmt.Sprintf("'iat' (issued at) is %s in the future (%s)", formatDuration(duration), iatTime))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should exist, in the past, with optional max age
|
||||||
|
if params.MaxAge > 0 || !params.IgnoreAuthTime {
|
||||||
|
if claims.AuthTime == 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
|
||||||
|
} else {
|
||||||
|
authTime := time.Unix(claims.AuthTime, 0)
|
||||||
|
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
|
||||||
|
age := params.Now.Sub(authTime)
|
||||||
|
diff := age - params.MaxAge
|
||||||
|
if claims.AuthTime > params.Now.Unix() {
|
||||||
|
fromNow := time.Unix(claims.AuthTime, 0).Sub(params.Now)
|
||||||
|
errs = append(errs, fmt.Sprintf(
|
||||||
|
"'auth_time' of %s is %s in the future (server time %s)",
|
||||||
|
authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")),
|
||||||
|
)
|
||||||
|
} else if params.MaxAge > 0 && age > params.MaxAge {
|
||||||
|
errs = append(errs, fmt.Sprintf(
|
||||||
|
"'auth_time' of %s is %s old, exceeding max age %s by %s",
|
||||||
|
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional exact match
|
||||||
|
if params.Jti != claims.Jti {
|
||||||
|
if len(params.Jti) > 0 {
|
||||||
|
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti))
|
||||||
|
} else if !params.IgnoreJti {
|
||||||
|
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional exact match
|
||||||
|
if params.Nonce != claims.Nonce {
|
||||||
|
if len(params.Nonce) > 0 {
|
||||||
|
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce))
|
||||||
|
} else if !params.IgnoreNonce {
|
||||||
|
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should exist, optional required-set check
|
||||||
|
if !params.IgnoreAmr {
|
||||||
|
if len(claims.Amr) == 0 {
|
||||||
|
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
|
||||||
|
} else if len(params.RequiredAmrs) > 0 {
|
||||||
|
for _, required := range params.RequiredAmrs {
|
||||||
|
if !slices.Contains(claims.Amr, required) {
|
||||||
|
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional, match if present
|
||||||
|
if params.Azp != claims.Azp {
|
||||||
|
if len(params.Azp) > 0 {
|
||||||
|
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp))
|
||||||
|
} else if !params.IgnoreAzp {
|
||||||
|
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errs, fmt.Errorf("has errors")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Signing helpers ---
|
||||||
|
|
||||||
|
// SignES256 computes an ES256 signature over header.payload using a P-256 key.
|
||||||
|
// The signature is a fixed-width 64-byte raw r||s value (not ASN.1 DER).
|
||||||
|
// Each component is zero-padded to 32 bytes via [big.Int.FillBytes].
|
||||||
|
func SignES256(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
|
||||||
|
return signEC(header, payload, key, crypto.SHA256)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignES384 computes an ES384 signature over header.payload using a P-384 key.
|
||||||
|
// The signature is a fixed-width 96-byte raw r||s value.
|
||||||
|
func SignES384(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
|
||||||
|
return signEC(header, payload, key, crypto.SHA384)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignES512 computes an ES512 signature over header.payload using a P-521 key.
|
||||||
|
// The signature is a fixed-width 132-byte raw r||s value.
|
||||||
|
func SignES512(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
|
||||||
|
return signEC(header, payload, key, crypto.SHA512)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignRS256 computes an RS256 (PKCS#1 v1.5 + SHA-256) signature over header.payload.
|
||||||
|
func SignRS256(header, payload string, key *rsa.PrivateKey) ([]byte, error) {
|
||||||
|
digest := digestFor(crypto.SHA256, header+"."+payload)
|
||||||
|
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, digest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SignRS256: %w", err)
|
||||||
|
}
|
||||||
|
return sig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeToJWT appends a base64url-encoded signature to a signing input string.
|
||||||
|
func EncodeToJWT(signingInput string, signature []byte) string {
|
||||||
|
return signingInput + "." + base64.RawURLEncoding.EncodeToString(signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- EC private key JWK utilities ---
|
||||||
|
|
||||||
|
// JWK represents a private key in JSON Web Key format (EC only).
|
||||||
|
type JWK struct {
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
Crv string `json:"crv"`
|
||||||
|
D string `json:"d"`
|
||||||
|
X string `json:"x"`
|
||||||
|
Y string `json:"y"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJWK parses an EC private key from a JWK struct.
|
||||||
|
func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) {
|
||||||
|
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JWK X: %v", err)
|
||||||
|
}
|
||||||
|
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JWK Y: %v", err)
|
||||||
|
}
|
||||||
|
d, err := base64.RawURLEncoding.DecodeString(jwk.D)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JWK D: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ecdsa.PrivateKey{
|
||||||
|
PublicKey: ecdsa.PublicKey{
|
||||||
|
Curve: elliptic.P256(),
|
||||||
|
X: new(big.Int).SetBytes(x),
|
||||||
|
Y: new(big.Int).SetBytes(y),
|
||||||
|
},
|
||||||
|
D: new(big.Int).SetBytes(d),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thumbprint computes the RFC 7638 JWK Thumbprint for an EC public key.
|
||||||
|
func (jwk JWK) Thumbprint() (string, error) {
|
||||||
|
data := map[string]string{
|
||||||
|
"crv": jwk.Crv,
|
||||||
|
"kty": jwk.Kty,
|
||||||
|
"x": jwk.X,
|
||||||
|
"y": jwk.Y,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
hash := sha256.Sum256(jsonData)
|
||||||
|
return base64.RawURLEncoding.EncodeToString(hash[:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Internal helpers ---
|
||||||
|
|
||||||
|
// algForECKey returns the JWA algorithm name and hash function for an EC public
|
||||||
|
// key, inferred from its curve. Returns an error for unsupported curves.
|
||||||
|
func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) {
|
||||||
|
switch pub.Curve {
|
||||||
|
case elliptic.P256():
|
||||||
|
return "ES256", crypto.SHA256, nil
|
||||||
|
case elliptic.P384():
|
||||||
|
return "ES384", crypto.SHA384, nil
|
||||||
|
case elliptic.P521():
|
||||||
|
return "ES512", crypto.SHA512, nil
|
||||||
|
default:
|
||||||
|
return "", 0, fmt.Errorf("unsupported EC curve: %s", pub.Curve.Params().Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// digestFor hashes data with h and returns the digest.
|
||||||
|
// Uses fixed-size stack arrays for the three supported hashes to avoid
|
||||||
|
// unnecessary heap allocation on the hot signing/verification path.
|
||||||
|
func digestFor(h crypto.Hash, data string) []byte {
|
||||||
|
switch h {
|
||||||
|
case crypto.SHA256:
|
||||||
|
d := sha256.Sum256([]byte(data))
|
||||||
|
return d[:]
|
||||||
|
case crypto.SHA384:
|
||||||
|
d := sha512.Sum384([]byte(data))
|
||||||
|
return d[:]
|
||||||
|
case crypto.SHA512:
|
||||||
|
d := sha512.Sum512([]byte(data))
|
||||||
|
return d[:]
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("bestjwt: unsupported hash %v", h))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// signEC is the shared implementation for SignES256/384/512.
|
||||||
|
func signEC(header, payload string, key *ecdsa.PrivateKey, h crypto.Hash) ([]byte, error) {
|
||||||
|
digest := digestFor(h, header+"."+payload)
|
||||||
|
r, s, err := ecdsa.Sign(rand.Reader, key, digest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("signEC: %w", err)
|
||||||
|
}
|
||||||
|
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||||
|
out := make([]byte, 2*byteLen)
|
||||||
|
r.FillBytes(out[:byteLen])
|
||||||
|
s.FillBytes(out[byteLen:])
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ecdsaDERToRaw converts an ASN.1 DER-encoded ECDSA signature (as returned by
|
||||||
|
// [crypto.Signer]) to the fixed-width raw r||s format required by JWS.
|
||||||
|
// r and s are zero-padded to the curve's byte length via [big.Int.FillBytes].
|
||||||
|
func ecdsaDERToRaw(der []byte, curve elliptic.Curve) ([]byte, error) {
|
||||||
|
var sig struct{ R, S *big.Int }
|
||||||
|
if _, err := asn1.Unmarshal(der, &sig); err != nil {
|
||||||
|
return nil, fmt.Errorf("ecdsaDERToRaw: %w", err)
|
||||||
|
}
|
||||||
|
byteLen := (curve.Params().BitSize + 7) / 8
|
||||||
|
out := make([]byte, 2*byteLen)
|
||||||
|
sig.R.FillBytes(out[:byteLen])
|
||||||
|
sig.S.FillBytes(out[byteLen:])
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLBase64 is a []byte that marshals to/from raw base64url in JSON.
|
||||||
|
type URLBase64 []byte
|
||||||
|
|
||||||
|
func (s URLBase64) String() string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s URLBase64) MarshalJSON() ([]byte, error) {
|
||||||
|
encoded := base64.RawURLEncoding.EncodeToString(s)
|
||||||
|
return json.Marshal(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *URLBase64) UnmarshalJSON(data []byte) error {
|
||||||
|
dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decode base64url signature: %w", err)
|
||||||
|
}
|
||||||
|
*s = dst
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatDuration(d time.Duration) string {
|
||||||
|
if d < 0 {
|
||||||
|
d = -d
|
||||||
|
}
|
||||||
|
days := int(d / (24 * time.Hour))
|
||||||
|
d -= time.Duration(days) * 24 * time.Hour
|
||||||
|
hours := int(d / time.Hour)
|
||||||
|
d -= time.Duration(hours) * time.Hour
|
||||||
|
minutes := int(d / time.Minute)
|
||||||
|
d -= time.Duration(minutes) * time.Minute
|
||||||
|
seconds := int(d / time.Second)
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
if days > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%dd", days))
|
||||||
|
}
|
||||||
|
if hours > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%dh", hours))
|
||||||
|
}
|
||||||
|
if minutes > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%dm", minutes))
|
||||||
|
}
|
||||||
|
if seconds > 0 || len(parts) == 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%ds", seconds))
|
||||||
|
}
|
||||||
|
if seconds == 0 || len(parts) == 0 {
|
||||||
|
d -= time.Duration(seconds) * time.Second
|
||||||
|
millis := int(d / time.Millisecond)
|
||||||
|
parts = append(parts, fmt.Sprintf("%dms", millis))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(parts, " ")
|
||||||
|
}
|
||||||
478
auth/bestjwt/jwt_test.go
Normal file
478
auth/bestjwt/jwt_test.go
Normal file
@ -0,0 +1,478 @@
|
|||||||
|
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package bestjwt_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/auth/bestjwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AppClaims is an example custom claims type.
|
||||||
|
// Embedding StandardClaims promotes Validate — no boilerplate needed.
|
||||||
|
type AppClaims struct {
|
||||||
|
bestjwt.StandardClaims
|
||||||
|
Email string `json:"email"`
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StrictAppClaims overrides Validate to also require a non-empty Email.
|
||||||
|
// This demonstrates how to add application-specific validation on top of
|
||||||
|
// the standard OIDC checks.
|
||||||
|
type StrictAppClaims struct {
|
||||||
|
bestjwt.StandardClaims
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c StrictAppClaims) Validate(params bestjwt.ValidateParams) ([]string, error) {
|
||||||
|
errs, _ := bestjwt.ValidateStandardClaims(c.StandardClaims, params)
|
||||||
|
if c.Email == "" {
|
||||||
|
errs = append(errs, "missing email claim")
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errs, fmt.Errorf("has errors")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// goodClaims returns a valid AppClaims with all standard fields populated.
|
||||||
|
func goodClaims() AppClaims {
|
||||||
|
now := time.Now()
|
||||||
|
return AppClaims{
|
||||||
|
StandardClaims: bestjwt.StandardClaims{
|
||||||
|
Iss: "https://example.com",
|
||||||
|
Sub: "user123",
|
||||||
|
Aud: "myapp",
|
||||||
|
Exp: now.Add(time.Hour).Unix(),
|
||||||
|
Iat: now.Unix(),
|
||||||
|
AuthTime: now.Unix(),
|
||||||
|
Amr: []string{"pwd"},
|
||||||
|
Jti: "abc123",
|
||||||
|
Azp: "myapp",
|
||||||
|
Nonce: "nonce1",
|
||||||
|
},
|
||||||
|
Email: "user@example.com",
|
||||||
|
Roles: []string{"admin"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// goodParams returns ValidateParams matching the claims from goodClaims.
|
||||||
|
func goodParams() bestjwt.ValidateParams {
|
||||||
|
return bestjwt.ValidateParams{
|
||||||
|
Iss: "https://example.com",
|
||||||
|
Sub: "user123",
|
||||||
|
Aud: "myapp",
|
||||||
|
Jti: "abc123",
|
||||||
|
Nonce: "nonce1",
|
||||||
|
Azp: "myapp",
|
||||||
|
RequiredAmrs: []string{"pwd"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Round-trip tests (sign → encode → decode → verify → validate) ---
|
||||||
|
|
||||||
|
// TestRoundTripES256 exercises the most common path: ECDSA P-256 / ES256.
|
||||||
|
// Demonstrates the Decode(&claims) ergonomic — no generics at the call site,
|
||||||
|
// no type assertion needed to access Email after decoding.
|
||||||
|
func TestRoundTripES256(t *testing.T) {
|
||||||
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = jws.Sign(privKey); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if jws.Header.Alg != "ES256" {
|
||||||
|
t.Fatalf("expected ES256, got %s", jws.Header.Alg)
|
||||||
|
}
|
||||||
|
if len(jws.Signature) != 64 { // P-256: 32 bytes each for r and s
|
||||||
|
t.Fatalf("expected 64-byte signature, got %d", len(jws.Signature))
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, err := bestjwt.Decode(token, &decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
|
t.Fatal("signature verification failed")
|
||||||
|
}
|
||||||
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||||
|
t.Fatalf("validation failed: %v", errs)
|
||||||
|
}
|
||||||
|
// Direct field access via the local variable — no type assertion.
|
||||||
|
if decoded.Email != claims.Email {
|
||||||
|
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoundTripES384 exercises ECDSA P-384 / ES384, verifying that the
|
||||||
|
// algorithm is inferred from the key's curve and that the 96-byte r||s
|
||||||
|
// signature format is produced and verified correctly.
|
||||||
|
func TestRoundTripES384(t *testing.T) {
|
||||||
|
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = jws.Sign(privKey); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if jws.Header.Alg != "ES384" {
|
||||||
|
t.Fatalf("expected ES384, got %s", jws.Header.Alg)
|
||||||
|
}
|
||||||
|
if len(jws.Signature) != 96 { // P-384: 48 bytes each for r and s
|
||||||
|
t.Fatalf("expected 96-byte signature, got %d", len(jws.Signature))
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, err := bestjwt.Decode(token, &decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
|
t.Fatal("signature verification failed")
|
||||||
|
}
|
||||||
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||||
|
t.Fatalf("validation failed: %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoundTripES512 exercises ECDSA P-521 / ES512 and the 132-byte signature.
|
||||||
|
func TestRoundTripES512(t *testing.T) {
|
||||||
|
privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = jws.Sign(privKey); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if jws.Header.Alg != "ES512" {
|
||||||
|
t.Fatalf("expected ES512, got %s", jws.Header.Alg)
|
||||||
|
}
|
||||||
|
if len(jws.Signature) != 132 { // P-521: 66 bytes each for r and s
|
||||||
|
t.Fatalf("expected 132-byte signature, got %d", len(jws.Signature))
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, err := bestjwt.Decode(token, &decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
|
t.Fatal("signature verification failed")
|
||||||
|
}
|
||||||
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||||
|
t.Fatalf("validation failed: %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256.
|
||||||
|
func TestRoundTripRS256(t *testing.T) {
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = jws.Sign(privKey); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if jws.Header.Alg != "RS256" {
|
||||||
|
t.Fatalf("expected RS256, got %s", jws.Header.Alg)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, err := bestjwt.Decode(token, &decoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
|
t.Fatal("signature verification failed")
|
||||||
|
}
|
||||||
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||||
|
t.Fatalf("validation failed: %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Security / negative tests ---
|
||||||
|
|
||||||
|
// TestVerifyWrongKeyType verifies that an RSA public key is rejected when
|
||||||
|
// verifying a token signed with ECDSA (alg = ES256).
|
||||||
|
func TestVerifyWrongKeyType(t *testing.T) {
|
||||||
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(ecKey) // alg = "ES256"
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
|
||||||
|
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
|
||||||
|
t.Fatal("expected verification to fail: RSA key for ES256 token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyAlgCurveMismatch verifies that a P-256 key is rejected when
|
||||||
|
// verifying a token whose header claims ES384 (signed with P-384).
|
||||||
|
// Without the curve/alg consistency check this would silently return false
|
||||||
|
// from ecdsa.Verify, but the explicit check makes the rejection reason clear.
|
||||||
|
func TestVerifyAlgCurveMismatch(t *testing.T) {
|
||||||
|
p384Key, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||||
|
p256Key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(p384Key) // alg = "ES384"
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
|
||||||
|
// P-256 key must be rejected for an ES384 token.
|
||||||
|
if jws2.UnsafeVerify(&p256Key.PublicKey) {
|
||||||
|
t.Fatal("expected verification to fail: P-256 key for ES384 token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyUnknownAlg verifies that a tampered alg header is rejected.
|
||||||
|
func TestVerifyUnknownAlg(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
|
||||||
|
// Tamper: overwrite alg in the decoded header.
|
||||||
|
jws2.Header.Alg = "none"
|
||||||
|
|
||||||
|
if jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
|
t.Fatal("expected verification to fail for unknown alg")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateMissingSignatureCheck verifies that Validate fails when
|
||||||
|
// UnsafeVerify was never called (Verified is false).
|
||||||
|
func TestValidateMissingSignatureCheck(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
|
||||||
|
token := jws.Encode()
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
|
||||||
|
// Deliberately skip UnsafeVerify.
|
||||||
|
errs, err := jws2.Validate(goodParams())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation to fail: signature was not checked")
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, e := range errs {
|
||||||
|
if e == "signature was not checked" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected 'signature was not checked' in errors: %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Embedded vs overridden Validate ---
|
||||||
|
|
||||||
|
// TestPromotedValidate confirms that AppClaims (which only embeds
|
||||||
|
// StandardClaims) gets the standard OIDC validation for free via promotion,
|
||||||
|
// without writing any Validate method.
|
||||||
|
func TestPromotedValidate(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
jws2.UnsafeVerify(&privKey.PublicKey)
|
||||||
|
|
||||||
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||||
|
t.Fatalf("promoted Validate failed unexpectedly: %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOverriddenValidate confirms that a StrictAppClaims with an empty Email
|
||||||
|
// fails validation via its overridden Validate method.
|
||||||
|
func TestOverriddenValidate(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := StrictAppClaims{
|
||||||
|
StandardClaims: bestjwt.StandardClaims{
|
||||||
|
Iss: "https://example.com",
|
||||||
|
Sub: "user123",
|
||||||
|
Aud: "myapp",
|
||||||
|
Exp: now.Add(time.Hour).Unix(),
|
||||||
|
Iat: now.Unix(),
|
||||||
|
AuthTime: now.Unix(),
|
||||||
|
Amr: []string{"pwd"},
|
||||||
|
Jti: "abc123",
|
||||||
|
Azp: "myapp",
|
||||||
|
Nonce: "nonce1",
|
||||||
|
},
|
||||||
|
Email: "", // intentionally empty to trigger the override
|
||||||
|
}
|
||||||
|
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded StrictAppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
jws2.UnsafeVerify(&privKey.PublicKey)
|
||||||
|
|
||||||
|
errs, err := jws2.Validate(goodParams())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation to fail: email is empty")
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, e := range errs {
|
||||||
|
if e == "missing email claim" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected 'missing email claim' in errors: %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- JWKS / key management ---
|
||||||
|
|
||||||
|
// TestTypedKeys verifies that TypedKeys correctly filters a mixed
|
||||||
|
// []PublicJWK[Key] into typed slices without type assertions at use sites.
|
||||||
|
func TestTypedKeys(t *testing.T) {
|
||||||
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
|
||||||
|
allKeys := []bestjwt.PublicJWK[bestjwt.Key]{
|
||||||
|
{Key: &ecKey.PublicKey, KID: "ec-1", Use: "sig"},
|
||||||
|
{Key: &rsaKey.PublicKey, KID: "rsa-1", Use: "sig"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](allKeys)
|
||||||
|
if len(ecKeys) != 1 || ecKeys[0].KID != "ec-1" {
|
||||||
|
t.Errorf("unexpected EC keys: %+v", ecKeys)
|
||||||
|
}
|
||||||
|
// Typed access — no assertion needed.
|
||||||
|
_ = ecKeys[0].Key.Curve
|
||||||
|
|
||||||
|
rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](allKeys)
|
||||||
|
if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-1" {
|
||||||
|
t.Errorf("unexpected RSA keys: %+v", rsaKeys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyWithJWKSKey verifies that PublicJWK.Key can be passed directly to
|
||||||
|
// UnsafeVerify without a type assertion when using a typed PublicJWK[Key].
|
||||||
|
func TestVerifyWithJWKSKey(t *testing.T) {
|
||||||
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
jwksKey := bestjwt.PublicJWK[bestjwt.Key]{Key: &privKey.PublicKey, KID: "k1"}
|
||||||
|
|
||||||
|
claims := goodClaims()
|
||||||
|
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k1")
|
||||||
|
_, _ = jws.Sign(privKey)
|
||||||
|
token := jws.Encode()
|
||||||
|
|
||||||
|
var decoded AppClaims
|
||||||
|
jws2, _ := bestjwt.Decode(token, &decoded)
|
||||||
|
|
||||||
|
// Pass PublicJWK.Key directly — Key interface satisfies the Key constraint.
|
||||||
|
if !jws2.UnsafeVerify(jwksKey.Key) {
|
||||||
|
t.Fatal("verification via PublicJWK.Key failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodePublicJWKJSON verifies JWKS JSON parsing and TypedKeys filtering
|
||||||
|
// with real base64url-encoded key material from RFC 7517 / OIDC examples.
|
||||||
|
func TestDecodePublicJWKJSON(t *testing.T) {
|
||||||
|
jwksJSON := []byte(`{"keys":[
|
||||||
|
{"kty":"EC","crv":"P-256",
|
||||||
|
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
|
||||||
|
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
|
||||||
|
"kid":"ec-256","use":"sig"},
|
||||||
|
{"kty":"RSA",
|
||||||
|
"n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
|
||||||
|
"e":"AQAB","kid":"rsa-2048","use":"sig"}
|
||||||
|
]}`)
|
||||||
|
|
||||||
|
keys, err := bestjwt.UnmarshalPublicJWKs(jwksJSON)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(keys) != 2 {
|
||||||
|
t.Fatalf("expected 2 keys, got %d", len(keys))
|
||||||
|
}
|
||||||
|
|
||||||
|
ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](keys)
|
||||||
|
if len(ecKeys) != 1 || ecKeys[0].KID != "ec-256" {
|
||||||
|
t.Errorf("EC key mismatch: %+v", ecKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](keys)
|
||||||
|
if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-2048" {
|
||||||
|
t.Errorf("RSA key mismatch: %+v", rsaKeys)
|
||||||
|
}
|
||||||
|
}
|
||||||
229
auth/bestjwt/pub.go
Normal file
229
auth/bestjwt/pub.go
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package bestjwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Key is the constraint for the public key type parameter K used in PublicJWK.
|
||||||
|
//
|
||||||
|
// All standard-library asymmetric public key types satisfy this interface
|
||||||
|
// since Go 1.15: *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey.
|
||||||
|
//
|
||||||
|
// Note: crypto.PublicKey is defined as interface{} and does NOT satisfy Key.
|
||||||
|
// Use Key itself as the type argument for heterogeneous collections
|
||||||
|
// (e.g. []PublicJWK[Key]), since Key declares Equal and therefore satisfies
|
||||||
|
// its own constraint. Use [TypedKeys] to narrow to a concrete type.
|
||||||
|
type Key interface {
|
||||||
|
Equal(x crypto.PublicKey) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicJWK wraps a parsed public key with its JWKS metadata.
|
||||||
|
//
|
||||||
|
// K is constrained to [Key], providing type-safe access to the underlying
|
||||||
|
// key without a type assertion at each use site.
|
||||||
|
//
|
||||||
|
// For a heterogeneous JWKS endpoint (mixed RSA/EC) use PublicJWK[Key].
|
||||||
|
// For a homogeneous store use the concrete type directly (e.g.
|
||||||
|
// PublicJWK[*ecdsa.PublicKey]). Use [TypedKeys] to narrow a mixed slice.
|
||||||
|
//
|
||||||
|
// Example — sign with a known key type, no assertion needed:
|
||||||
|
//
|
||||||
|
// ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](allKeys)
|
||||||
|
// jws.UnsafeVerify(ecKeys[0].Key) // Key is *ecdsa.PublicKey directly
|
||||||
|
type PublicJWK[K Key] struct {
|
||||||
|
Key K
|
||||||
|
KID string
|
||||||
|
Use string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
|
||||||
|
type PublicJWKJSON struct {
|
||||||
|
Kty string `json:"kty"`
|
||||||
|
KID string `json:"kid"`
|
||||||
|
N string `json:"n,omitempty"` // RSA modulus
|
||||||
|
E string `json:"e,omitempty"` // RSA exponent
|
||||||
|
Crv string `json:"crv,omitempty"`
|
||||||
|
X string `json:"x,omitempty"`
|
||||||
|
Y string `json:"y,omitempty"`
|
||||||
|
Use string `json:"use,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWKsJSON is the JSON representation of a JWKS document.
|
||||||
|
type JWKsJSON struct {
|
||||||
|
Keys []PublicJWKJSON `json:"keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypedKeys filters a heterogeneous []PublicJWK[Key] slice to only those whose
|
||||||
|
// underlying key is of concrete type K, returning a typed []PublicJWK[K].
|
||||||
|
// Keys of other types are silently skipped.
|
||||||
|
//
|
||||||
|
// Example — extract only ECDSA keys from a mixed JWKS result:
|
||||||
|
//
|
||||||
|
// all, _ := bestjwt.FetchPublicJWKs(jwksURL)
|
||||||
|
// ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](all)
|
||||||
|
// rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](all)
|
||||||
|
func TypedKeys[K Key](keys []PublicJWK[Key]) []PublicJWK[K] {
|
||||||
|
var result []PublicJWK[K]
|
||||||
|
for _, k := range keys {
|
||||||
|
if typed, ok := k.Key.(K); ok {
|
||||||
|
result = append(result, PublicJWK[K]{Key: typed, KID: k.KID, Use: k.Use})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchPublicJWKs retrieves and parses a JWKS document from url.
|
||||||
|
// Keys are returned as []PublicJWK[Key] since a JWKS endpoint may contain a
|
||||||
|
// mix of key types. Use [TypedKeys] to narrow to a concrete type.
|
||||||
|
func FetchPublicJWKs(url string) ([]PublicJWK[Key], error) {
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return DecodePublicJWKs(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPublicJWKs reads and parses a JWKS document from a file path.
|
||||||
|
func ReadPublicJWKs(filePath string) ([]PublicJWK[Key], error) {
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open JWKS file '%s': %w", filePath, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = file.Close() }()
|
||||||
|
return DecodePublicJWKs(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalPublicJWKs parses a JWKS document from raw JSON bytes.
|
||||||
|
func UnmarshalPublicJWKs(data []byte) ([]PublicJWK[Key], error) {
|
||||||
|
var jwks JWKsJSON
|
||||||
|
if err := json.Unmarshal(data, &jwks); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
|
||||||
|
}
|
||||||
|
return DecodePublicJWKsJSON(jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodePublicJWKs parses a JWKS document from an [io.Reader].
|
||||||
|
func DecodePublicJWKs(r io.Reader) ([]PublicJWK[Key], error) {
|
||||||
|
var jwks JWKsJSON
|
||||||
|
if err := json.NewDecoder(r).Decode(&jwks); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
|
||||||
|
}
|
||||||
|
return DecodePublicJWKsJSON(jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
|
||||||
|
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK[Key], error) {
|
||||||
|
var keys []PublicJWK[Key]
|
||||||
|
for _, jwk := range jwks.Keys {
|
||||||
|
key, err := DecodePublicJWK(jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse public jwk '%s': %w", jwk.KID, err)
|
||||||
|
}
|
||||||
|
keys = append(keys, *key)
|
||||||
|
}
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid RSA or ECDSA keys found")
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodePublicJWK parses a single [PublicJWKJSON] into a PublicJWK[Key].
|
||||||
|
// Supports RSA (minimum 1024-bit) and EC (P-256, P-384, P-521) keys.
|
||||||
|
func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK[Key], error) {
|
||||||
|
switch jwk.Kty {
|
||||||
|
case "RSA":
|
||||||
|
key, err := decodeRSAPublicJWK(jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse RSA key '%s': %w", jwk.KID, err)
|
||||||
|
}
|
||||||
|
if key.Size() < 128 { // 1024 bits minimum
|
||||||
|
return nil, fmt.Errorf("RSA key '%s' too small: %d bytes", jwk.KID, key.Size())
|
||||||
|
}
|
||||||
|
return &PublicJWK[Key]{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
|
||||||
|
|
||||||
|
case "EC":
|
||||||
|
key, err := decodeECDSAPublicJWK(jwk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse EC key '%s': %w", jwk.KID, err)
|
||||||
|
}
|
||||||
|
return &PublicJWK[Key]{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported key type '%s' for kid '%s'", jwk.Kty, jwk.KID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
|
||||||
|
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid RSA modulus: %w", err)
|
||||||
|
}
|
||||||
|
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid RSA exponent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
eInt := new(big.Int).SetBytes(e).Int64()
|
||||||
|
if eInt > int64(^uint(0)>>1) || eInt < 0 {
|
||||||
|
return nil, fmt.Errorf("RSA exponent too large or negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rsa.PublicKey{
|
||||||
|
N: new(big.Int).SetBytes(n),
|
||||||
|
E: int(eInt),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeECDSAPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) {
|
||||||
|
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid ECDSA X: %w", err)
|
||||||
|
}
|
||||||
|
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid ECDSA Y: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var curve elliptic.Curve
|
||||||
|
switch jwk.Crv {
|
||||||
|
case "P-256":
|
||||||
|
curve = elliptic.P256()
|
||||||
|
case "P-384":
|
||||||
|
curve = elliptic.P384()
|
||||||
|
case "P-521":
|
||||||
|
curve = elliptic.P521()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported ECDSA curve: %s", jwk.Crv)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ecdsa.PublicKey{
|
||||||
|
Curve: curve,
|
||||||
|
X: new(big.Int).SetBytes(x),
|
||||||
|
Y: new(big.Int).SetBytes(y),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user