refactor(auth/embeddedjwt): split Decode from claims unmarshaling

Remove the claims parameter from Decode — it now only parses the JWS
structure (header, payload, signature). Add UnmarshalClaims(v Claims)
to decode the payload into a typed struct as a separate step. Guard
Validate against nil Claims with a clear error message.
This commit is contained in:
AJ ONeal 2026-03-12 18:07:19 -06:00
parent ab898e4444
commit 8ac858128e
No known key found for this signature in database
2 changed files with 58 additions and 42 deletions

View File

@ -38,13 +38,14 @@ type Claims interface {
// JWS is a decoded JSON Web Signature / JWT.
//
// Claims is stored as the [Claims] interface so that any embedded-struct type
// can be used without generics. Access the concrete type via type assertion or,
// more conveniently, via the pointer you passed to [Decode].
// can be used without generics. Access the concrete type via the pointer you
// passed to [JWS.UnmarshalClaims].
//
// Typical usage:
//
// jws, err := embeddedjwt.Decode(tokenString)
// var claims AppClaims
// jws, err := embeddedjwt.Decode(tokenString, &claims)
// err = jws.UnmarshalClaims(&claims)
// jws.UnsafeVerify(pubKey)
// errs, err := jws.Validate(params)
// // claims.Email, claims.Roles, etc. are already populated
@ -99,17 +100,10 @@ func (c StandardClaims) Validate(params ValidateParams) ([]string, error) {
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
//
// claims must be a pointer to the caller's pre-allocated claims struct
// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it,
// and the same pointer is stored in jws.Claims. This means callers can
// access custom fields through their own variable without a type assertion:
//
// var claims AppClaims
// jws, err := embeddedjwt.Decode(token, &claims)
// // claims.Email is already set; no type assertion needed
//
// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first.
func Decode(tokenStr string, claims Claims) (*JWS, error) {
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
// Decode to populate a typed claims struct. The signature is not verified;
// call [JWS.UnsafeVerify] before [JWS.Validate].
func Decode(tokenStr string) (*JWS, error) {
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
@ -127,24 +121,37 @@ func Decode(tokenStr string, claims Claims) (*JWS, error) {
return nil, fmt.Errorf("invalid header JSON: %v", err)
}
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, fmt.Errorf("invalid claims encoding: %v", err)
}
// Unmarshal into the concrete type behind the Claims interface.
// json.Unmarshal receives the concrete pointer via reflection.
if err := json.Unmarshal(payload, claims); err != nil {
return nil, fmt.Errorf("invalid claims JSON: %v", err)
}
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
return nil, fmt.Errorf("invalid signature encoding: %v", err)
}
jws.Claims = claims
return &jws, nil
}
// UnmarshalClaims decodes the JWT payload into v and stores v in jws.Claims.
//
// v must be a pointer to a concrete type that satisfies [Claims] (e.g.
// *AppClaims). After this call, the caller's variable is populated and
// jws.Validate will use it — no type assertion needed:
//
// jws, _ := embeddedjwt.Decode(token)
// var claims AppClaims
// _ = jws.UnmarshalClaims(&claims)
// jws.UnsafeVerify(pubKey)
// jws.Validate(params)
// // claims.Email, claims.Roles, etc. are already set
func (jws *JWS) UnmarshalClaims(v Claims) error {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return fmt.Errorf("invalid claims encoding: %v", err)
}
if err := json.Unmarshal(payload, v); err != nil {
return fmt.Errorf("invalid claims JSON: %v", err)
}
jws.Claims = v
return nil
}
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
//
// kid identifies the signing key. The "alg" header field is set automatically
@ -255,7 +262,12 @@ func (jws *JWS) UnsafeVerify(pub Key) bool {
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
// additionally enforces that the signature was verified (unless params.IgnoreSig).
// Returns an error if [JWS.UnmarshalClaims] has not been called.
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
if jws.Claims == nil {
return []string{"claims not decoded: call UnmarshalClaims before Validate"}, fmt.Errorf("has errors")
}
if params.Now.IsZero() {
params.Now = time.Now()
}

View File

@ -79,9 +79,9 @@ func goodParams() embeddedjwt.ValidateParams {
}
}
// TestRoundTrip is the primary happy path: sign, encode, decode, verify,
// validate — and confirm that custom fields are accessible without a type
// assertion via the local &claims pointer.
// TestRoundTrip is the primary happy path: sign, encode, decode, unmarshal,
// verify, validate — and confirm that custom fields are accessible via the
// local claims pointer without a type assertion.
func TestRoundTrip(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
@ -103,11 +103,14 @@ func TestRoundTrip(t *testing.T) {
token := jws.Encode()
var decoded AppClaims
jws2, err := embeddedjwt.Decode(token, &decoded)
jws2, err := embeddedjwt.Decode(token)
if err != nil {
t.Fatal(err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
@ -142,11 +145,14 @@ func TestRoundTripRS256(t *testing.T) {
token := jws.Encode()
var decoded AppClaims
jws2, err := embeddedjwt.Decode(token, &decoded)
jws2, err := embeddedjwt.Decode(token)
if err != nil {
t.Fatal(err)
}
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed")
}
@ -165,8 +171,9 @@ func TestPromotedValidate(t *testing.T) {
_, _ = jws.Sign(privKey)
token := jws.Encode()
jws2, _ := embeddedjwt.Decode(token)
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
_ = jws2.UnmarshalClaims(&decoded)
jws2.UnsafeVerify(&privKey.PublicKey)
if errs, err := jws2.Validate(goodParams()); err != nil {
@ -200,8 +207,9 @@ func TestOverriddenValidate(t *testing.T) {
_, _ = jws.Sign(privKey)
token := jws.Encode()
jws2, _ := embeddedjwt.Decode(token)
var decoded StrictAppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
_ = jws2.UnmarshalClaims(&decoded)
jws2.UnsafeVerify(&privKey.PublicKey)
errs, err := jws2.Validate(goodParams())
@ -230,8 +238,7 @@ func TestUnsafeVerifyWrongKey(t *testing.T) {
_, _ = jws.Sign(signingKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2, _ := embeddedjwt.Decode(token)
if jws2.UnsafeVerify(&wrongKey.PublicKey) {
t.Fatal("expected verification to fail with wrong key")
@ -248,8 +255,7 @@ func TestVerifyWrongKeyType(t *testing.T) {
_, _ = jws.Sign(ecKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2, _ := embeddedjwt.Decode(token)
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
t.Fatal("expected verification to fail: RSA key for ES256 token")
@ -265,8 +271,7 @@ func TestVerifyUnknownAlg(t *testing.T) {
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2, _ := embeddedjwt.Decode(token)
jws2.Header.Alg = "none"
if jws2.UnsafeVerify(&privKey.PublicKey) {
@ -285,8 +290,7 @@ func TestVerifyWithJWKSKey(t *testing.T) {
_, _ = jws.Sign(privKey)
token := jws.Encode()
var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2, _ := embeddedjwt.Decode(token)
if !jws2.UnsafeVerify(jwksKey.Key) {
t.Fatal("verification via PublicJWK.Key failed")