mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-13 12:27:59 +00:00
Implements JWS[C Validatable] generic over the claims type, with a package-level Decode[C] function (Go disallows generic methods). Claims are directly typed as C — no interface or type assertion needed at use sites. PublicJWK[K Key] and TypedKeys[K] provide generic key management. Supports ES256 and RS256 via crypto.Signer.
295 lines
8.5 KiB
Go
295 lines
8.5 KiB
Go
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package genericjwt_test
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/therootcompany/golib/auth/genericjwt"
|
|
)
|
|
|
|
// AppClaims is an example custom claims type that satisfies [genericjwt.Validatable]
|
|
// by explicitly delegating to [genericjwt.ValidateStandardClaims].
|
|
//
|
|
// Unlike embeddedjwt, there is no promoted Validate here — genericjwt's
|
|
// StandardClaims has no Validate method, so the application type always owns
|
|
// the implementation. This keeps the generics constraint explicit.
|
|
type AppClaims struct {
|
|
genericjwt.StandardClaims
|
|
Email string `json:"email"`
|
|
Roles []string `json:"roles"`
|
|
}
|
|
|
|
func (c AppClaims) Validate(params genericjwt.ValidateParams) ([]string, error) {
|
|
return genericjwt.ValidateStandardClaims(c.StandardClaims, params)
|
|
}
|
|
|
|
// StrictAppClaims adds an application-specific check (non-empty Email) on top
|
|
// of the standard OIDC validation.
|
|
type StrictAppClaims struct {
|
|
genericjwt.StandardClaims
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
func (c StrictAppClaims) Validate(params genericjwt.ValidateParams) ([]string, error) {
|
|
errs, _ := genericjwt.ValidateStandardClaims(c.StandardClaims, params)
|
|
if c.Email == "" {
|
|
errs = append(errs, "missing email claim")
|
|
}
|
|
if len(errs) > 0 {
|
|
return errs, fmt.Errorf("has errors")
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func goodClaims() AppClaims {
|
|
now := time.Now()
|
|
return AppClaims{
|
|
StandardClaims: genericjwt.StandardClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user123",
|
|
Aud: "myapp",
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
Iat: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
Amr: []string{"pwd"},
|
|
Jti: "abc123",
|
|
Azp: "myapp",
|
|
Nonce: "nonce1",
|
|
},
|
|
Email: "user@example.com",
|
|
Roles: []string{"admin"},
|
|
}
|
|
}
|
|
|
|
func goodParams() genericjwt.ValidateParams {
|
|
return genericjwt.ValidateParams{
|
|
Iss: "https://example.com",
|
|
Sub: "user123",
|
|
Aud: "myapp",
|
|
Jti: "abc123",
|
|
Nonce: "nonce1",
|
|
Azp: "myapp",
|
|
RequiredAmrs: []string{"pwd"},
|
|
}
|
|
}
|
|
|
|
// TestRoundTrip is the primary happy path demonstrating the core genericjwt
|
|
// ergonomic: Decode[AppClaims] places the type parameter at the call site and
|
|
// returns a JWS[AppClaims] whose Claims field is directly typed — no interface,
|
|
// no type assertion ever needed.
|
|
func TestRoundTrip(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := genericjwt.NewJWSFromClaims(claims, "key-1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if _, err = jws.Sign(privKey); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.Header.Alg != "ES256" {
|
|
t.Fatalf("expected ES256, got %s", jws.Header.Alg)
|
|
}
|
|
|
|
token := jws.Encode()
|
|
|
|
// Type parameter at the call site — no pre-allocated claims pointer needed.
|
|
jws2, err := genericjwt.Decode[AppClaims](token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
|
t.Fatal("signature verification failed")
|
|
}
|
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
|
t.Fatalf("validation failed: %v", errs)
|
|
}
|
|
// Direct field access on jws2.Claims — zero type assertions.
|
|
if jws2.Claims.Email != claims.Email {
|
|
t.Errorf("email: got %s, want %s", jws2.Claims.Email, claims.Email)
|
|
}
|
|
}
|
|
|
|
// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256.
|
|
func TestRoundTripRS256(t *testing.T) {
|
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := genericjwt.NewJWSFromClaims(claims, "key-1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if _, err = jws.Sign(privKey); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.Header.Alg != "RS256" {
|
|
t.Fatalf("expected RS256, got %s", jws.Header.Alg)
|
|
}
|
|
|
|
token := jws.Encode()
|
|
|
|
jws2, err := genericjwt.Decode[AppClaims](token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
|
t.Fatal("signature verification failed")
|
|
}
|
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
|
t.Fatalf("validation failed: %v", errs)
|
|
}
|
|
}
|
|
|
|
// TestUnsafeVerifyWrongKey confirms that a different key's public key does
|
|
// not verify the signature.
|
|
func TestUnsafeVerifyWrongKey(t *testing.T) {
|
|
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
claims := goodClaims()
|
|
jws, _ := genericjwt.NewJWSFromClaims(claims, "k")
|
|
_, _ = jws.Sign(signingKey)
|
|
token := jws.Encode()
|
|
|
|
jws2, _ := genericjwt.Decode[AppClaims](token)
|
|
|
|
if jws2.UnsafeVerify(&wrongKey.PublicKey) {
|
|
t.Fatal("expected verification to fail with wrong key")
|
|
}
|
|
}
|
|
|
|
// TestVerifyWrongKeyType confirms that an RSA key is rejected for an ES256 token.
|
|
func TestVerifyWrongKeyType(t *testing.T) {
|
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
claims := goodClaims()
|
|
jws, _ := genericjwt.NewJWSFromClaims(claims, "k")
|
|
_, _ = jws.Sign(ecKey)
|
|
token := jws.Encode()
|
|
|
|
jws2, _ := genericjwt.Decode[AppClaims](token)
|
|
|
|
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
|
|
t.Fatal("expected verification to fail: RSA key for ES256 token")
|
|
}
|
|
}
|
|
|
|
// TestVerifyUnknownAlg confirms that a tampered alg header is rejected.
|
|
func TestVerifyUnknownAlg(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
claims := goodClaims()
|
|
jws, _ := genericjwt.NewJWSFromClaims(claims, "k")
|
|
_, _ = jws.Sign(privKey)
|
|
token := jws.Encode()
|
|
|
|
jws2, _ := genericjwt.Decode[AppClaims](token)
|
|
jws2.Header.Alg = "none"
|
|
|
|
if jws2.UnsafeVerify(&privKey.PublicKey) {
|
|
t.Fatal("expected verification to fail for unknown alg")
|
|
}
|
|
}
|
|
|
|
// TestValidateMissingSignatureCheck verifies that Validate fails when
|
|
// UnsafeVerify was never called (Verified is false).
|
|
func TestValidateMissingSignatureCheck(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
claims := goodClaims()
|
|
jws, _ := genericjwt.NewJWSFromClaims(claims, "k")
|
|
_, _ = jws.Sign(privKey)
|
|
token := jws.Encode()
|
|
|
|
jws2, _ := genericjwt.Decode[AppClaims](token)
|
|
|
|
// Deliberately skip UnsafeVerify.
|
|
errs, err := jws2.Validate(goodParams())
|
|
if err == nil {
|
|
t.Fatal("expected validation to fail: signature was not checked")
|
|
}
|
|
found := false
|
|
for _, e := range errs {
|
|
if e == "signature was not checked" {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatalf("expected 'signature was not checked' in errors: %v", errs)
|
|
}
|
|
}
|
|
|
|
// TestVerifyWithJWKSKey verifies that PublicJWK[Key].Key can be passed
|
|
// directly to UnsafeVerify — the Key interface satisfies UnsafeVerify's
|
|
// parameter constraint without a type assertion.
|
|
func TestVerifyWithJWKSKey(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
jwksKey := genericjwt.PublicJWK[genericjwt.Key]{Key: &privKey.PublicKey, KID: "k1"}
|
|
|
|
claims := goodClaims()
|
|
jws, _ := genericjwt.NewJWSFromClaims(claims, "k1")
|
|
_, _ = jws.Sign(privKey)
|
|
token := jws.Encode()
|
|
|
|
jws2, _ := genericjwt.Decode[AppClaims](token)
|
|
|
|
if !jws2.UnsafeVerify(jwksKey.Key) {
|
|
t.Fatal("verification via PublicJWK.Key failed")
|
|
}
|
|
}
|
|
|
|
// TestDecodePublicJWKJSON verifies JWKS JSON parsing and TypedKeys filtering
|
|
// with real base64url-encoded key material from RFC 7517 / OIDC examples.
|
|
func TestDecodePublicJWKJSON(t *testing.T) {
|
|
jwksJSON := []byte(`{"keys":[
|
|
{"kty":"EC","crv":"P-256",
|
|
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
|
|
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
|
|
"kid":"ec-256","use":"sig"},
|
|
{"kty":"RSA",
|
|
"n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
|
|
"e":"AQAB","kid":"rsa-2048","use":"sig"}
|
|
]}`)
|
|
|
|
keys, err := genericjwt.UnmarshalPublicJWKs(jwksJSON)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(keys) != 2 {
|
|
t.Fatalf("expected 2 keys, got %d", len(keys))
|
|
}
|
|
|
|
ecKeys := genericjwt.TypedKeys[*ecdsa.PublicKey](keys)
|
|
if len(ecKeys) != 1 || ecKeys[0].KID != "ec-256" {
|
|
t.Errorf("EC key mismatch: %+v", ecKeys)
|
|
}
|
|
|
|
rsaKeys := genericjwt.TypedKeys[*rsa.PublicKey](keys)
|
|
if len(rsaKeys) != 1 || rsaKeys[0].KID != "rsa-2048" {
|
|
t.Errorf("RSA key mismatch: %+v", rsaKeys)
|
|
}
|
|
}
|