golib/auth/bestjwt/jwt_test.go
AJ ONeal fac58cf1ad
feat(auth/bestjwt): add hybrid JWT/JWS/JWK package
Combines the best ergonomics from genericjwt and embeddedjwt:

- Decode(&claims) pattern (embedded structs, no generics at call sites,
  no type assertion to access custom fields)
- StandardClaims.Validate promoted to any embedding struct via value
  receiver; override Validate on the outer struct for custom checks
- Sign(crypto.Signer): algorithm inferred from key.Public() type switch,
  supports HSM/cloud KMS transparently
- Full ECDSA curve support: ES256 (P-256), ES384 (P-384), ES512 (P-521)
  all inferred automatically from key curve via algForECKey
- Curve/alg consistency check in UnsafeVerify: P-256 key rejected for
  ES384 token and vice versa (prevents cross-algorithm downgrade)
- digestFor: fixed-size stack arrays for SHA-256/384/512 digests
- ecdsaDERToRaw + FillBytes: correct zero-padded r||s conversion from
  ASN.1 DER output of crypto.Signer
- Generic PublicJWK[K Key] + TypedKeys[K]: type-safe JWKS key management,
  filter mixed []PublicJWK[Key] to concrete type without assertions
- JWKS fetch/parse: FetchPublicJWKs, ReadPublicJWKs, UnmarshalPublicJWKs,
  DecodePublicJWKs for RSA and EC (P-256/384/521)
