feat(auth/embeddedjwt): add embedded-struct JWT/JWS/JWK package

Claims via embedded structs rather than generics:

- Decode(token, &claims) pattern: JSON payload unmarshaled directly into
  the caller's pre-allocated struct, stored in jws.Claims; custom fields
  accessible through the local variable without a type assertion
- StandardClaims.Validate promoted to any embedding struct via value
  receiver; override Validate on the outer struct for custom checks,
  calling ValidateStandardClaims to preserve standard OIDC validation
- Sign(crypto.Signer): algorithm set from key.Public() type switch;
  ES256 (P-256) and RS256 (PKCS#1 v1.5) supported; works with HSM/KMS
- ecdsaDERToRaw: converts ASN.1 DER output of crypto.Signer to raw r||s
- SignES256 uses FillBytes for correct zero-padded r||s (no leading-zero bug)
- UnsafeVerify(Key): dispatches on Header.Alg; ES256 and RS256 supported
- Non-generic PublicJWK with ECDSA()/RSA() typed accessor methods
  (contrast: bestjwt uses generic PublicJWK[K] + TypedKeys[K])
- JWKS fetch/parse: FetchPublicJWKs, ReadPublicJWKs, UnmarshalPublicJWKs
  for RSA and EC (P-256/384/521) keys
- 10 tests covering round trips, promoted/overridden validate, wrong key,
  wrong key type, unknown alg, JWKS accessors, and JWKS JSON parsing
This commit is contained in:
AJ ONeal 2026-03-12 17:46:04 -06:00
parent fac58cf1ad
commit 83b22dbb86
No known key found for this signature in database
4 changed files with 1179 additions and 0 deletions

3
auth/embeddedjwt/go.mod Normal file
View File

@ -0,0 +1,3 @@
module github.com/therootcompany/golib/auth/embeddedjwt
go 1.24.0

597
auth/embeddedjwt/jwt.go Normal file
View File

@ -0,0 +1,597 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package embeddedjwt
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"slices"
"strings"
"time"
)
// Claims is the interface that custom claims types must satisfy.
//
// Because [StandardClaims] implements Claims with a value receiver, any struct
// that embeds StandardClaims satisfies Claims automatically via method promotion
// — no boilerplate required. Override Validate on the outer struct to add
// application-specific checks.
type Claims interface {
Validate(params ValidateParams) ([]string, error)
}
// JWS is a decoded JSON Web Signature / JWT.
//
// Claims is stored as the [Claims] interface so that any embedded-struct type
// can be used without generics. Access the concrete type via type assertion or,
// more conveniently, via the pointer you passed to [Decode].
//
// Typical usage:
//
// var claims AppClaims
// jws, err := embeddedjwt.Decode(tokenString, &claims)
// jws.UnsafeVerify(pubKey)
// errs, err := jws.Validate(params)
// // claims.Email, claims.Roles, etc. are already populated
type JWS struct {
Protected string `json:"-"` // base64url-encoded header
Header StandardHeader `json:"header"`
Payload string `json:"-"` // base64url-encoded claims
Claims Claims `json:"claims"`
Signature URLBase64 `json:"signature"`
Verified bool `json:"-"`
}
// StandardHeader holds the standard JOSE header fields.
type StandardHeader struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
// StandardClaims holds the registered JWT claim names defined in RFC 7519
// and extended by OpenID Connect Core.
//
// Embed StandardClaims in your own struct to satisfy [Claims] automatically:
//
// type AppClaims struct {
// embeddedjwt.StandardClaims
// Email string `json:"email"`
// }
// // AppClaims now satisfies Claims via promoted Validate.
// // Override Validate on AppClaims to add custom checks.
type StandardClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Aud string `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
AuthTime int64 `json:"auth_time"`
Nonce string `json:"nonce,omitempty"`
Amr []string `json:"amr"`
Azp string `json:"azp,omitempty"`
Jti string `json:"jti"`
}
// Validate implements [Claims] by checking all standard OIDC/JWT claim fields.
//
// This method is promoted to any struct that embeds [StandardClaims], so
// embedding structs satisfy Claims without writing any additional code.
// params.Now must be non-zero; [JWS.Validate] ensures this before delegating.
func (c StandardClaims) Validate(params ValidateParams) ([]string, error) {
return ValidateStandardClaims(c, params)
}
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
//
// claims must be a pointer to the caller's pre-allocated claims struct
// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it,
// and the same pointer is stored in jws.Claims. This means callers can
// access custom fields through their own variable without a type assertion:
//
// var claims AppClaims
// jws, err := embeddedjwt.Decode(token, &claims)
// // claims.Email is already set; no type assertion needed
//
// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first.
func Decode(tokenStr string, claims Claims) (*JWS, error) {
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
var jws JWS
var sigEnc string
jws.Protected, jws.Payload, sigEnc = parts[0], parts[1], parts[2]
header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
if err != nil {
return nil, fmt.Errorf("invalid header encoding: %v", err)
}
if err := json.Unmarshal(header, &jws.Header); err != nil {
return nil, fmt.Errorf("invalid header JSON: %v", err)
}
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, fmt.Errorf("invalid claims encoding: %v", err)
}
// Unmarshal into the concrete type behind the Claims interface.
// json.Unmarshal receives the concrete pointer via reflection.
if err := json.Unmarshal(payload, claims); err != nil {
return nil, fmt.Errorf("invalid claims JSON: %v", err)
}
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
return nil, fmt.Errorf("invalid signature encoding: %v", err)
}
jws.Claims = claims
return &jws, nil
}
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
//
// kid identifies the signing key. The "alg" header field is set automatically
// when [JWS.Sign] is called. Call [JWS.Encode] to produce a compact JWT string
// after signing.
func NewJWSFromClaims(claims Claims, kid string) (*JWS, error) {
var jws JWS
jws.Header = StandardHeader{
// Alg is set by Sign based on the key type.
Kid: kid,
Typ: "JWT",
}
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
claimsJSON, _ := json.Marshal(claims)
jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
jws.Claims = claims
return &jws, nil
}
// Sign signs the JWS in-place using the provided [crypto.Signer].
// It sets the "alg" header field based on the public key type and re-encodes
// the protected header before signing, so the signed input is always consistent.
//
// Supported public key types (via Signer.Public()):
// - *ecdsa.PublicKey → ES256 (ECDSA P-256, raw r||s)
// - *rsa.PublicKey → RS256 (PKCS#1 v1.5 + SHA-256)
//
// Because the parameter is [crypto.Signer] rather than a concrete key type,
// hardware-backed keys (HSM, OS keychain, etc.) work without modification.
func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
switch pub := key.Public().(type) {
case *ecdsa.PublicKey:
jws.Header.Alg = "ES256"
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload))
// crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS.
derSig, err := key.Sign(rand.Reader, hash[:], crypto.SHA256)
if err != nil {
return nil, fmt.Errorf("Sign ES256: %w", err)
}
jws.Signature, err = ecdsaDERToRaw(derSig, pub.Curve)
return jws.Signature, err
case *rsa.PublicKey:
jws.Header.Alg = "RS256"
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload))
// crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly.
var err error
jws.Signature, err = key.Sign(rand.Reader, hash[:], crypto.SHA256)
return jws.Signature, err
default:
return nil, fmt.Errorf("Sign: unsupported public key type %T (supported: *ecdsa.PublicKey, *rsa.PublicKey)", key.Public())
}
}
// Encode produces the compact JWT string (header.payload.signature).
func (jws JWS) Encode() string {
sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature)
return jws.Protected + "." + jws.Payload + "." + sigEnc
}
// UnsafeVerify checks the signature using the algorithm in the JWT header and
// sets jws.Verified on success. It only checks the signature — use
// [JWS.Validate] to check claim values.
//
// pub must be of the concrete type matching the header alg (e.g.
// *ecdsa.PublicKey for ES256). Callers can pass PublicJWK.Key directly
// without a type assertion.
//
// Currently supported: ES256, RS256.
func (jws *JWS) UnsafeVerify(pub Key) bool {
signingInput := jws.Protected + "." + jws.Payload
hash := sha256.Sum256([]byte(signingInput))
switch jws.Header.Alg {
case "ES256":
k, ok := pub.(*ecdsa.PublicKey)
if !ok || len(jws.Signature) != 64 {
jws.Verified = false
return false
}
r := new(big.Int).SetBytes(jws.Signature[:32])
s := new(big.Int).SetBytes(jws.Signature[32:])
jws.Verified = ecdsa.Verify(k, hash[:], r, s)
case "RS256":
k, ok := pub.(*rsa.PublicKey)
if !ok {
jws.Verified = false
return false
}
jws.Verified = rsa.VerifyPKCS1v15(k, crypto.SHA256, hash[:], jws.Signature) == nil
default:
jws.Verified = false
}
return jws.Verified
}
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
// additionally enforces that the signature was verified (unless params.IgnoreSig).
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
if params.Now.IsZero() {
params.Now = time.Now()
}
errs, _ := jws.Claims.Validate(params)
if !params.IgnoreSig && !jws.Verified {
errs = append(errs, "signature was not checked")
}
if len(errs) > 0 {
timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST"))
if loc, err := time.LoadLocation("Local"); err == nil {
timeInfo += fmt.Sprintf(" %s", loc)
}
errs = append(errs, timeInfo)
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
// ValidateParams holds validation configuration.
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type ValidateParams struct {
Now time.Time
IgnoreIss bool
Iss string
IgnoreSub bool
Sub string
IgnoreAud bool
Aud string
IgnoreExp bool
IgnoreJti bool
Jti string
IgnoreIat bool
IgnoreAuthTime bool
MaxAge time.Duration
IgnoreNonce bool
Nonce string
IgnoreAmr bool
RequiredAmrs []string
IgnoreAzp bool
Azp string
IgnoreSig bool
}
// ValidateStandardClaims checks the registered JWT/OIDC claim fields.
//
// This is called by [StandardClaims.Validate] and is exported so that
// custom claims types can call it from an overriding Validate method:
//
// func (c AppClaims) Validate(params embeddedjwt.ValidateParams) ([]string, error) {
// errs, _ := embeddedjwt.ValidateStandardClaims(c.StandardClaims, params)
// if c.Email == "" {
// errs = append(errs, "missing email claim")
// }
// if len(errs) > 0 {
// return errs, fmt.Errorf("has errors")
// }
// return nil, nil
// }
//
// params.Now must be non-zero; [JWS.Validate] ensures this before delegating.
func ValidateStandardClaims(claims StandardClaims, params ValidateParams) ([]string, error) {
var errs []string
// Required to exist and match
if len(params.Iss) > 0 || !params.IgnoreIss {
if len(claims.Iss) == 0 {
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
} else if claims.Iss != params.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss))
}
}
// Required to exist, optional match
if len(claims.Sub) == 0 {
if !params.IgnoreSub {
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
}
} else if len(params.Sub) > 0 {
if params.Sub != claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub))
}
}
// Required to exist and match
if len(params.Aud) > 0 || !params.IgnoreAud {
if len(claims.Aud) == 0 {
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
} else if claims.Aud != params.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud))
}
}
// Required to exist and not be in the past
if !params.IgnoreExp {
if claims.Exp <= 0 {
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
} else if claims.Exp < params.Now.Unix() {
duration := time.Since(time.Unix(claims.Exp, 0))
expTime := time.Unix(claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
}
}
// Required to exist and not be in the future
if !params.IgnoreIat {
if claims.Iat <= 0 {
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
} else if claims.Iat > params.Now.Unix() {
duration := time.Unix(claims.Iat, 0).Sub(params.Now)
iatTime := time.Unix(claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("'iat' (issued at) is %s in the future (%s)", formatDuration(duration), iatTime))
}
}
// Should exist, in the past, with optional max age
if params.MaxAge > 0 || !params.IgnoreAuthTime {
if claims.AuthTime == 0 {
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
} else {
authTime := time.Unix(claims.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
age := params.Now.Sub(authTime)
diff := age - params.MaxAge
if claims.AuthTime > params.Now.Unix() {
fromNow := time.Unix(claims.AuthTime, 0).Sub(params.Now)
errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s in the future (server time %s)",
authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")),
)
} else if params.MaxAge > 0 && age > params.MaxAge {
errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s old, exceeding max age %s by %s",
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)),
)
}
}
}
// Optional exact match
if params.Jti != claims.Jti {
if len(params.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti))
} else if !params.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
}
}
// Optional exact match
if params.Nonce != claims.Nonce {
if len(params.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce))
} else if !params.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
}
}
// Should exist, optional required-set check
if !params.IgnoreAmr {
if len(claims.Amr) == 0 {
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
} else if len(params.RequiredAmrs) > 0 {
for _, required := range params.RequiredAmrs {
if !slices.Contains(claims.Amr, required) {
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
}
}
}
}
// Optional, match if present
if params.Azp != claims.Azp {
if len(params.Azp) > 0 {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp))
} else if !params.IgnoreAzp {
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
}
}
if len(errs) > 0 {
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
// --- Private key / signing helpers ---
// JWK represents a private key in JSON Web Key format (EC only).
type JWK struct {
Kty string `json:"kty"`
Crv string `json:"crv"`
D string `json:"d"`
X string `json:"x"`
Y string `json:"y"`
}
// UnmarshalJWK parses an EC private key from a JWK struct.
func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid JWK X: %v", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("invalid JWK Y: %v", err)
}
d, err := base64.RawURLEncoding.DecodeString(jwk.D)
if err != nil {
return nil, fmt.Errorf("invalid JWK D: %v", err)
}
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
},
D: new(big.Int).SetBytes(d),
}, nil
}
// Thumbprint computes the RFC 7638 JWK Thumbprint for an EC public key.
func (jwk JWK) Thumbprint() (string, error) {
data := map[string]string{
"crv": jwk.Crv,
"kty": jwk.Kty,
"x": jwk.X,
"y": jwk.Y,
}
jsonData, err := json.Marshal(data)
if err != nil {
return "", err
}
hash := sha256.Sum256(jsonData)
return base64.RawURLEncoding.EncodeToString(hash[:]), nil
}
// SignES256 computes an ES256 signature over header.payload.
// The signature is a fixed-width raw r||s value (not ASN.1 DER).
// r and s are zero-padded to the curve's byte length via [big.Int.FillBytes].
func SignES256(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
hash := sha256.Sum256([]byte(header + "." + payload))
r, s, err := ecdsa.Sign(rand.Reader, key, hash[:])
if err != nil {
return nil, fmt.Errorf("SignES256: %w", err)
}
byteLen := (key.Curve.Params().BitSize + 7) / 8
out := make([]byte, 2*byteLen)
r.FillBytes(out[:byteLen])
s.FillBytes(out[byteLen:])
return out, nil
}
// SignRS256 computes an RS256 (PKCS#1 v1.5 + SHA-256) signature over header.payload.
func SignRS256(header, payload string, key *rsa.PrivateKey) ([]byte, error) {
hash := sha256.Sum256([]byte(header + "." + payload))
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hash[:])
if err != nil {
return nil, fmt.Errorf("SignRS256: %w", err)
}
return sig, nil
}
// ecdsaDERToRaw converts an ASN.1 DER ECDSA signature (as returned by
// [crypto.Signer]) to the fixed-width r||s format required by JWS.
func ecdsaDERToRaw(der []byte, curve elliptic.Curve) ([]byte, error) {
var sig struct{ R, S *big.Int }
if _, err := asn1.Unmarshal(der, &sig); err != nil {
return nil, fmt.Errorf("ecdsaDERToRaw: %w", err)
}
byteLen := (curve.Params().BitSize + 7) / 8
out := make([]byte, 2*byteLen)
sig.R.FillBytes(out[:byteLen])
sig.S.FillBytes(out[byteLen:])
return out, nil
}
// EncodeToJWT appends a base64url-encoded signature to a signing input.
func EncodeToJWT(signingInput string, signature []byte) string {
sigEnc := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + sigEnc
}
// URLBase64 is a []byte that marshals to/from raw base64url in JSON.
type URLBase64 []byte
func (s URLBase64) String() string {
return base64.RawURLEncoding.EncodeToString(s)
}
func (s URLBase64) MarshalJSON() ([]byte, error) {
encoded := base64.RawURLEncoding.EncodeToString(s)
return json.Marshal(encoded)
}
func (s *URLBase64) UnmarshalJSON(data []byte) error {
dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data)
if err != nil {
return fmt.Errorf("decode base64url signature: %w", err)
}
*s = dst
return nil
}
func formatDuration(d time.Duration) string {
if d < 0 {
d = -d
}
days := int(d / (24 * time.Hour))
d -= time.Duration(days) * 24 * time.Hour
hours := int(d / time.Hour)
d -= time.Duration(hours) * time.Hour
minutes := int(d / time.Minute)
d -= time.Duration(minutes) * time.Minute
seconds := int(d / time.Second)
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%dd", days))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%dh", hours))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%dm", minutes))
}
if seconds > 0 || len(parts) == 0 {
parts = append(parts, fmt.Sprintf("%ds", seconds))
}
if seconds == 0 || len(parts) == 0 {
d -= time.Duration(seconds) * time.Second
millis := int(d / time.Millisecond)
parts = append(parts, fmt.Sprintf("%dms", millis))
}
return strings.Join(parts, " ")
}

View File

@ -0,0 +1,361 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package embeddedjwt_test
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
"testing"
"time"
"github.com/therootcompany/golib/auth/embeddedjwt"
)
// AppClaims embeds StandardClaims and gains Validate via promotion.
// No Validate override — demonstrates zero-boilerplate satisfaction of Claims.
type AppClaims struct {
embeddedjwt.StandardClaims
Email string `json:"email"`
Roles []string `json:"roles"`
}
// StrictAppClaims overrides Validate to also require a non-empty Email,
// demonstrating how to layer application-specific checks on top of the
// promoted standard validation.
type StrictAppClaims struct {
embeddedjwt.StandardClaims
Email string `json:"email"`
}
func (c StrictAppClaims) Validate(params embeddedjwt.ValidateParams) ([]string, error) {
errs, _ := embeddedjwt.ValidateStandardClaims(c.StandardClaims, params)
if c.Email == "" {
errs = append(errs, "missing email claim")
}
if len(errs) > 0 {
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
func goodClaims() AppClaims {
now := time.Now()
return AppClaims{
StandardClaims: embeddedjwt.StandardClaims{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Exp: now.Add(time.Hour).Unix(),
Iat: now.Unix(),
AuthTime: now.Unix(),
Amr: []string{"pwd"},
Jti: "abc123",
Azp: "myapp",
Nonce: "nonce1",
},
Email: "user@example.com",
Roles: []string{"admin"},
}
}
func goodParams() embeddedjwt.ValidateParams {
return embeddedjwt.ValidateParams{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Jti: "abc123",
Nonce: "nonce1",
Azp: "myapp",
RequiredAmrs: []string{"pwd"},
}
}
// TestRoundTrip is the primary happy path: sign, encode, decode, verify,
// validate — and confirm that custom fields are accessible without a type
// assertion via the local &claims pointer.
func TestRoundTrip(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := embeddedjwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "ES256" {
t.Fatalf("expected ES256, got %s", jws.Header.Alg)
}
token := jws.Encode()
var decoded AppClaims
jws2, err := embeddedjwt.Decode(token, &decoded)
if err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
// Access custom field directly — no type assertion on jws2.Claims needed.
if decoded.Email != claims.Email {
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
}
}
// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256.
func TestRoundTripRS256(t *testing.T) {
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := embeddedjwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "RS256" {
t.Fatalf("expected RS256, got %s", jws.Header.Alg)
}
token := jws.Encode()
var decoded AppClaims
jws2, err := embeddedjwt.Decode(token, &decoded)
if err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
}
// TestPromotedValidate confirms that AppClaims satisfies Claims via the
// promoted Validate from embedded StandardClaims, with no method written.
func TestPromotedValidate(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := embeddedjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2.UnsafeVerify(&privKey.PublicKey)
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("promoted Validate failed unexpectedly: %v", errs)
}
}
// TestOverriddenValidate confirms that StrictAppClaims.Validate is called
// (not the promoted one) and that the missing Email is caught.
func TestOverriddenValidate(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
now := time.Now()
claims := StrictAppClaims{
StandardClaims: embeddedjwt.StandardClaims{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Exp: now.Add(time.Hour).Unix(),
Iat: now.Unix(),
AuthTime: now.Unix(),
Amr: []string{"pwd"},
Jti: "abc123",
Azp: "myapp",
Nonce: "nonce1",
},
Email: "", // intentionally empty
}
jws, _ := embeddedjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded StrictAppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2.UnsafeVerify(&privKey.PublicKey)
errs, err := jws2.Validate(goodParams())
if err == nil {
t.Fatal("expected validation to fail: email is empty")
}
found := false
for _, e := range errs {
if e == "missing email claim" {
found = true
}
}
if !found {
t.Fatalf("expected 'missing email claim' in errors: %v", errs)
}
}
// TestUnsafeVerifyWrongKey confirms that a different key's public key does
// not verify the signature.
func TestUnsafeVerifyWrongKey(t *testing.T) {
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := embeddedjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(signingKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
if jws2.UnsafeVerify(&wrongKey.PublicKey) {
t.Fatal("expected verification to fail with wrong key")
}
}
// TestVerifyWrongKeyType confirms that an RSA key is rejected for an ES256 token.
func TestVerifyWrongKeyType(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
claims := goodClaims()
jws, _ := embeddedjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(ecKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
t.Fatal("expected verification to fail: RSA key for ES256 token")
}
}
// TestVerifyUnknownAlg confirms that a tampered alg header is rejected.
func TestVerifyUnknownAlg(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := embeddedjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2.Header.Alg = "none"
if jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("expected verification to fail for unknown alg")
}
}
// TestVerifyWithJWKSKey confirms that PublicJWK.Key can be passed directly to
// UnsafeVerify without a type assertion.
func TestVerifyWithJWKSKey(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
jwksKey := embeddedjwt.PublicJWK{Key: &privKey.PublicKey, KID: "k1"}
claims := goodClaims()
jws, _ := embeddedjwt.NewJWSFromClaims(&claims, "k1")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
if !jws2.UnsafeVerify(jwksKey.Key) {
t.Fatal("verification via PublicJWK.Key failed")
}
}
// TestPublicJWKAccessors confirms the ECDSA() and RSA() typed accessor methods.
func TestPublicJWKAccessors(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
ecJWK := embeddedjwt.PublicJWK{Key: &ecKey.PublicKey, KID: "ec-1"}
rsaJWK := embeddedjwt.PublicJWK{Key: &rsaKey.PublicKey, KID: "rsa-1"}
if k, ok := ecJWK.ECDSA(); !ok || k == nil {
t.Error("expected ECDSA() to succeed for EC key")
}
if _, ok := ecJWK.RSA(); ok {
t.Error("expected RSA() to fail for EC key")
}
if k, ok := rsaJWK.RSA(); !ok || k == nil {
t.Error("expected RSA() to succeed for RSA key")
}
if _, ok := rsaJWK.ECDSA(); ok {
t.Error("expected ECDSA() to fail for RSA key")
}
}
// TestDecodePublicJWKJSON verifies JWKS JSON parsing and the typed accessors
// with real base64url-encoded key material from RFC 7517 / OIDC examples.
func TestDecodePublicJWKJSON(t *testing.T) {
jwksJSON := []byte(`{"keys":[
{"kty":"EC","crv":"P-256",
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"kid":"ec-256","use":"sig"},
{"kty":"RSA",
"n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e":"AQAB","kid":"rsa-2048","use":"sig"}
]}`)
keys, err := embeddedjwt.UnmarshalPublicJWKs(jwksJSON)
if err != nil {
t.Fatal(err)
}
if len(keys) != 2 {
t.Fatalf("expected 2 keys, got %d", len(keys))
}
var ecCount, rsaCount int
for _, k := range keys {
if _, ok := k.ECDSA(); ok {
ecCount++
if k.KID != "ec-256" {
t.Errorf("unexpected EC kid: %s", k.KID)
}
}
if _, ok := k.RSA(); ok {
rsaCount++
if k.KID != "rsa-2048" {
t.Errorf("unexpected RSA kid: %s", k.KID)
}
}
}
if ecCount != 1 {
t.Errorf("expected 1 EC key, got %d", ecCount)
}
if rsaCount != 1 {
t.Errorf("expected 1 RSA key, got %d", rsaCount)
}
}

218
auth/embeddedjwt/pub.go Normal file
View File

@ -0,0 +1,218 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package embeddedjwt
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"os"
"time"
)
// Key is the interface satisfied by all standard-library asymmetric public key
// types since Go 1.15: *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey.
//
// It is used as the field type in [PublicJWK] so that a single slice can hold
// mixed key types, and as the parameter type of [JWS.UnsafeVerify] so that
// callers can pass PublicJWK.Key directly without a type assertion.
type Key interface {
Equal(x crypto.PublicKey) bool
}
// PublicJWK wraps a parsed public key with its JWKS metadata.
//
// Key is stored as the [Key] interface to allow mixed RSA/EC slices from a
// real JWKS endpoint. Use the [PublicJWK.ECDSA] and [PublicJWK.RSA] accessor
// methods to obtain a typed key when the algorithm is known.
//
// Example:
//
// keys, _ := embeddedjwt.FetchPublicJWKs(jwksURL)
// for _, k := range keys {
// if ec, ok := k.ECDSA(); ok {
// jws.UnsafeVerify(ec)
// }
// }
type PublicJWK struct {
Key Key
KID string
Use string
}
// ECDSA returns the underlying key as *ecdsa.PublicKey, or (nil, false).
func (p PublicJWK) ECDSA() (*ecdsa.PublicKey, bool) {
k, ok := p.Key.(*ecdsa.PublicKey)
return k, ok
}
// RSA returns the underlying key as *rsa.PublicKey, or (nil, false).
func (p PublicJWK) RSA() (*rsa.PublicKey, bool) {
k, ok := p.Key.(*rsa.PublicKey)
return k, ok
}
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
type PublicJWKJSON struct {
Kty string `json:"kty"`
KID string `json:"kid"`
N string `json:"n,omitempty"` // RSA modulus
E string `json:"e,omitempty"` // RSA exponent
Crv string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
Use string `json:"use,omitempty"`
}
// JWKsJSON is the JSON representation of a JWKS document.
type JWKsJSON struct {
Keys []PublicJWKJSON `json:"keys"`
}
// FetchPublicJWKs retrieves and parses a JWKS document from url.
func FetchPublicJWKs(url string) ([]PublicJWK, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return DecodePublicJWKs(resp.Body)
}
// ReadPublicJWKs reads and parses a JWKS document from a file path.
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open JWKS file '%s': %w", filePath, err)
}
defer func() { _ = file.Close() }()
return DecodePublicJWKs(file)
}
// UnmarshalPublicJWKs parses a JWKS document from raw JSON bytes.
func UnmarshalPublicJWKs(data []byte) ([]PublicJWK, error) {
var jwks JWKsJSON
if err := json.Unmarshal(data, &jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
}
return DecodePublicJWKsJSON(jwks)
}
// DecodePublicJWKs parses a JWKS document from an [io.Reader].
func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
var jwks JWKsJSON
if err := json.NewDecoder(r).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
}
return DecodePublicJWKsJSON(jwks)
}
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into public keys.
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
var keys []PublicJWK
for _, jwk := range jwks.Keys {
key, err := DecodePublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse public jwk '%s': %w", jwk.KID, err)
}
keys = append(keys, *key)
}
if len(keys) == 0 {
return nil, fmt.Errorf("no valid RSA or ECDSA keys found")
}
return keys, nil
}
// DecodePublicJWK parses a single [PublicJWKJSON] into a PublicJWK.
// Supports RSA (minimum 1024-bit) and EC (P-256, P-384, P-521) keys.
func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK, error) {
switch jwk.Kty {
case "RSA":
key, err := decodeRSAPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA key '%s': %w", jwk.KID, err)
}
if key.Size() < 128 { // 1024 bits minimum
return nil, fmt.Errorf("RSA key '%s' too small: %d bytes", jwk.KID, key.Size())
}
return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
case "EC":
key, err := decodeECDSAPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse EC key '%s': %w", jwk.KID, err)
}
return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
default:
return nil, fmt.Errorf("unsupported key type '%s' for kid '%s'", jwk.Kty, jwk.KID)
}
}
func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("invalid RSA modulus: %w", err)
}
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("invalid RSA exponent: %w", err)
}
eInt := new(big.Int).SetBytes(e).Int64()
if eInt > int64(^uint(0)>>1) || eInt < 0 {
return nil, fmt.Errorf("RSA exponent too large or negative")
}
return &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(eInt),
}, nil
}
func decodeECDSAPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid ECDSA X: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("invalid ECDSA Y: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported ECDSA curve: %s", jwk.Crv)
}
return &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}, nil
}