mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-13 12:27:59 +00:00
refactor(auth/embeddedjwt): split Decode from claims unmarshaling
Remove the claims parameter from Decode — it now only parses the JWS structure (header, payload, signature). Add UnmarshalClaims(v Claims) to decode the payload into a typed struct as a separate step. Guard Validate against nil Claims with a clear error message.
This commit is contained in:
parent
ab898e4444
commit
8ac858128e
@ -38,13 +38,14 @@ type Claims interface {
|
||||
// JWS is a decoded JSON Web Signature / JWT.
|
||||
//
|
||||
// Claims is stored as the [Claims] interface so that any embedded-struct type
|
||||
// can be used without generics. Access the concrete type via type assertion or,
|
||||
// more conveniently, via the pointer you passed to [Decode].
|
||||
// can be used without generics. Access the concrete type via the pointer you
|
||||
// passed to [JWS.UnmarshalClaims].
|
||||
//
|
||||
// Typical usage:
|
||||
//
|
||||
// jws, err := embeddedjwt.Decode(tokenString)
|
||||
// var claims AppClaims
|
||||
// jws, err := embeddedjwt.Decode(tokenString, &claims)
|
||||
// err = jws.UnmarshalClaims(&claims)
|
||||
// jws.UnsafeVerify(pubKey)
|
||||
// errs, err := jws.Validate(params)
|
||||
// // claims.Email, claims.Roles, etc. are already populated
|
||||
@ -99,17 +100,10 @@ func (c StandardClaims) Validate(params ValidateParams) ([]string, error) {
|
||||
|
||||
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
||||
//
|
||||
// claims must be a pointer to the caller's pre-allocated claims struct
|
||||
// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it,
|
||||
// and the same pointer is stored in jws.Claims. This means callers can
|
||||
// access custom fields through their own variable without a type assertion:
|
||||
//
|
||||
// var claims AppClaims
|
||||
// jws, err := embeddedjwt.Decode(token, &claims)
|
||||
// // claims.Email is already set; no type assertion needed
|
||||
//
|
||||
// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first.
|
||||
func Decode(tokenStr string, claims Claims) (*JWS, error) {
|
||||
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
||||
// Decode to populate a typed claims struct. The signature is not verified;
|
||||
// call [JWS.UnsafeVerify] before [JWS.Validate].
|
||||
func Decode(tokenStr string) (*JWS, error) {
|
||||
parts := strings.Split(tokenStr, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
@ -127,24 +121,37 @@ func Decode(tokenStr string, claims Claims) (*JWS, error) {
|
||||
return nil, fmt.Errorf("invalid header JSON: %v", err)
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid claims encoding: %v", err)
|
||||
}
|
||||
// Unmarshal into the concrete type behind the Claims interface.
|
||||
// json.Unmarshal receives the concrete pointer via reflection.
|
||||
if err := json.Unmarshal(payload, claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid claims JSON: %v", err)
|
||||
}
|
||||
|
||||
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
|
||||
return nil, fmt.Errorf("invalid signature encoding: %v", err)
|
||||
}
|
||||
|
||||
jws.Claims = claims
|
||||
return &jws, nil
|
||||
}
|
||||
|
||||
// UnmarshalClaims decodes the JWT payload into v and stores v in jws.Claims.
|
||||
//
|
||||
// v must be a pointer to a concrete type that satisfies [Claims] (e.g.
|
||||
// *AppClaims). After this call, the caller's variable is populated and
|
||||
// jws.Validate will use it — no type assertion needed:
|
||||
//
|
||||
// jws, _ := embeddedjwt.Decode(token)
|
||||
// var claims AppClaims
|
||||
// _ = jws.UnmarshalClaims(&claims)
|
||||
// jws.UnsafeVerify(pubKey)
|
||||
// jws.Validate(params)
|
||||
// // claims.Email, claims.Roles, etc. are already set
|
||||
func (jws *JWS) UnmarshalClaims(v Claims) error {
|
||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid claims encoding: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(payload, v); err != nil {
|
||||
return fmt.Errorf("invalid claims JSON: %v", err)
|
||||
}
|
||||
jws.Claims = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
|
||||
//
|
||||
// kid identifies the signing key. The "alg" header field is set automatically
|
||||
@ -255,7 +262,12 @@ func (jws *JWS) UnsafeVerify(pub Key) bool {
|
||||
|
||||
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
|
||||
// additionally enforces that the signature was verified (unless params.IgnoreSig).
|
||||
// Returns an error if [JWS.UnmarshalClaims] has not been called.
|
||||
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
|
||||
if jws.Claims == nil {
|
||||
return []string{"claims not decoded: call UnmarshalClaims before Validate"}, fmt.Errorf("has errors")
|
||||
}
|
||||
|
||||
if params.Now.IsZero() {
|
||||
params.Now = time.Now()
|
||||
}
|
||||
|
||||
@ -79,9 +79,9 @@ func goodParams() embeddedjwt.ValidateParams {
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoundTrip is the primary happy path: sign, encode, decode, verify,
|
||||
// validate — and confirm that custom fields are accessible without a type
|
||||
// assertion via the local &claims pointer.
|
||||
// TestRoundTrip is the primary happy path: sign, encode, decode, unmarshal,
|
||||
// verify, validate — and confirm that custom fields are accessible via the
|
||||
// local claims pointer without a type assertion.
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
@ -103,11 +103,14 @@ func TestRoundTrip(t *testing.T) {
|
||||
|
||||
token := jws.Encode()
|
||||
|
||||
var decoded AppClaims
|
||||
jws2, err := embeddedjwt.Decode(token, &decoded)
|
||||
jws2, err := embeddedjwt.Decode(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var decoded AppClaims
|
||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||
t.Fatal("signature verification failed")
|
||||
}
|
||||
@ -142,11 +145,14 @@ func TestRoundTripRS256(t *testing.T) {
|
||||
|
||||
token := jws.Encode()
|
||||
|
||||
var decoded AppClaims
|
||||
jws2, err := embeddedjwt.Decode(token, &decoded)
|
||||
jws2, err := embeddedjwt.Decode(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var decoded AppClaims
|
||||
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||
t.Fatal("signature verification failed")
|
||||
}
|
||||
@ -165,8 +171,9 @@ func TestPromotedValidate(t *testing.T) {
|
||||
_, _ = jws.Sign(privKey)
|
||||
token := jws.Encode()
|
||||
|
||||
jws2, _ := embeddedjwt.Decode(token)
|
||||
var decoded AppClaims
|
||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
||||
_ = jws2.UnmarshalClaims(&decoded)
|
||||
jws2.UnsafeVerify(&privKey.PublicKey)
|
||||
|
||||
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||
@ -200,8 +207,9 @@ func TestOverriddenValidate(t *testing.T) {
|
||||
_, _ = jws.Sign(privKey)
|
||||
token := jws.Encode()
|
||||
|
||||
jws2, _ := embeddedjwt.Decode(token)
|
||||
var decoded StrictAppClaims
|
||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
||||
_ = jws2.UnmarshalClaims(&decoded)
|
||||
jws2.UnsafeVerify(&privKey.PublicKey)
|
||||
|
||||
errs, err := jws2.Validate(goodParams())
|
||||
@ -230,8 +238,7 @@ func TestUnsafeVerifyWrongKey(t *testing.T) {
|
||||
_, _ = jws.Sign(signingKey)
|
||||
token := jws.Encode()
|
||||
|
||||
var decoded AppClaims
|
||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
||||
jws2, _ := embeddedjwt.Decode(token)
|
||||
|
||||
if jws2.UnsafeVerify(&wrongKey.PublicKey) {
|
||||
t.Fatal("expected verification to fail with wrong key")
|
||||
@ -248,8 +255,7 @@ func TestVerifyWrongKeyType(t *testing.T) {
|
||||
_, _ = jws.Sign(ecKey)
|
||||
token := jws.Encode()
|
||||
|
||||
var decoded AppClaims
|
||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
||||
jws2, _ := embeddedjwt.Decode(token)
|
||||
|
||||
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
|
||||
t.Fatal("expected verification to fail: RSA key for ES256 token")
|
||||
@ -265,8 +271,7 @@ func TestVerifyUnknownAlg(t *testing.T) {
|
||||
_, _ = jws.Sign(privKey)
|
||||
token := jws.Encode()
|
||||
|
||||
var decoded AppClaims
|
||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
||||
jws2, _ := embeddedjwt.Decode(token)
|
||||
jws2.Header.Alg = "none"
|
||||
|
||||
if jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||
@ -285,8 +290,7 @@ func TestVerifyWithJWKSKey(t *testing.T) {
|
||||
_, _ = jws.Sign(privKey)
|
||||
token := jws.Encode()
|
||||
|
||||
var decoded AppClaims
|
||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
||||
jws2, _ := embeddedjwt.Decode(token)
|
||||
|
||||
if !jws2.UnsafeVerify(jwksKey.Key) {
|
||||
t.Fatal("verification via PublicJWK.Key failed")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user