refactor(auth/embeddedjwt): split Decode from claims unmarshaling

Remove the claims parameter from Decode — it now only parses the JWS
structure (header, payload, signature). Add UnmarshalClaims(v Claims)
to decode the payload into a typed struct as a separate step. Guard
Validate against nil Claims with a clear error message.
This commit is contained in:
AJ ONeal 2026-03-12 18:07:19 -06:00
parent ab898e4444
commit 8ac858128e
No known key found for this signature in database
2 changed files with 58 additions and 42 deletions

View File

@ -38,13 +38,14 @@ type Claims interface {
// JWS is a decoded JSON Web Signature / JWT. // JWS is a decoded JSON Web Signature / JWT.
// //
// Claims is stored as the [Claims] interface so that any embedded-struct type // Claims is stored as the [Claims] interface so that any embedded-struct type
// can be used without generics. Access the concrete type via type assertion or, // can be used without generics. Access the concrete type via the pointer you
// more conveniently, via the pointer you passed to [Decode]. // passed to [JWS.UnmarshalClaims].
// //
// Typical usage: // Typical usage:
// //
// jws, err := embeddedjwt.Decode(tokenString)
// var claims AppClaims // var claims AppClaims
// jws, err := embeddedjwt.Decode(tokenString, &claims) // err = jws.UnmarshalClaims(&claims)
// jws.UnsafeVerify(pubKey) // jws.UnsafeVerify(pubKey)
// errs, err := jws.Validate(params) // errs, err := jws.Validate(params)
// // claims.Email, claims.Roles, etc. are already populated // // claims.Email, claims.Roles, etc. are already populated
@ -99,17 +100,10 @@ func (c StandardClaims) Validate(params ValidateParams) ([]string, error) {
// Decode parses a compact JWT string (header.payload.signature) into a JWS. // Decode parses a compact JWT string (header.payload.signature) into a JWS.
// //
// claims must be a pointer to the caller's pre-allocated claims struct // It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it, // Decode to populate a typed claims struct. The signature is not verified;
// and the same pointer is stored in jws.Claims. This means callers can // call [JWS.UnsafeVerify] before [JWS.Validate].
// access custom fields through their own variable without a type assertion: func Decode(tokenStr string) (*JWS, error) {
//
// var claims AppClaims
// jws, err := embeddedjwt.Decode(token, &claims)
// // claims.Email is already set; no type assertion needed
//
// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first.
func Decode(tokenStr string, claims Claims) (*JWS, error) {
parts := strings.Split(tokenStr, ".") parts := strings.Split(tokenStr, ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format") return nil, fmt.Errorf("invalid JWT format")
@ -127,24 +121,37 @@ func Decode(tokenStr string, claims Claims) (*JWS, error) {
return nil, fmt.Errorf("invalid header JSON: %v", err) return nil, fmt.Errorf("invalid header JSON: %v", err)
} }
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, fmt.Errorf("invalid claims encoding: %v", err)
}
// Unmarshal into the concrete type behind the Claims interface.
// json.Unmarshal receives the concrete pointer via reflection.
if err := json.Unmarshal(payload, claims); err != nil {
return nil, fmt.Errorf("invalid claims JSON: %v", err)
}
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil { if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
return nil, fmt.Errorf("invalid signature encoding: %v", err) return nil, fmt.Errorf("invalid signature encoding: %v", err)
} }
jws.Claims = claims
return &jws, nil return &jws, nil
} }
// UnmarshalClaims decodes the JWT payload into v and stores v in jws.Claims.
//
// v must be a pointer to a concrete type that satisfies [Claims] (e.g.
// *AppClaims). After this call, the caller's variable is populated and
// jws.Validate will use it — no type assertion needed:
//
// jws, _ := embeddedjwt.Decode(token)
// var claims AppClaims
// _ = jws.UnmarshalClaims(&claims)
// jws.UnsafeVerify(pubKey)
// jws.Validate(params)
// // claims.Email, claims.Roles, etc. are already set
func (jws *JWS) UnmarshalClaims(v Claims) error {
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return fmt.Errorf("invalid claims encoding: %v", err)
}
if err := json.Unmarshal(payload, v); err != nil {
return fmt.Errorf("invalid claims JSON: %v", err)
}
jws.Claims = v
return nil
}
// NewJWSFromClaims creates an unsigned JWS from the provided claims. // NewJWSFromClaims creates an unsigned JWS from the provided claims.
// //
// kid identifies the signing key. The "alg" header field is set automatically // kid identifies the signing key. The "alg" header field is set automatically
@ -255,7 +262,12 @@ func (jws *JWS) UnsafeVerify(pub Key) bool {
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and // Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
// additionally enforces that the signature was verified (unless params.IgnoreSig). // additionally enforces that the signature was verified (unless params.IgnoreSig).
// Returns an error if [JWS.UnmarshalClaims] has not been called.
func (jws *JWS) Validate(params ValidateParams) ([]string, error) { func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
if jws.Claims == nil {
return []string{"claims not decoded: call UnmarshalClaims before Validate"}, fmt.Errorf("has errors")
}
if params.Now.IsZero() { if params.Now.IsZero() {
params.Now = time.Now() params.Now = time.Now()
} }

View File

@ -79,9 +79,9 @@ func goodParams() embeddedjwt.ValidateParams {
} }
} }
// TestRoundTrip is the primary happy path: sign, encode, decode, verify, // TestRoundTrip is the primary happy path: sign, encode, decode, unmarshal,
// validate — and confirm that custom fields are accessible without a type // verify, validate — and confirm that custom fields are accessible via the
// assertion via the local &claims pointer. // local claims pointer without a type assertion.
func TestRoundTrip(t *testing.T) { func TestRoundTrip(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
@ -103,11 +103,14 @@ func TestRoundTrip(t *testing.T) {
token := jws.Encode() token := jws.Encode()
var decoded AppClaims jws2, err := embeddedjwt.Decode(token)
jws2, err := embeddedjwt.Decode(token, &decoded)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) { if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed") t.Fatal("signature verification failed")
} }
@ -142,11 +145,14 @@ func TestRoundTripRS256(t *testing.T) {
token := jws.Encode() token := jws.Encode()
var decoded AppClaims jws2, err := embeddedjwt.Decode(token)
jws2, err := embeddedjwt.Decode(token, &decoded)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var decoded AppClaims
if err = jws2.UnmarshalClaims(&decoded); err != nil {
t.Fatal(err)
}
if !jws2.UnsafeVerify(&privKey.PublicKey) { if !jws2.UnsafeVerify(&privKey.PublicKey) {
t.Fatal("signature verification failed") t.Fatal("signature verification failed")
} }
@ -165,8 +171,9 @@ func TestPromotedValidate(t *testing.T) {
_, _ = jws.Sign(privKey) _, _ = jws.Sign(privKey)
token := jws.Encode() token := jws.Encode()
jws2, _ := embeddedjwt.Decode(token)
var decoded AppClaims var decoded AppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded) _ = jws2.UnmarshalClaims(&decoded)
jws2.UnsafeVerify(&privKey.PublicKey) jws2.UnsafeVerify(&privKey.PublicKey)
if errs, err := jws2.Validate(goodParams()); err != nil { if errs, err := jws2.Validate(goodParams()); err != nil {
@ -200,8 +207,9 @@ func TestOverriddenValidate(t *testing.T) {
_, _ = jws.Sign(privKey) _, _ = jws.Sign(privKey)
token := jws.Encode() token := jws.Encode()
jws2, _ := embeddedjwt.Decode(token)
var decoded StrictAppClaims var decoded StrictAppClaims
jws2, _ := embeddedjwt.Decode(token, &decoded) _ = jws2.UnmarshalClaims(&decoded)
jws2.UnsafeVerify(&privKey.PublicKey) jws2.UnsafeVerify(&privKey.PublicKey)
errs, err := jws2.Validate(goodParams()) errs, err := jws2.Validate(goodParams())
@ -230,8 +238,7 @@ func TestUnsafeVerifyWrongKey(t *testing.T) {
_, _ = jws.Sign(signingKey) _, _ = jws.Sign(signingKey)
token := jws.Encode() token := jws.Encode()
var decoded AppClaims jws2, _ := embeddedjwt.Decode(token)
jws2, _ := embeddedjwt.Decode(token, &decoded)
if jws2.UnsafeVerify(&wrongKey.PublicKey) { if jws2.UnsafeVerify(&wrongKey.PublicKey) {
t.Fatal("expected verification to fail with wrong key") t.Fatal("expected verification to fail with wrong key")
@ -248,8 +255,7 @@ func TestVerifyWrongKeyType(t *testing.T) {
_, _ = jws.Sign(ecKey) _, _ = jws.Sign(ecKey)
token := jws.Encode() token := jws.Encode()
var decoded AppClaims jws2, _ := embeddedjwt.Decode(token)
jws2, _ := embeddedjwt.Decode(token, &decoded)
if jws2.UnsafeVerify(&rsaKey.PublicKey) { if jws2.UnsafeVerify(&rsaKey.PublicKey) {
t.Fatal("expected verification to fail: RSA key for ES256 token") t.Fatal("expected verification to fail: RSA key for ES256 token")
@ -265,8 +271,7 @@ func TestVerifyUnknownAlg(t *testing.T) {
_, _ = jws.Sign(privKey) _, _ = jws.Sign(privKey)
token := jws.Encode() token := jws.Encode()
var decoded AppClaims jws2, _ := embeddedjwt.Decode(token)
jws2, _ := embeddedjwt.Decode(token, &decoded)
jws2.Header.Alg = "none" jws2.Header.Alg = "none"
if jws2.UnsafeVerify(&privKey.PublicKey) { if jws2.UnsafeVerify(&privKey.PublicKey) {
@ -285,8 +290,7 @@ func TestVerifyWithJWKSKey(t *testing.T) {
_, _ = jws.Sign(privKey) _, _ = jws.Sign(privKey)
token := jws.Encode() token := jws.Encode()
var decoded AppClaims jws2, _ := embeddedjwt.Decode(token)
jws2, _ := embeddedjwt.Decode(token, &decoded)
if !jws2.UnsafeVerify(jwksKey.Key) { if !jws2.UnsafeVerify(jwksKey.Key) {
t.Fatal("verification via PublicJWK.Key failed") t.Fatal("verification via PublicJWK.Key failed")