diff --git a/auth/embeddedjwt/jwt.go b/auth/embeddedjwt/jwt.go index ff6c70f..5aa2442 100644 --- a/auth/embeddedjwt/jwt.go +++ b/auth/embeddedjwt/jwt.go @@ -38,13 +38,14 @@ type Claims interface { // JWS is a decoded JSON Web Signature / JWT. // // Claims is stored as the [Claims] interface so that any embedded-struct type -// can be used without generics. Access the concrete type via type assertion or, -// more conveniently, via the pointer you passed to [Decode]. +// can be used without generics. Access the concrete type via the pointer you +// passed to [JWS.UnmarshalClaims]. // // Typical usage: // +// jws, err := embeddedjwt.Decode(tokenString) // var claims AppClaims -// jws, err := embeddedjwt.Decode(tokenString, &claims) +// err = jws.UnmarshalClaims(&claims) // jws.UnsafeVerify(pubKey) // errs, err := jws.Validate(params) // // claims.Email, claims.Roles, etc. are already populated @@ -99,17 +100,10 @@ func (c StandardClaims) Validate(params ValidateParams) ([]string, error) { // Decode parses a compact JWT string (header.payload.signature) into a JWS. // -// claims must be a pointer to the caller's pre-allocated claims struct -// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it, -// and the same pointer is stored in jws.Claims. This means callers can -// access custom fields through their own variable without a type assertion: -// -// var claims AppClaims -// jws, err := embeddedjwt.Decode(token, &claims) -// // claims.Email is already set; no type assertion needed -// -// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first. -func Decode(tokenStr string, claims Claims) (*JWS, error) { +// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after +// Decode to populate a typed claims struct. The signature is not verified; +// call [JWS.UnsafeVerify] before [JWS.Validate]. +func Decode(tokenStr string) (*JWS, error) { parts := strings.Split(tokenStr, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format") @@ -127,24 +121,37 @@ func Decode(tokenStr string, claims Claims) (*JWS, error) { return nil, fmt.Errorf("invalid header JSON: %v", err) } - payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) - if err != nil { - return nil, fmt.Errorf("invalid claims encoding: %v", err) - } - // Unmarshal into the concrete type behind the Claims interface. - // json.Unmarshal receives the concrete pointer via reflection. - if err := json.Unmarshal(payload, claims); err != nil { - return nil, fmt.Errorf("invalid claims JSON: %v", err) - } - if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil { return nil, fmt.Errorf("invalid signature encoding: %v", err) } - jws.Claims = claims return &jws, nil } +// UnmarshalClaims decodes the JWT payload into v and stores v in jws.Claims. +// +// v must be a pointer to a concrete type that satisfies [Claims] (e.g. +// *AppClaims). After this call, the caller's variable is populated and +// jws.Validate will use it — no type assertion needed: +// +// jws, _ := embeddedjwt.Decode(token) +// var claims AppClaims +// _ = jws.UnmarshalClaims(&claims) +// jws.UnsafeVerify(pubKey) +// jws.Validate(params) +// // claims.Email, claims.Roles, etc. are already set +func (jws *JWS) UnmarshalClaims(v Claims) error { + payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return fmt.Errorf("invalid claims encoding: %v", err) + } + if err := json.Unmarshal(payload, v); err != nil { + return fmt.Errorf("invalid claims JSON: %v", err) + } + jws.Claims = v + return nil +} + // NewJWSFromClaims creates an unsigned JWS from the provided claims. // // kid identifies the signing key. The "alg" header field is set automatically @@ -255,7 +262,12 @@ func (jws *JWS) UnsafeVerify(pub Key) bool { // Validate sets params.Now if zero, then delegates to jws.Claims.Validate and // additionally enforces that the signature was verified (unless params.IgnoreSig). +// Returns an error if [JWS.UnmarshalClaims] has not been called. func (jws *JWS) Validate(params ValidateParams) ([]string, error) { + if jws.Claims == nil { + return []string{"claims not decoded: call UnmarshalClaims before Validate"}, fmt.Errorf("has errors") + } + if params.Now.IsZero() { params.Now = time.Now() } diff --git a/auth/embeddedjwt/jwt_test.go b/auth/embeddedjwt/jwt_test.go index a692e89..8586b19 100644 --- a/auth/embeddedjwt/jwt_test.go +++ b/auth/embeddedjwt/jwt_test.go @@ -79,9 +79,9 @@ func goodParams() embeddedjwt.ValidateParams { } } -// TestRoundTrip is the primary happy path: sign, encode, decode, verify, -// validate — and confirm that custom fields are accessible without a type -// assertion via the local &claims pointer. +// TestRoundTrip is the primary happy path: sign, encode, decode, unmarshal, +// verify, validate — and confirm that custom fields are accessible via the +// local claims pointer without a type assertion. func TestRoundTrip(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -103,11 +103,14 @@ func TestRoundTrip(t *testing.T) { token := jws.Encode() - var decoded AppClaims - jws2, err := embeddedjwt.Decode(token, &decoded) + jws2, err := embeddedjwt.Decode(token) if err != nil { t.Fatal(err) } + var decoded AppClaims + if err = jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatal(err) + } if !jws2.UnsafeVerify(&privKey.PublicKey) { t.Fatal("signature verification failed") } @@ -142,11 +145,14 @@ func TestRoundTripRS256(t *testing.T) { token := jws.Encode() - var decoded AppClaims - jws2, err := embeddedjwt.Decode(token, &decoded) + jws2, err := embeddedjwt.Decode(token) if err != nil { t.Fatal(err) } + var decoded AppClaims + if err = jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatal(err) + } if !jws2.UnsafeVerify(&privKey.PublicKey) { t.Fatal("signature verification failed") } @@ -165,8 +171,9 @@ func TestPromotedValidate(t *testing.T) { _, _ = jws.Sign(privKey) token := jws.Encode() + jws2, _ := embeddedjwt.Decode(token) var decoded AppClaims - jws2, _ := embeddedjwt.Decode(token, &decoded) + _ = jws2.UnmarshalClaims(&decoded) jws2.UnsafeVerify(&privKey.PublicKey) if errs, err := jws2.Validate(goodParams()); err != nil { @@ -200,8 +207,9 @@ func TestOverriddenValidate(t *testing.T) { _, _ = jws.Sign(privKey) token := jws.Encode() + jws2, _ := embeddedjwt.Decode(token) var decoded StrictAppClaims - jws2, _ := embeddedjwt.Decode(token, &decoded) + _ = jws2.UnmarshalClaims(&decoded) jws2.UnsafeVerify(&privKey.PublicKey) errs, err := jws2.Validate(goodParams()) @@ -230,8 +238,7 @@ func TestUnsafeVerifyWrongKey(t *testing.T) { _, _ = jws.Sign(signingKey) token := jws.Encode() - var decoded AppClaims - jws2, _ := embeddedjwt.Decode(token, &decoded) + jws2, _ := embeddedjwt.Decode(token) if jws2.UnsafeVerify(&wrongKey.PublicKey) { t.Fatal("expected verification to fail with wrong key") @@ -248,8 +255,7 @@ func TestVerifyWrongKeyType(t *testing.T) { _, _ = jws.Sign(ecKey) token := jws.Encode() - var decoded AppClaims - jws2, _ := embeddedjwt.Decode(token, &decoded) + jws2, _ := embeddedjwt.Decode(token) if jws2.UnsafeVerify(&rsaKey.PublicKey) { t.Fatal("expected verification to fail: RSA key for ES256 token") @@ -265,8 +271,7 @@ func TestVerifyUnknownAlg(t *testing.T) { _, _ = jws.Sign(privKey) token := jws.Encode() - var decoded AppClaims - jws2, _ := embeddedjwt.Decode(token, &decoded) + jws2, _ := embeddedjwt.Decode(token) jws2.Header.Alg = "none" if jws2.UnsafeVerify(&privKey.PublicKey) { @@ -285,8 +290,7 @@ func TestVerifyWithJWKSKey(t *testing.T) { _, _ = jws.Sign(privKey) token := jws.Encode() - var decoded AppClaims - jws2, _ := embeddedjwt.Decode(token, &decoded) + jws2, _ := embeddedjwt.Decode(token) if !jws2.UnsafeVerify(jwksKey.Key) { t.Fatal("verification via PublicJWK.Key failed")