feat(auth/ajwt): add first-principles JWT/JWS/JWK package

Design goals from first principles:

- JWS holds only parsed structure (header, payload, sig) — no Claims
  interface, no Verified flag. Removes footguns from the simpler packages.

- Issuer owns key management and verification. Verify does key lookup by
  kid, sig verification, and iss claim check — in that order, so sig is
  always authenticated before any payload data is trusted.

- ValidateParams is a stable config object with Validate(StandardClaims,
  time.Time) as a method. Time is passed at the call site, not stored in
  the params struct, so the same config object can be reused across requests.

- UnmarshalClaims(v any) accepts any type — no Claims interface to
  implement. Custom validation is a plain function call, not a method
  satisfying an interface.

- Sign uses crypto.Signer, supporting ES256/ES384/ES512 (ECDSA), RS256
  (RSA PKCS#1 v1.5), and EdDSA (Ed25519, RFC 8037).

- PublicJWK uses crypto.PublicKey (not generics) since JWKS returns
  heterogeneous key types at runtime. Typed accessors ECDSA(), RSA(), and
  EdDSA() replace TypedKeys[K] filtering.

- JWKS parsing handles kty: "EC", "RSA", and "OKP" (Ed25519).

10 tests: ES256/RS256/EdDSA round trips, custom validation, wrong key,
unknown kid, iss mismatch, tampered alg, PublicJWK accessors, JWKS JSON.
This commit is contained in:
AJ ONeal 2026-03-12 19:40:18 -06:00
parent 8ac858128e
commit 1f0b36fc6d
No known key found for this signature in database
4 changed files with 1289 additions and 0 deletions

3
auth/ajwt/go.mod Normal file
View File

@ -0,0 +1,3 @@
module github.com/therootcompany/golib/auth/ajwt
go 1.24.0

641
auth/ajwt/jwt.go Normal file
View File

@ -0,0 +1,641 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
// Package ajwt is a lightweight JWT/JWS/JWK library designed from first
// principles:
//
// - [JWS] is a parsed structure only — no Claims interface, no Verified flag.
// - [Issuer] owns key management and signature verification, centralizing
// the key lookup → sig verify → iss check sequence.
// - [ValidateParams] is a stable config object; time is passed at the call
// site so the same params can be reused across requests.
// - [JWS.UnmarshalClaims] accepts any type — no interface to implement.
// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384),
// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037).
//
// Typical usage:
//
// // At startup:
// iss := ajwt.NewIssuer("https://accounts.example.com")
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
// if err := iss.FetchKeys(ctx); err != nil { ... }
//
// // Per request:
// jws, err := ajwt.Decode(tokenStr)
// if err := iss.Verify(jws); err != nil { ... } // sig + iss check
// var claims AppClaims
// if err := jws.UnmarshalClaims(&claims); err != nil { ... }
// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... }
package ajwt
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"net/http"
"slices"
"strings"
"time"
)
// JWS is a decoded JSON Web Signature / JWT.
//
// It holds only the parsed structure — header, raw base64url fields, and
// decoded signature bytes. It carries no Claims interface and no Verified flag;
// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to
// decode the payload into a typed struct.
type JWS struct {
Protected string // base64url-encoded header
Header StandardHeader
Payload string // base64url-encoded claims
Signature []byte
}
// StandardHeader holds the standard JOSE header fields.
type StandardHeader struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
// StandardClaims holds the registered JWT claim names defined in RFC 7519
// and extended by OpenID Connect Core.
//
// Embed StandardClaims in your own claims struct:
//
// type AppClaims struct {
// ajwt.StandardClaims
// Email string `json:"email"`
// }
type StandardClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Aud string `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
AuthTime int64 `json:"auth_time"`
Nonce string `json:"nonce,omitempty"`
Amr []string `json:"amr"`
Azp string `json:"azp,omitempty"`
Jti string `json:"jti"`
}
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
//
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
// [Issuer.Verify] to populate a typed claims struct.
func Decode(tokenStr string) (*JWS, error) {
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
var jws JWS
jws.Protected, jws.Payload = parts[0], parts[1]
header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
if err != nil {
return nil, fmt.Errorf("invalid header encoding: %v", err)
}
if err := json.Unmarshal(header, &jws.Header); err != nil {
return nil, fmt.Errorf("invalid header JSON: %v", err)
}
jws.Signature, err = base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid signature encoding: %v", err)
}
return &jws, nil
}
// UnmarshalClaims decodes the JWT payload into v.
//
// v must be a pointer to a struct (e.g. *AppClaims). Always call
// [Issuer.Verify] before UnmarshalClaims to ensure the signature is
// authenticated before trusting the payload.
func (jws *JWS) UnmarshalClaims(v any) error {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return fmt.Errorf("invalid claims encoding: %v", err)
}
if err := json.Unmarshal(payload, v); err != nil {
return fmt.Errorf("invalid claims JSON: %v", err)
}
return nil
}
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
//
// kid identifies the signing key. The "alg" header field is set automatically
// when [JWS.Sign] is called. Call [JWS.Encode] to produce the compact JWT
// string after signing.
func NewJWSFromClaims(claims any, kid string) (*JWS, error) {
var jws JWS
jws.Header = StandardHeader{
// Alg is set by Sign based on the key type.
Kid: kid,
Typ: "JWT",
}
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
claimsJSON, err := json.Marshal(claims)
if err != nil {
return nil, fmt.Errorf("marshal claims: %w", err)
}
jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
return &jws, nil
}
// Sign signs the JWS in-place using the provided [crypto.Signer].
// It sets the "alg" header field based on the public key type and re-encodes
// the protected header before signing, so the signed input is always
// consistent with the token header.
//
// Supported algorithms (inferred from key type):
// - *ecdsa.PublicKey P-256 → ES256 (SHA-256, raw r||s)
// - *ecdsa.PublicKey P-384 → ES384 (SHA-384, raw r||s)
// - *ecdsa.PublicKey P-521 → ES512 (SHA-512, raw r||s)
// - *rsa.PublicKey → RS256 (PKCS#1 v1.5 + SHA-256)
// - ed25519.PublicKey → EdDSA (Ed25519, RFC 8037)
func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) {
switch pub := key.Public().(type) {
case *ecdsa.PublicKey:
alg, h, err := algForECKey(pub)
if err != nil {
return nil, err
}
jws.Header.Alg = alg
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
digest := digestFor(h, jws.Protected+"."+jws.Payload)
// crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS.
derSig, err := key.Sign(rand.Reader, digest, h)
if err != nil {
return nil, fmt.Errorf("Sign %s: %w", alg, err)
}
jws.Signature, err = ecdsaDERToRaw(derSig, pub.Curve)
return jws.Signature, err
case *rsa.PublicKey:
jws.Header.Alg = "RS256"
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload)
// crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly.
var err error
jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256)
return jws.Signature, err
case ed25519.PublicKey:
jws.Header.Alg = "EdDSA"
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
// Ed25519 signs the raw message with no pre-hashing; pass crypto.Hash(0).
signingInput := jws.Protected + "." + jws.Payload
var err error
jws.Signature, err = key.Sign(rand.Reader, []byte(signingInput), crypto.Hash(0))
return jws.Signature, err
default:
return nil, fmt.Errorf(
"Sign: unsupported public key type %T (supported: *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey)",
key.Public(),
)
}
}
// Encode produces the compact JWT string (header.payload.signature).
func (jws *JWS) Encode() string {
return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature)
}
// ValidateParams holds claim validation configuration.
//
// Configure once at startup; call [ValidateParams.Validate] per request,
// passing the current time. This keeps the config stable and makes the
// time dependency explicit at the call site.
//
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type ValidateParams struct {
IgnoreIss bool
Iss string
IgnoreSub bool
Sub string
IgnoreAud bool
Aud string
IgnoreExp bool
IgnoreJti bool
Jti string
IgnoreIat bool
IgnoreAuthTime bool
MaxAge time.Duration
IgnoreNonce bool
Nonce string
IgnoreAmr bool
RequiredAmrs []string
IgnoreAzp bool
Azp string
}
// Validate checks the standard JWT/OIDC claim fields against this config.
//
// now is typically time.Now() — passing it explicitly keeps the config stable
// across requests and avoids hidden time dependencies in the params struct.
func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) {
return ValidateStandardClaims(claims, p, now)
}
// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params.
//
// Exported so callers can use it directly without a [ValidateParams] receiver:
//
// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now())
func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) {
var errs []string
// Required to exist and match
if len(params.Iss) > 0 || !params.IgnoreIss {
if len(claims.Iss) == 0 {
errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)")
} else if claims.Iss != params.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss))
}
}
// Required to exist, optional match
if len(claims.Sub) == 0 {
if !params.IgnoreSub {
errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)")
}
} else if len(params.Sub) > 0 {
if params.Sub != claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub))
}
}
// Required to exist and match
if len(params.Aud) > 0 || !params.IgnoreAud {
if len(claims.Aud) == 0 {
errs = append(errs, "missing or malformed 'aud' (audience receiving token)")
} else if claims.Aud != params.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud))
}
}
// Required to exist and not be in the past
if !params.IgnoreExp {
if claims.Exp <= 0 {
errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)")
} else if claims.Exp < now.Unix() {
duration := now.Sub(time.Unix(claims.Exp, 0))
expTime := time.Unix(claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
}
}
// Required to exist and not be in the future
if !params.IgnoreIat {
if claims.Iat <= 0 {
errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)")
} else if claims.Iat > now.Unix() {
duration := time.Unix(claims.Iat, 0).Sub(now)
iatTime := time.Unix(claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("'iat' (issued at) is %s in the future (%s)", formatDuration(duration), iatTime))
}
}
// Should exist, in the past, with optional max age
if params.MaxAge > 0 || !params.IgnoreAuthTime {
if claims.AuthTime == 0 {
errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")
} else {
authTime := time.Unix(claims.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
age := now.Sub(authTime)
diff := age - params.MaxAge
if claims.AuthTime > now.Unix() {
fromNow := time.Unix(claims.AuthTime, 0).Sub(now)
errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s in the future (server time %s)",
authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")),
)
} else if params.MaxAge > 0 && age > params.MaxAge {
errs = append(errs, fmt.Sprintf(
"'auth_time' of %s is %s old, exceeding max age %s by %s",
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)),
)
}
}
}
// Optional exact match
if params.Jti != claims.Jti {
if len(params.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti))
} else if !params.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti))
}
}
// Optional exact match
if params.Nonce != claims.Nonce {
if len(params.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce))
} else if !params.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce))
}
}
// Should exist, optional required-set check
if !params.IgnoreAmr {
if len(claims.Amr) == 0 {
errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)")
} else if len(params.RequiredAmrs) > 0 {
for _, required := range params.RequiredAmrs {
if !slices.Contains(claims.Amr, required) {
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required))
}
}
}
}
// Optional, match if present
if params.Azp != claims.Azp {
if len(params.Azp) > 0 {
errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp))
} else if !params.IgnoreAzp {
errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp))
}
}
if len(errs) > 0 {
timeInfo := fmt.Sprintf("info: server time is %s", now.Format("2006-01-02 15:04:05 MST"))
if loc, err := time.LoadLocation("Local"); err == nil {
timeInfo += fmt.Sprintf(" %s", loc)
}
errs = append(errs, timeInfo)
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
// Issuer holds public keys and validation config for a trusted token issuer.
//
// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint.
// [Issuer.SetKeys] injects keys directly (useful in tests).
// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check.
//
// Typical setup:
//
// iss := ajwt.NewIssuer("https://accounts.example.com")
// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true}
// if err := iss.FetchKeys(ctx); err != nil { ... }
type Issuer struct {
URL string
JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json"
Params ValidateParams
keys map[string]crypto.PublicKey // kid → key
}
// NewIssuer creates an Issuer for the given base URL.
func NewIssuer(url string) *Issuer {
return &Issuer{
URL: url,
keys: make(map[string]crypto.PublicKey),
}
}
// SetKeys stores public keys by their KID, replacing any previously stored keys.
// Useful for injecting keys in tests without an HTTP round-trip.
func (iss *Issuer) SetKeys(keys []PublicJWK) {
m := make(map[string]crypto.PublicKey, len(keys))
for _, k := range keys {
m[k.KID] = k.Key
}
iss.keys = m
}
// FetchKeys retrieves and stores the JWKS from the issuer's endpoint.
// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json".
func (iss *Issuer) FetchKeys(ctx context.Context) error {
url := iss.JWKsURL
if url == "" {
url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json"
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode)
}
keys, err := DecodePublicJWKs(resp.Body)
if err != nil {
return fmt.Errorf("parse JWKS: %w", err)
}
iss.SetKeys(keys)
return nil
}
// Verify authenticates jws against this issuer:
// 1. Looks up the signing key by jws.Header.Kid.
// 2. Verifies the signature before trusting any payload data.
// 3. Checks that the token's "iss" claim matches iss.URL.
//
// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a
// typed struct, then [ValidateParams.Validate] to check claim values.
func (iss *Issuer) Verify(jws *JWS) error {
if jws.Header.Kid == "" {
return fmt.Errorf("missing 'kid' header")
}
key, ok := iss.keys[jws.Header.Kid]
if !ok {
return fmt.Errorf("unknown kid: %q", jws.Header.Kid)
}
signingInput := jws.Protected + "." + jws.Payload
if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
// Signature verified — now safe to inspect the payload.
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return fmt.Errorf("invalid claims encoding: %w", err)
}
var partial struct {
Iss string `json:"iss"`
}
if err := json.Unmarshal(payload, &partial); err != nil {
return fmt.Errorf("invalid claims JSON: %w", err)
}
if partial.Iss != iss.URL {
return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL)
}
return nil
}
// verifyWith checks a JWS signature using the given algorithm and public key.
// Returns nil on success, a descriptive error on failure.
func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKey) error {
switch alg {
case "ES256", "ES384", "ES512":
k, ok := key.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("alg %s requires *ecdsa.PublicKey, got %T", alg, key)
}
expectedAlg, h, err := algForECKey(k)
if err != nil {
return err
}
if expectedAlg != alg {
return fmt.Errorf("key curve mismatch: key is %s, token alg is %s", expectedAlg, alg)
}
byteLen := (k.Curve.Params().BitSize + 7) / 8
if len(sig) != 2*byteLen {
return fmt.Errorf("invalid %s signature length: got %d, want %d", alg, len(sig), 2*byteLen)
}
digest := digestFor(h, signingInput)
r := new(big.Int).SetBytes(sig[:byteLen])
s := new(big.Int).SetBytes(sig[byteLen:])
if !ecdsa.Verify(k, digest, r, s) {
return fmt.Errorf("%s signature invalid", alg)
}
return nil
case "RS256":
k, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("alg RS256 requires *rsa.PublicKey, got %T", key)
}
digest := digestFor(crypto.SHA256, signingInput)
if err := rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, sig); err != nil {
return fmt.Errorf("RS256 signature invalid: %w", err)
}
return nil
case "EdDSA":
k, ok := key.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("alg EdDSA requires ed25519.PublicKey, got %T", key)
}
if !ed25519.Verify(k, []byte(signingInput), sig) {
return fmt.Errorf("EdDSA signature invalid")
}
return nil
default:
return fmt.Errorf("unsupported alg: %q", alg)
}
}
// --- Internal helpers ---
func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) {
switch pub.Curve {
case elliptic.P256():
return "ES256", crypto.SHA256, nil
case elliptic.P384():
return "ES384", crypto.SHA384, nil
case elliptic.P521():
return "ES512", crypto.SHA512, nil
default:
return "", 0, fmt.Errorf("unsupported EC curve: %s", pub.Curve.Params().Name)
}
}
func digestFor(h crypto.Hash, data string) []byte {
switch h {
case crypto.SHA256:
d := sha256.Sum256([]byte(data))
return d[:]
case crypto.SHA384:
d := sha512.Sum384([]byte(data))
return d[:]
case crypto.SHA512:
d := sha512.Sum512([]byte(data))
return d[:]
default:
panic(fmt.Sprintf("ajwt: unsupported hash %v", h))
}
}
func ecdsaDERToRaw(der []byte, curve elliptic.Curve) ([]byte, error) {
var sig struct{ R, S *big.Int }
if _, err := asn1.Unmarshal(der, &sig); err != nil {
return nil, fmt.Errorf("ecdsaDERToRaw: %w", err)
}
byteLen := (curve.Params().BitSize + 7) / 8
out := make([]byte, 2*byteLen)
sig.R.FillBytes(out[:byteLen])
sig.S.FillBytes(out[byteLen:])
return out, nil
}
func formatDuration(d time.Duration) string {
if d < 0 {
d = -d
}
days := int(d / (24 * time.Hour))
d -= time.Duration(days) * 24 * time.Hour
hours := int(d / time.Hour)
d -= time.Duration(hours) * time.Hour
minutes := int(d / time.Minute)
d -= time.Duration(minutes) * time.Minute
seconds := int(d / time.Second)
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%dd", days))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%dh", hours))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%dm", minutes))
}
if seconds > 0 || len(parts) == 0 {
parts = append(parts, fmt.Sprintf("%ds", seconds))
}
if seconds == 0 || len(parts) == 0 {
d -= time.Duration(seconds) * time.Second
millis := int(d / time.Millisecond)
parts = append(parts, fmt.Sprintf("%dms", millis))
}
return strings.Join(parts, " ")
}

410
auth/ajwt/jwt_test.go Normal file
View File

@ -0,0 +1,410 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package ajwt_test
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
"testing"
"time"
"github.com/therootcompany/golib/auth/ajwt"
)
// AppClaims embeds StandardClaims and adds application-specific fields.
//
// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate
// interface — there is none. Validation is explicit: call
// ValidateStandardClaims or ValidateParams.Validate at the call site.
type AppClaims struct {
ajwt.StandardClaims
Email string `json:"email"`
Roles []string `json:"roles"`
}
// validateAppClaims is a plain function — not a method satisfying an interface.
// Custom validation logic lives here, calling ValidateStandardClaims directly.
func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) {
errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now)
if c.Email == "" {
errs = append(errs, "missing email claim")
}
if len(errs) > 0 {
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
func goodClaims() AppClaims {
now := time.Now()
return AppClaims{
StandardClaims: ajwt.StandardClaims{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Exp: now.Add(time.Hour).Unix(),
Iat: now.Unix(),
AuthTime: now.Unix(),
Amr: []string{"pwd"},
Jti: "abc123",
Azp: "myapp",
Nonce: "nonce1",
},
Email: "user@example.com",
Roles: []string{"admin"},
}
}
// goodParams configures the validator. Iss is omitted because Issuer.Verify
// already enforces the iss claim — no need to check it twice.
func goodParams() ajwt.ValidateParams {
return ajwt.ValidateParams{
IgnoreIss: true, // Issuer.Verify handles iss
Sub: "user123",
Aud: "myapp",
Jti: "abc123",
Nonce: "nonce1",
Azp: "myapp",
RequiredAmrs: []string{"pwd"},
}
}
func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer {
iss := ajwt.NewIssuer("https://example.com")
iss.Params = goodParams()
iss.SetKeys([]ajwt.PublicJWK{pub})
return iss
}
// TestRoundTrip is the primary happy path using ES256.
// It demonstrates the full Issuer-based flow:
//
// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate
//
// No Claims interface, no Verified flag, no type assertions on jws.
func TestRoundTrip(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := ajwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "ES256" {
t.Fatalf("expected ES256, got %s", jws.Header.Alg)
}
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
// Direct field access — no type assertion needed, no jws.Claims interface.
if decoded.Email != claims.Email {
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
}
}
// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256.
func TestRoundTripRS256(t *testing.T) {
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := ajwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "RS256" {
t.Fatalf("expected RS256, got %s", jws.Header.Alg)
}
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
}
// TestRoundTripEdDSA exercises Ed25519 / EdDSA (RFC 8037).
func TestRoundTripEdDSA(t *testing.T) {
pubKeyBytes, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := ajwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "EdDSA" {
t.Fatalf("expected EdDSA, got %s", jws.Header.Alg)
}
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"})
jws2, err := ajwt.Decode(token)
if err != nil {
t.Fatal(err)
}
if err = iss.Verify(jws2); err != nil {
t.Fatalf("Verify failed: %v", err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
}
// TestCustomValidation demonstrates custom claim validation without any interface.
// The caller owns the validation logic and calls ValidateStandardClaims directly.
func TestCustomValidation(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
// Token with empty Email — our custom validator should reject it.
claims := goodClaims()
claims.Email = ""
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
_ = iss.Verify(jws2)
var decoded AppClaims
_ = jws2.UnmarshalClaims(&decoded)
errs, err := validateAppClaims(decoded, goodParams(), time.Now())
if err == nil {
t.Fatal("expected validation to fail: email is empty")
}
found := false
for _, e := range errs {
if e == "missing email claim" {
found = true
}
}
if !found {
t.Fatalf("expected 'missing email claim' in errors: %v", errs)
}
}
// TestIssuerWrongKey confirms that a different key's public key is rejected.
func TestIssuerWrongKey(t *testing.T) {
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(signingKey)
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail with wrong key")
}
}
// TestIssuerUnknownKid confirms that an unknown kid is rejected.
func TestIssuerUnknownKid(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&claims, "unknown-kid")
_, _ = jws.Sign(privKey)
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail for unknown kid")
}
}
// TestIssuerIssMismatch confirms that a token with a mismatched iss is rejected
// even if the signature is valid.
func TestIssuerIssMismatch(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
claims.Iss = "https://evil.example.com" // not the issuer URL
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
// Issuer expects "https://example.com" but token says "https://evil.example.com"
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail: iss mismatch")
}
}
// TestVerifyTamperedAlg confirms that a tampered alg header is rejected.
func TestVerifyTamperedAlg(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := ajwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"})
jws2, _ := ajwt.Decode(token)
jws2.Header.Alg = "none" // tamper
if err := iss.Verify(jws2); err == nil {
t.Fatal("expected Verify to fail for tampered alg")
}
}
// TestPublicJWKAccessors confirms the ECDSA, RSA, and EdDSA typed accessor methods.
func TestPublicJWKAccessors(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
ecJWK := ajwt.PublicJWK{Key: &ecKey.PublicKey, KID: "ec-1"}
rsaJWK := ajwt.PublicJWK{Key: &rsaKey.PublicKey, KID: "rsa-1"}
edJWK := ajwt.PublicJWK{Key: edPub, KID: "ed-1"}
if k, ok := ecJWK.ECDSA(); !ok || k == nil {
t.Error("expected ECDSA() to succeed for EC key")
}
if _, ok := ecJWK.RSA(); ok {
t.Error("expected RSA() to fail for EC key")
}
if _, ok := ecJWK.EdDSA(); ok {
t.Error("expected EdDSA() to fail for EC key")
}
if k, ok := rsaJWK.RSA(); !ok || k == nil {
t.Error("expected RSA() to succeed for RSA key")
}
if _, ok := rsaJWK.ECDSA(); ok {
t.Error("expected ECDSA() to fail for RSA key")
}
if _, ok := rsaJWK.EdDSA(); ok {
t.Error("expected EdDSA() to fail for RSA key")
}
if k, ok := edJWK.EdDSA(); !ok || k == nil {
t.Error("expected EdDSA() to succeed for Ed25519 key")
}
if _, ok := edJWK.ECDSA(); ok {
t.Error("expected ECDSA() to fail for Ed25519 key")
}
if _, ok := edJWK.RSA(); ok {
t.Error("expected RSA() to fail for Ed25519 key")
}
}
// TestDecodePublicJWKJSON verifies JWKS JSON parsing with real base64url-encoded
// key material from RFC 7517 / OIDC examples.
func TestDecodePublicJWKJSON(t *testing.T) {
jwksJSON := []byte(`{"keys":[
{"kty":"EC","crv":"P-256",
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"kid":"ec-256","use":"sig"},
{"kty":"RSA",
"n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e":"AQAB","kid":"rsa-2048","use":"sig"}
]}`)
keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON)
if err != nil {
t.Fatal(err)
}
if len(keys) != 2 {
t.Fatalf("expected 2 keys, got %d", len(keys))
}
var ecCount, rsaCount int
for _, k := range keys {
if _, ok := k.ECDSA(); ok {
ecCount++
if k.KID != "ec-256" {
t.Errorf("unexpected EC kid: %s", k.KID)
}
}
if _, ok := k.RSA(); ok {
rsaCount++
if k.KID != "rsa-2048" {
t.Errorf("unexpected RSA kid: %s", k.KID)
}
}
}
if ecCount != 1 {
t.Errorf("expected 1 EC key, got %d", ecCount)
}
if rsaCount != 1 {
t.Errorf("expected 1 RSA key, got %d", rsaCount)
}
}

235
auth/ajwt/pub.go Normal file
View File

@ -0,0 +1,235 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package ajwt
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"os"
"time"
)
// PublicJWK wraps a parsed public key with its JWKS metadata.
//
// Key is [crypto.PublicKey] (= any) since a JWKS endpoint returns a
// heterogeneous mix of key types determined at runtime by the "kty" field.
// Use the typed accessor methods [PublicJWK.ECDSA], [PublicJWK.RSA], and
// [PublicJWK.EdDSA] to assert the underlying type without a raw type switch.
type PublicJWK struct {
Key crypto.PublicKey
KID string
Use string
}
// ECDSA returns the key as *ecdsa.PublicKey if it is one, else (nil, false).
func (k PublicJWK) ECDSA() (*ecdsa.PublicKey, bool) {
key, ok := k.Key.(*ecdsa.PublicKey)
return key, ok
}
// RSA returns the key as *rsa.PublicKey if it is one, else (nil, false).
func (k PublicJWK) RSA() (*rsa.PublicKey, bool) {
key, ok := k.Key.(*rsa.PublicKey)
return key, ok
}
// EdDSA returns the key as ed25519.PublicKey if it is one, else (nil, false).
func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) {
key, ok := k.Key.(ed25519.PublicKey)
return key, ok
}
// PublicJWKJSON is the JSON representation of a single key in a JWKS document.
type PublicJWKJSON struct {
Kty string `json:"kty"`
KID string `json:"kid"`
Crv string `json:"crv,omitempty"` // EC / OKP curve
X string `json:"x,omitempty"` // EC / OKP public key x (or Ed25519 key bytes)
Y string `json:"y,omitempty"` // EC public key y
N string `json:"n,omitempty"` // RSA modulus
E string `json:"e,omitempty"` // RSA exponent
Use string `json:"use,omitempty"`
}
// JWKsJSON is the JSON representation of a JWKS document.
type JWKsJSON struct {
Keys []PublicJWKJSON `json:"keys"`
}
// FetchPublicJWKs retrieves and parses a JWKS document from url.
//
// For issuer-scoped key management with context support, use
// [Issuer.FetchKeys] instead.
func FetchPublicJWKs(url string) ([]PublicJWK, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return DecodePublicJWKs(resp.Body)
}
// ReadPublicJWKs reads and parses a JWKS document from a file path.
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open JWKS file %q: %w", filePath, err)
}
defer func() { _ = file.Close() }()
return DecodePublicJWKs(file)
}
// UnmarshalPublicJWKs parses a JWKS document from raw JSON bytes.
func UnmarshalPublicJWKs(data []byte) ([]PublicJWK, error) {
var jwks JWKsJSON
if err := json.Unmarshal(data, &jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
}
return DecodePublicJWKsJSON(jwks)
}
// DecodePublicJWKs parses a JWKS document from an [io.Reader].
func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
var jwks JWKsJSON
if err := json.NewDecoder(r).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
}
return DecodePublicJWKsJSON(jwks)
}
// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys.
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
var keys []PublicJWK
for _, jwk := range jwks.Keys {
key, err := DecodePublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err)
}
keys = append(keys, *key)
}
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in JWKS")
}
return keys, nil
}
// DecodePublicJWK parses a single [PublicJWKJSON] into a [PublicJWK].
//
// Supported key types:
// - "RSA" — minimum 1024-bit (RS256)
// - "EC" — P-256, P-384, P-521 (ES256, ES384, ES512)
// - "OKP" — Ed25519 crv (EdDSA / RFC 8037)
func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK, error) {
switch jwk.Kty {
case "RSA":
key, err := decodeRSAPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA key %q: %w", jwk.KID, err)
}
if key.Size() < 128 { // 1024 bits minimum
return nil, fmt.Errorf("RSA key %q too small: %d bytes", jwk.KID, key.Size())
}
return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
case "EC":
key, err := decodeECPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse EC key %q: %w", jwk.KID, err)
}
return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
case "OKP":
key, err := decodeOKPPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse OKP key %q: %w", jwk.KID, err)
}
return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil
default:
return nil, fmt.Errorf("unsupported key type %q for kid %q", jwk.Kty, jwk.KID)
}
}
func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("invalid RSA modulus: %w", err)
}
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("invalid RSA exponent: %w", err)
}
eInt := new(big.Int).SetBytes(e).Int64()
if eInt > int64(^uint(0)>>1) || eInt < 0 {
return nil, fmt.Errorf("RSA exponent too large or negative")
}
return &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(eInt),
}, nil
}
func decodeECPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid ECDSA X: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("invalid ECDSA Y: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported EC curve: %s", jwk.Crv)
}
return &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}, nil
}
func decodeOKPPublicJWK(jwk PublicJWKJSON) (ed25519.PublicKey, error) {
if jwk.Crv != "Ed25519" {
return nil, fmt.Errorf("unsupported OKP curve: %q (only Ed25519 supported)", jwk.Crv)
}
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid OKP X: %w", err)
}
if len(x) != ed25519.PublicKeySize {
return nil, fmt.Errorf("invalid Ed25519 key size: got %d bytes, want %d", len(x), ed25519.PublicKeySize)
}
return ed25519.PublicKey(x), nil
}