- RS256 (PKCS#1 v1.5 + SHA-256) support via crypto.Signer
- 13 tests covering all algorithms, negative cases, and JWKS integration
2026-03-12 17:40:24 -06:00

479 lines
14 KiB
Go

// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package bestjwt_test
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
"testing"
"time"
"github.com/therootcompany/golib/auth/bestjwt"
)
// AppClaims is an example custom claims type.
// Embedding StandardClaims promotes Validate — no boilerplate needed.
type AppClaims struct {
bestjwt.StandardClaims
Email string `json:"email"`
Roles []string `json:"roles"`
}
// StrictAppClaims overrides Validate to also require a non-empty Email.
// This demonstrates how to add application-specific validation on top of
// the standard OIDC checks.
type StrictAppClaims struct {
bestjwt.StandardClaims
Email string `json:"email"`
}
func (c StrictAppClaims) Validate(params bestjwt.ValidateParams) ([]string, error) {
errs, _ := bestjwt.ValidateStandardClaims(c.StandardClaims, params)
if c.Email == "" {
errs = append(errs, "missing email claim")
}
if len(errs) > 0 {
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
// goodClaims returns a valid AppClaims with all standard fields populated.
func goodClaims() AppClaims {
now := time.Now()
return AppClaims{
StandardClaims: bestjwt.StandardClaims{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Exp: now.Add(time.Hour).Unix(),
Iat: now.Unix(),
AuthTime: now.Unix(),
Amr: []string{"pwd"},
Jti: "abc123",
Azp: "myapp",
Nonce: "nonce1",
},
Email: "user@example.com",
Roles: []string{"admin"},
}
}
// goodParams returns ValidateParams matching the claims from goodClaims.
func goodParams() bestjwt.ValidateParams {
return bestjwt.ValidateParams{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Jti: "abc123",
Nonce: "nonce1",
Azp: "myapp",
RequiredAmrs: []string{"pwd"},
}
}
// --- Round-trip tests (sign → encode → decode → verify → validate) ---
// TestRoundTripES256 exercises the most common path: ECDSA P-256 / ES256.
// Demonstrates the Decode(&claims) ergonomic — no generics at the call site,
// no type assertion needed to access Email after decoding.
func TestRoundTripES256(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "ES256" {
t.Fatalf("expected ES256, got %s", jws.Header.Alg)
}
if len(jws.Signature) != 64 { // P-256: 32 bytes each for r and s
t.Fatalf("expected 64-byte signature, got %d", len(jws.Signature))
}
token := jws.Encode()
var decoded AppClaims
jws2, err := bestjwt.Decode(token, &decoded)
if err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
// Direct field access via the local variable — no type assertion.
if decoded.Email != claims.Email {
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
}
}
// TestRoundTripES384 exercises ECDSA P-384 / ES384, verifying that the
// algorithm is inferred from the key's curve and that the 96-byte r||s
// signature format is produced and verified correctly.
func TestRoundTripES384(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "ES384" {
t.Fatalf("expected ES384, got %s", jws.Header.Alg)
}
if len(jws.Signature) != 96 { // P-384: 48 bytes each for r and s
t.Fatalf("expected 96-byte signature, got %d", len(jws.Signature))
}
token := jws.Encode()
var decoded AppClaims
jws2, err := bestjwt.Decode(token, &decoded)
if err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
}
// TestRoundTripES512 exercises ECDSA P-521 / ES512 and the 132-byte signature.
func TestRoundTripES512(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "ES512" {
t.Fatalf("expected ES512, got %s", jws.Header.Alg)
}
if len(jws.Signature) != 132 { // P-521: 66 bytes each for r and s
t.Fatalf("expected 132-byte signature, got %d", len(jws.Signature))
}
token := jws.Encode()
var decoded AppClaims
jws2, err := bestjwt.Decode(token, &decoded)
if err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
}
// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256.
func TestRoundTripRS256(t *testing.T) {
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
jws, err := bestjwt.NewJWSFromClaims(&claims, "key-1")
if err != nil {
t.Fatal(err)
}
if _, err = jws.Sign(privKey); err != nil {
t.Fatal(err)
}
if jws.Header.Alg != "RS256" {
t.Fatalf("expected RS256, got %s", jws.Header.Alg)
}
token := jws.Encode()
var decoded AppClaims
jws2, err := bestjwt.Decode(token, &decoded)
if err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("validation failed: %v", errs)
}
}
// --- Security / negative tests ---
// TestVerifyWrongKeyType verifies that an RSA public key is rejected when
// verifying a token signed with ECDSA (alg = ES256).
func TestVerifyWrongKeyType(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
claims := goodClaims()
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(ecKey) // alg = "ES256"
token := jws.Encode()
var decoded AppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
t.Fatal("expected verification to fail: RSA key for ES256 token")
}
}
// TestVerifyAlgCurveMismatch verifies that a P-256 key is rejected when
// verifying a token whose header claims ES384 (signed with P-384).
// Without the curve/alg consistency check this would silently return false
// from ecdsa.Verify, but the explicit check makes the rejection reason clear.
func TestVerifyAlgCurveMismatch(t *testing.T) {
p384Key, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
p256Key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(p384Key) // alg = "ES384"
token := jws.Encode()
var decoded AppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
// P-256 key must be rejected for an ES384 token.
if jws2.UnsafeVerify(&p256Key.PublicKey) {
t.Fatal("expected verification to fail: P-256 key for ES384 token")
}
}
// TestVerifyUnknownAlg verifies that a tampered alg header is rejected.
func TestVerifyUnknownAlg(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
// Tamper: overwrite alg in the decoded header.
jws2.Header.Alg = "none"
if jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("expected verification to fail for unknown alg")
}
}
// TestValidateMissingSignatureCheck verifies that Validate fails when
// UnsafeVerify was never called (Verified is false).
func TestValidateMissingSignatureCheck(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
// Deliberately skip UnsafeVerify.
errs, err := jws2.Validate(goodParams())
if err == nil {
t.Fatal("expected validation to fail: signature was not checked")
}
found := false
for _, e := range errs {
if e == "signature was not checked" {
found = true
}
}
if !found {
t.Fatalf("expected 'signature was not checked' in errors: %v", errs)
}
}
// --- Embedded vs overridden Validate ---
// TestPromotedValidate confirms that AppClaims (which only embeds
// StandardClaims) gets the standard OIDC validation for free via promotion,
// without writing any Validate method.
func TestPromotedValidate(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
claims := goodClaims()
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
jws2.UnsafeVerify(&privKey.PublicKey)
if errs, err := jws2.Validate(goodParams()); err != nil {
t.Fatalf("promoted Validate failed unexpectedly: %v", errs)
}
}
// TestOverriddenValidate confirms that a StrictAppClaims with an empty Email
// fails validation via its overridden Validate method.
func TestOverriddenValidate(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
now := time.Now()
claims := StrictAppClaims{
StandardClaims: bestjwt.StandardClaims{
Iss: "https://example.com",
Sub: "user123",
Aud: "myapp",
Exp: now.Add(time.Hour).Unix(),
Iat: now.Unix(),
AuthTime: now.Unix(),
Amr: []string{"pwd"},
Jti: "abc123",
Azp: "myapp",
Nonce: "nonce1",
},
Email: "", // intentionally empty to trigger the override
}
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded StrictAppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
jws2.UnsafeVerify(&privKey.PublicKey)
errs, err := jws2.Validate(goodParams())
if err == nil {
t.Fatal("expected validation to fail: email is empty")
}
found := false
for _, e := range errs {
if e == "missing email claim" {
found = true
}
}
if !found {
t.Fatalf("expected 'missing email claim' in errors: %v", errs)
}
}
// --- JWKS / key management ---
// TestTypedKeys verifies that TypedKeys correctly filters a mixed
// []PublicJWK[Key] into typed slices without type assertions at use sites.
func TestTypedKeys(t *testing.T) {
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
allKeys := []bestjwt.PublicJWK[bestjwt.Key]{
{Key: &ecKey.PublicKey, KID: "ec-1", Use: "sig"},
{Key: &rsaKey.PublicKey, KID: "rsa-1", Use: "sig"},
}
ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](allKeys)
if len(ecKeys) != 1 || ecKeys[0].KID != "ec-1" {
t.Errorf("unexpected EC keys: %+v", ecKeys)
}
// Typed access — no assertion needed.
_ = ecKeys[0].Key.Curve
rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](allKeys)
if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-1" {
t.Errorf("unexpected RSA keys: %+v", rsaKeys)
}
}
// TestVerifyWithJWKSKey verifies that PublicJWK.Key can be passed directly to
// UnsafeVerify without a type assertion when using a typed PublicJWK[Key].
func TestVerifyWithJWKSKey(t *testing.T) {
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
jwksKey := bestjwt.PublicJWK[bestjwt.Key]{Key: &privKey.PublicKey, KID: "k1"}
claims := goodClaims()
jws, _ := bestjwt.NewJWSFromClaims(&claims, "k1")
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := bestjwt.Decode(token, &decoded)
// Pass PublicJWK.Key directly — Key interface satisfies the Key constraint.
if !jws2.UnsafeVerify(jwksKey.Key) {
t.Fatal("verification via PublicJWK.Key failed")
}
}
// TestDecodePublicJWKJSON verifies JWKS JSON parsing and TypedKeys filtering
// with real base64url-encoded key material from RFC 7517 / OIDC examples.
func TestDecodePublicJWKJSON(t *testing.T) {
jwksJSON := []byte(`{"keys":[
{"kty":"EC","crv":"P-256",
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
"kid":"ec-256","use":"sig"},
{"kty":"RSA",
"n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e":"AQAB","kid":"rsa-2048","use":"sig"}
]}`)
keys, err := bestjwt.UnmarshalPublicJWKs(jwksJSON)
if err != nil {
t.Fatal(err)
}
if len(keys) != 2 {
t.Fatalf("expected 2 keys, got %d", len(keys))
}
ecKeys := bestjwt.TypedKeys[*ecdsa.PublicKey](keys)
if len(ecKeys) != 1 || ecKeys[0].KID != "ec-256" {
t.Errorf("EC key mismatch: %+v", ecKeys)
}
rsaKeys := bestjwt.TypedKeys[*rsa.PublicKey](keys)
if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-2048" {
t.Errorf("RSA key mismatch: %+v", rsaKeys)
}
}