mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-13 12:27:59 +00:00
refactor(auth/embeddedjwt): split Decode from claims unmarshaling
Remove the claims parameter from Decode — it now only parses the JWS structure (header, payload, signature). Add UnmarshalClaims(v Claims) to decode the payload into a typed struct as a separate step. Guard Validate against nil Claims with a clear error message.
This commit is contained in:
parent
ab898e4444
commit
8ac858128e
@ -38,13 +38,14 @@ type Claims interface {
|
|||||||
// JWS is a decoded JSON Web Signature / JWT.
|
// JWS is a decoded JSON Web Signature / JWT.
|
||||||
//
|
//
|
||||||
// Claims is stored as the [Claims] interface so that any embedded-struct type
|
// Claims is stored as the [Claims] interface so that any embedded-struct type
|
||||||
// can be used without generics. Access the concrete type via type assertion or,
|
// can be used without generics. Access the concrete type via the pointer you
|
||||||
// more conveniently, via the pointer you passed to [Decode].
|
// passed to [JWS.UnmarshalClaims].
|
||||||
//
|
//
|
||||||
// Typical usage:
|
// Typical usage:
|
||||||
//
|
//
|
||||||
|
// jws, err := embeddedjwt.Decode(tokenString)
|
||||||
// var claims AppClaims
|
// var claims AppClaims
|
||||||
// jws, err := embeddedjwt.Decode(tokenString, &claims)
|
// err = jws.UnmarshalClaims(&claims)
|
||||||
// jws.UnsafeVerify(pubKey)
|
// jws.UnsafeVerify(pubKey)
|
||||||
// errs, err := jws.Validate(params)
|
// errs, err := jws.Validate(params)
|
||||||
// // claims.Email, claims.Roles, etc. are already populated
|
// // claims.Email, claims.Roles, etc. are already populated
|
||||||
@ -99,17 +100,10 @@ func (c StandardClaims) Validate(params ValidateParams) ([]string, error) {
|
|||||||
|
|
||||||
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
|
||||||
//
|
//
|
||||||
// claims must be a pointer to the caller's pre-allocated claims struct
|
// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after
|
||||||
// (e.g. &AppClaims{}). The JSON payload is unmarshaled directly into it,
|
// Decode to populate a typed claims struct. The signature is not verified;
|
||||||
// and the same pointer is stored in jws.Claims. This means callers can
|
// call [JWS.UnsafeVerify] before [JWS.Validate].
|
||||||
// access custom fields through their own variable without a type assertion:
|
func Decode(tokenStr string) (*JWS, error) {
|
||||||
//
|
|
||||||
// var claims AppClaims
|
|
||||||
// jws, err := embeddedjwt.Decode(token, &claims)
|
|
||||||
// // claims.Email is already set; no type assertion needed
|
|
||||||
//
|
|
||||||
// The signature is not verified by Decode. Call [JWS.UnsafeVerify] first.
|
|
||||||
func Decode(tokenStr string, claims Claims) (*JWS, error) {
|
|
||||||
parts := strings.Split(tokenStr, ".")
|
parts := strings.Split(tokenStr, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, fmt.Errorf("invalid JWT format")
|
return nil, fmt.Errorf("invalid JWT format")
|
||||||
@ -127,24 +121,37 @@ func Decode(tokenStr string, claims Claims) (*JWS, error) {
|
|||||||
return nil, fmt.Errorf("invalid header JSON: %v", err)
|
return nil, fmt.Errorf("invalid header JSON: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid claims encoding: %v", err)
|
|
||||||
}
|
|
||||||
// Unmarshal into the concrete type behind the Claims interface.
|
|
||||||
// json.Unmarshal receives the concrete pointer via reflection.
|
|
||||||
if err := json.Unmarshal(payload, claims); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid claims JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
|
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
|
||||||
return nil, fmt.Errorf("invalid signature encoding: %v", err)
|
return nil, fmt.Errorf("invalid signature encoding: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jws.Claims = claims
|
|
||||||
return &jws, nil
|
return &jws, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalClaims decodes the JWT payload into v and stores v in jws.Claims.
|
||||||
|
//
|
||||||
|
// v must be a pointer to a concrete type that satisfies [Claims] (e.g.
|
||||||
|
// *AppClaims). After this call, the caller's variable is populated and
|
||||||
|
// jws.Validate will use it — no type assertion needed:
|
||||||
|
//
|
||||||
|
// jws, _ := embeddedjwt.Decode(token)
|
||||||
|
// var claims AppClaims
|
||||||
|
// _ = jws.UnmarshalClaims(&claims)
|
||||||
|
// jws.UnsafeVerify(pubKey)
|
||||||
|
// jws.Validate(params)
|
||||||
|
// // claims.Email, claims.Roles, etc. are already set
|
||||||
|
func (jws *JWS) UnmarshalClaims(v Claims) error {
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid claims encoding: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(payload, v); err != nil {
|
||||||
|
return fmt.Errorf("invalid claims JSON: %v", err)
|
||||||
|
}
|
||||||
|
jws.Claims = v
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
|
// NewJWSFromClaims creates an unsigned JWS from the provided claims.
|
||||||
//
|
//
|
||||||
// kid identifies the signing key. The "alg" header field is set automatically
|
// kid identifies the signing key. The "alg" header field is set automatically
|
||||||
@ -255,7 +262,12 @@ func (jws *JWS) UnsafeVerify(pub Key) bool {
|
|||||||
|
|
||||||
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
|
// Validate sets params.Now if zero, then delegates to jws.Claims.Validate and
|
||||||
// additionally enforces that the signature was verified (unless params.IgnoreSig).
|
// additionally enforces that the signature was verified (unless params.IgnoreSig).
|
||||||
|
// Returns an error if [JWS.UnmarshalClaims] has not been called.
|
||||||
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
|
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
|
||||||
|
if jws.Claims == nil {
|
||||||
|
return []string{"claims not decoded: call UnmarshalClaims before Validate"}, fmt.Errorf("has errors")
|
||||||
|
}
|
||||||
|
|
||||||
if params.Now.IsZero() {
|
if params.Now.IsZero() {
|
||||||
params.Now = time.Now()
|
params.Now = time.Now()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -79,9 +79,9 @@ func goodParams() embeddedjwt.ValidateParams {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoundTrip is the primary happy path: sign, encode, decode, verify,
|
// TestRoundTrip is the primary happy path: sign, encode, decode, unmarshal,
|
||||||
// validate — and confirm that custom fields are accessible without a type
|
// verify, validate — and confirm that custom fields are accessible via the
|
||||||
// assertion via the local &claims pointer.
|
// local claims pointer without a type assertion.
|
||||||
func TestRoundTrip(t *testing.T) {
|
func TestRoundTrip(t *testing.T) {
|
||||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -103,11 +103,14 @@ func TestRoundTrip(t *testing.T) {
|
|||||||
|
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
var decoded AppClaims
|
jws2, err := embeddedjwt.Decode(token)
|
||||||
jws2, err := embeddedjwt.Decode(token, &decoded)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
var decoded AppClaims
|
||||||
|
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
t.Fatal("signature verification failed")
|
t.Fatal("signature verification failed")
|
||||||
}
|
}
|
||||||
@ -142,11 +145,14 @@ func TestRoundTripRS256(t *testing.T) {
|
|||||||
|
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
var decoded AppClaims
|
jws2, err := embeddedjwt.Decode(token)
|
||||||
jws2, err := embeddedjwt.Decode(token, &decoded)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
var decoded AppClaims
|
||||||
|
if err = jws2.UnmarshalClaims(&decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
if !jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
t.Fatal("signature verification failed")
|
t.Fatal("signature verification failed")
|
||||||
}
|
}
|
||||||
@ -165,8 +171,9 @@ func TestPromotedValidate(t *testing.T) {
|
|||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
|
jws2, _ := embeddedjwt.Decode(token)
|
||||||
var decoded AppClaims
|
var decoded AppClaims
|
||||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
_ = jws2.UnmarshalClaims(&decoded)
|
||||||
jws2.UnsafeVerify(&privKey.PublicKey)
|
jws2.UnsafeVerify(&privKey.PublicKey)
|
||||||
|
|
||||||
if errs, err := jws2.Validate(goodParams()); err != nil {
|
if errs, err := jws2.Validate(goodParams()); err != nil {
|
||||||
@ -200,8 +207,9 @@ func TestOverriddenValidate(t *testing.T) {
|
|||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
|
jws2, _ := embeddedjwt.Decode(token)
|
||||||
var decoded StrictAppClaims
|
var decoded StrictAppClaims
|
||||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
_ = jws2.UnmarshalClaims(&decoded)
|
||||||
jws2.UnsafeVerify(&privKey.PublicKey)
|
jws2.UnsafeVerify(&privKey.PublicKey)
|
||||||
|
|
||||||
errs, err := jws2.Validate(goodParams())
|
errs, err := jws2.Validate(goodParams())
|
||||||
@ -230,8 +238,7 @@ func TestUnsafeVerifyWrongKey(t *testing.T) {
|
|||||||
_, _ = jws.Sign(signingKey)
|
_, _ = jws.Sign(signingKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
var decoded AppClaims
|
jws2, _ := embeddedjwt.Decode(token)
|
||||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
|
||||||
|
|
||||||
if jws2.UnsafeVerify(&wrongKey.PublicKey) {
|
if jws2.UnsafeVerify(&wrongKey.PublicKey) {
|
||||||
t.Fatal("expected verification to fail with wrong key")
|
t.Fatal("expected verification to fail with wrong key")
|
||||||
@ -248,8 +255,7 @@ func TestVerifyWrongKeyType(t *testing.T) {
|
|||||||
_, _ = jws.Sign(ecKey)
|
_, _ = jws.Sign(ecKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
var decoded AppClaims
|
jws2, _ := embeddedjwt.Decode(token)
|
||||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
|
||||||
|
|
||||||
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
|
if jws2.UnsafeVerify(&rsaKey.PublicKey) {
|
||||||
t.Fatal("expected verification to fail: RSA key for ES256 token")
|
t.Fatal("expected verification to fail: RSA key for ES256 token")
|
||||||
@ -265,8 +271,7 @@ func TestVerifyUnknownAlg(t *testing.T) {
|
|||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
var decoded AppClaims
|
jws2, _ := embeddedjwt.Decode(token)
|
||||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
|
||||||
jws2.Header.Alg = "none"
|
jws2.Header.Alg = "none"
|
||||||
|
|
||||||
if jws2.UnsafeVerify(&privKey.PublicKey) {
|
if jws2.UnsafeVerify(&privKey.PublicKey) {
|
||||||
@ -285,8 +290,7 @@ func TestVerifyWithJWKSKey(t *testing.T) {
|
|||||||
_, _ = jws.Sign(privKey)
|
_, _ = jws.Sign(privKey)
|
||||||
token := jws.Encode()
|
token := jws.Encode()
|
||||||
|
|
||||||
var decoded AppClaims
|
jws2, _ := embeddedjwt.Decode(token)
|
||||||
jws2, _ := embeddedjwt.Decode(token, &decoded)
|
|
||||||
|
|
||||||
if !jws2.UnsafeVerify(jwksKey.Key) {
|
if !jws2.UnsafeVerify(jwksKey.Key) {
|
||||||
t.Fatal("verification via PublicJWK.Key failed")
|
t.Fatal("verification via PublicJWK.Key failed")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user