diff --git a/auth/jwt/jwt.go b/auth/jwt/jwt.go new file mode 100644 index 0000000..103b614 --- /dev/null +++ b/auth/jwt/jwt.go @@ -0,0 +1,483 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package jwt + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "slices" + "strings" + "time" +) + +type Keypair struct { + Thumbprint string + PrivateKey *ecdsa.PrivateKey +} + +type JWK struct { + Kty string `json:"kty"` + Crv string `json:"crv"` + D string `json:"d"` + X string `json:"x"` + Y string `json:"y"` +} + +type JWT string + +func (jwt JWT) Split() (string, string, string, error) { + parts := strings.Split(string(jwt), ".") + if len(parts) != 3 { + return "", "", "", fmt.Errorf("invalid JWT format") + } + + rawHeader, rawPayload, rawSig := parts[0], parts[1], parts[2] + return rawHeader, rawPayload, rawSig, nil +} + +func (jwt JWT) Decode() (JWS, error) { + h64, p64, s64, err := jwt.Split() + if err != nil { + return JWS{}, err + } + + var jws JWS + var sigEnc string + jws.Protected, jws.Payload, sigEnc = h64, p64, s64 + + header, err := base64.RawURLEncoding.DecodeString(jws.Protected) + if err != nil { + return jws, fmt.Errorf("invalid header encoding: %v", err) + } + if err := json.Unmarshal(header, &jws.Header); err != nil { + return jws, fmt.Errorf("invalid header JSON: %v", err) + } + + payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return jws, fmt.Errorf("invalid claims encoding: %v", err) + } + if err := json.Unmarshal(payload, &jws.Claims); err != nil { + return jws, fmt.Errorf("invalid claims JSON: %v", err) + } + + if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil { + return jws, fmt.Errorf("invalid signature encoding: %v", err) + } + + return jws, nil +} + +type JWS struct { + Protected string `json:"-"` // base64 + Header MyHeader `json:"headers"` + Payload string `json:"-"` // base64 + Claims MyClaims `json:"claims"` + Signature URLBase64 `json:"signature"` + Verified bool `json:"-"` +} + +type MyHeader struct { + StandardHeader +} + +type StandardHeader struct { + Alg string `json:"alg"` + Kid string `json:"kid"` + Typ string `json:"typ"` +} + +type MyClaims struct { + StandardClaims + Email string `json:"email"` + EmployeeID string `json:"employee_id"` + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` + Roles []string `json:"roles"` +} + +type StandardClaims struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Aud string `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + AuthTime int64 `json:"auth_time"` + Nonce string `json:"nonce,omitempty"` + Amr []string `json:"amr"` + Azp string `json:"azp,omitempty"` + Jti string `json:"jti"` +} + +func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) { + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("invalid JWK X: %v", err) + } + y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, fmt.Errorf("invalid JWK Y: %v", err) + } + d, err := base64.RawURLEncoding.DecodeString(jwk.D) + if err != nil { + return nil, fmt.Errorf("invalid JWK D: %v", err) + } + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + }, + D: new(big.Int).SetBytes(d), + }, nil +} + +func NewJWS(email, employeeID, issuer, thumbprint string, roles []string) (JWS, error) { + var jws JWS + + jws.Header.StandardHeader = StandardHeader{ + Alg: "ES256", + Kid: thumbprint, + Typ: "JWT", + } + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + now := time.Now().Unix() + jtiBytes := make([]byte, 16) + if _, err := rand.Read(jtiBytes); err != nil { + return JWS{}, fmt.Errorf("failed to generate Jti: %v", err) + } + jti := base64.RawURLEncoding.EncodeToString(jtiBytes) + emailName := strings.Split(email, "@")[0] + + jws.Claims = MyClaims{ + StandardClaims: StandardClaims{ + AuthTime: now, + Exp: now + 15*60*37, // TODO remove + Iat: now, + Iss: issuer, + Jti: jti, + Sub: email, + Amr: []string{"pwd"}, + }, + Email: email, + EmployeeID: employeeID, + FamilyName: "McTestface", + GivenName: strings.ToUpper(emailName), + Roles: roles, + } + claimsJSON, _ := json.Marshal(jws.Claims) + jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON) + + return jws, nil +} + +func (jws *JWS) Sign(key *ecdsa.PrivateKey) ([]byte, error) { + var err error + jws.Signature, err = SignJWS(jws.Protected, jws.Payload, key) + return jws.Signature, err +} + +// UnsafeVerify only checks the signature, use Validate to check all values +func (jws *JWS) UnsafeVerify(pub *ecdsa.PublicKey) bool { + hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload)) + n := len(jws.Signature) + if n != 64 { + // return fmt.Errorf("expected a 64-byte signature consisting of two 32-byte r and s components, but got %d instead (perhaps ASN.1 or other format)", n) + return false + } + + r := new(big.Int).SetBytes(jws.Signature[:32]) + s := new(big.Int).SetBytes(jws.Signature[32:]) + + jws.Verified = ecdsa.Verify(pub, hash[:], r, s) + return jws.Verified +} + +// ValidateParams holds validation configuration. +// https://openid.net/specs/openid-connect-core-1_0.html#IDToken +type ValidateParams struct { + Now time.Time + IgnoreIss bool + Iss string + IgnoreSub bool + Sub string + IgnoreAud bool + Aud string + IgnoreExp bool + IgnoreJti bool + Jti string + IgnoreIat bool + IgnoreAuthTime bool + MaxAge time.Duration + IgnoreNonce bool + Nonce string + IgnoreAmr bool + RequiredAmrs []string + IgnoreAzp bool + Azp string + IgnoreSig bool +} + +// Validate checks common JWS fields and issuer, collecting all errors. +func (jws *JWS) Validate(params ValidateParams) ([]string, error) { + var errs []string + + if params.Now.IsZero() { + params.Now = time.Now() + } + + // Required to exist and match + if len(params.Iss) > 0 || !params.IgnoreIss { + if len(jws.Claims.Iss) == 0 { + errs = append(errs, ("missing or malformed 'iss' (token issuer, identifier for public key)")) + } else if jws.Claims.Iss != params.Iss { + errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", jws.Claims.Iss, params.Iss)) + } + } + + // Required to exist, optional match + if len(jws.Claims.Sub) == 0 { + if !params.IgnoreSub { + errs = append(errs, ("missing or malformed 'sub' (subject, typically pairwise user id)")) + } + } else if len(params.Sub) > 0 { + if params.Sub != jws.Claims.Sub { + errs = append(errs, fmt.Sprintf("'sub' (subject, typically pairwise user id) mismatch: got %s, expected %s", jws.Claims.Sub, params.Sub)) + } + } + + // Required to exist and match + if len(params.Aud) > 0 || !params.IgnoreAud { + if len(jws.Claims.Aud) == 0 { + errs = append(errs, ("missing or malformed 'aud' (audience receiving token)")) + } else if jws.Claims.Aud != params.Aud { + errs = append(errs, fmt.Sprintf("'aud' (audience receiving token) mismatch: got %s, expected %s", jws.Claims.Aud, params.Aud)) + } + } + + // Required to exist and not be in the past + if !params.IgnoreExp { + if jws.Claims.Exp <= 0 { + errs = append(errs, ("missing or malformed 'exp' (expiration date in seconds)")) + } else if jws.Claims.Exp < params.Now.Unix() { + duration := time.Since(time.Unix(jws.Claims.Exp, 0)) + expTime := time.Unix(jws.Claims.Exp, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime)) + } + } + + // Required to exist and not be in the future + if !params.IgnoreIat { + if jws.Claims.Iat <= 0 { + errs = append(errs, ("missing or malformed 'iat' (issued at, when token was signed)")) + } else if jws.Claims.Iat > params.Now.Unix() { + duration := time.Unix(jws.Claims.Iat, 0).Sub(params.Now) + iatTime := time.Unix(jws.Claims.Iat, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf("'iat' (issued at, when token was signed) is %s in the future (%s)", formatDuration(duration), iatTime)) + } + } + + // Should exist, in the past, with optional max age + if params.MaxAge > 0 || !params.IgnoreAuthTime { + if jws.Claims.AuthTime == 0 { + errs = append(errs, ("missing or malformed 'auth_time' (time of real-world user authentication, in seconds)")) + } else { + authTime := time.Unix(jws.Claims.AuthTime, 0) + authTimeStr := authTime.Format("2006-01-02 15:04:05 MST") + age := params.Now.Sub(authTime) + diff := age - params.MaxAge + if jws.Claims.AuthTime > params.Now.Unix() { + fromNow := time.Unix(jws.Claims.AuthTime, 0).Sub(params.Now) + authTimeStr := time.Unix(jws.Claims.AuthTime, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf( + "'auth_time' (time of real-world user authentication) of %s is %s in the future (server time %s)", + authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")), + ) + } else if age > params.MaxAge { + errs = append(errs, fmt.Sprintf( + "'auth_time' (time of real-world user authentication) of %s is %s old, which exceeds the max age of %s (%ds) by %s", + authTimeStr, formatDuration(age), formatDuration(params.MaxAge), params.MaxAge/time.Second, formatDuration(diff)), + ) + } + } + } + + // Optional + if params.Jti != jws.Claims.Jti { + if len(params.Jti) > 0 { + errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", jws.Claims.Jti, params.Jti)) + } else if !params.IgnoreJti { + errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", jws.Claims.Jti)) + } + } + + // Optional + if params.Nonce != jws.Claims.Nonce { + if len(params.Nonce) > 0 { + errs = append(errs, fmt.Sprintf("'nonce' (one-time random salt, as string) mismatch: got %s, expected %s", jws.Claims.Nonce, params.Nonce)) + } else if !params.IgnoreNonce { + errs = append(errs, fmt.Sprintf("unchecked 'nonce' (one-time random salt): %s", jws.Claims.Nonce)) + } + } + + // Acr check not implemented because the use case is not yet clear + + // Should exist, optional match + if !params.IgnoreAmr { + if len(jws.Claims.Amr) == 0 { + errs = append(errs, ("missing or malformed 'amr' (authorization methods, as json list)")) + } else { + if len(params.RequiredAmrs) > 0 { + for _, required := range params.RequiredAmrs { + if !slices.Contains(jws.Claims.Amr, required) { + errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr' (authorization methods, as json list)", required)) + } + } + } + + // TODO specify multiple amrs in a tiered list (must have at least one from each list) + // count := 0 + // if len(params.AcceptableAmrs) > 0 { + // for _, amr := range jws.Claims.Amr { + // if slices.Contains(params.AcceptableAmrs, amr) { + // count += 1 + // } + // } + // } + } + } + + // Optional, should match if exists + if params.Azp != jws.Claims.Azp { + if len(params.Azp) > 0 { + errs = append(errs, ("missing or malformed 'azp' (authorized party which presents token)")) + } else if !params.IgnoreAzp { + errs = append(errs, fmt.Sprintf("'azp' mismatch (authorized party which presents token): got %s, expected %s", jws.Claims.Azp, params.Azp)) + } + } + + // Must be checked + if !params.IgnoreSig { + if !jws.Verified { + errs = append(errs, ("signature was not checked")) + } + } + + if len(errs) > 0 { + timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST")) + if loc, err := time.LoadLocation("Local"); err == nil { + timeInfo += fmt.Sprintf(" %s", loc) + } + errs = append(errs, timeInfo) + return errs, fmt.Errorf("has errors") + } + return nil, nil +} + +func SignJWS(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) { + hash := sha256.Sum256([]byte(header + "." + payload)) + r, s, err := ecdsa.Sign(rand.Reader, key, hash[:]) + if err != nil { + return nil, fmt.Errorf("failed to sign: %v", err) + } + return append(r.Bytes(), s.Bytes()...), nil +} + +func (jws JWS) Encode() string { + sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature) + return jws.Protected + "." + jws.Payload + "." + sigEnc +} + +func EncodeToJWT(signingInput string, signature []byte) string { + sigEnc := base64.RawURLEncoding.EncodeToString(signature) + return signingInput + "." + sigEnc +} + +func (jwk JWK) Thumbprint() (string, error) { + data := map[string]string{ + "crv": jwk.Crv, + "kty": jwk.Kty, + "x": jwk.X, + "y": jwk.Y, + } + jsonData, err := json.Marshal(data) + if err != nil { + return "", err + } + hash := sha256.Sum256(jsonData) + return base64.RawURLEncoding.EncodeToString(hash[:]), nil +} + +// URLBase64 unmarshals to bytes and marshals to a raw url base64 string +type URLBase64 []byte + +func (s URLBase64) String() string { + encoded := base64.RawURLEncoding.EncodeToString(s) + return encoded +} + +// MarshalJSON implements JSON marshaling to URL-safe base64. +func (s URLBase64) MarshalJSON() ([]byte, error) { + encoded := base64.RawURLEncoding.EncodeToString(s) + return json.Marshal(encoded) +} + +// UnmarshalJSON implements JSON unmarshaling from URL-safe base64. +func (s *URLBase64) UnmarshalJSON(data []byte) error { + dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data) + if err != nil { + return fmt.Errorf("decode base64url signature: %w", err) + } + + *s = dst + return nil +} + +func formatDuration(d time.Duration) string { + if d < 0 { + d = -d + } + days := int(d / (24 * time.Hour)) + d -= time.Duration(days) * 24 * time.Hour + hours := int(d / time.Hour) + d -= time.Duration(hours) * time.Hour + minutes := int(d / time.Minute) + d -= time.Duration(minutes) * time.Minute + seconds := int(d / time.Second) + + var parts []string + if days > 0 { + parts = append(parts, fmt.Sprintf("%dd", days)) + } + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + if seconds == 0 || len(parts) == 0 { + d -= time.Duration(seconds) * time.Second + millis := int(d / time.Millisecond) + parts = append(parts, fmt.Sprintf("%dms", millis)) + } + + return strings.Join(parts, " ") +} diff --git a/auth/jwt/pub.go b/auth/jwt/pub.go new file mode 100644 index 0000000..c4d14f9 --- /dev/null +++ b/auth/jwt/pub.go @@ -0,0 +1,203 @@ +package jwt + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "os" + "time" +) + +type PublicKey interface { + Equal(x crypto.PublicKey) bool +} + +// PublicJWK represents a parsed public key (RSA or ECDSA) +type PublicJWK struct { + PublicKey + KID string + Use string +} + +// PublicJWKJSON represents a JSON Web Key as defined in the provided code +type PublicJWKJSON struct { + Kty string `json:"kty"` + KID string `json:"kid"` + N string `json:"n,omitempty"` // RSA modulus + E string `json:"e,omitempty"` // RSA exponent + Crv string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` + Use string `json:"use,omitempty"` +} + +type JWKsJSON struct { + Keys []PublicJWKJSON `json:"keys"` +} + +func UnmarshalPublicJWKs(data []byte) ([]PublicJWK, error) { + var jwks JWKsJSON + if err := json.Unmarshal(data, &jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err) + } + + pubkeys, err := DecodePublicJWKsJSON(jwks) + if err != nil { + return nil, err + } + + return pubkeys, nil +} + +func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) { + var jwks JWKsJSON + + if err := json.NewDecoder(r).Decode(&jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err) + } + + pubkeys, err := DecodePublicJWKsJSON(jwks) + if err != nil { + return nil, err + } + + return pubkeys, nil +} + +// DecodePublicJWKsJSON parses JWKS from a Reader +func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) { + // Process keys + var publicKeys []PublicJWK + for _, jwk := range jwks.Keys { + publicKey, err := DecodePublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse public jwk '%s': %w", jwk.KID, err) + } + publicKeys = append(publicKeys, *publicKey) + } + + if len(publicKeys) == 0 { + return nil, fmt.Errorf("no valid RSA or ECDSA keys found") + } + + return publicKeys, nil +} + +// DecodePublicJWK parses JWKS from a Reader +func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK, error) { + switch jwk.Kty { + case "RSA": + key, err := decodeRSAPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse RSA key '%s': %w", jwk.KID, err) + } + // Ensure RSA key meets minimum size requirement + if key.Size() < 128 { // 1024 bits / 8 = 128 bytes + return nil, fmt.Errorf("RSA key '%s' too small: %d bytes", jwk.KID, key.Size()) + } + return &PublicJWK{PublicKey: key, KID: jwk.KID, Use: jwk.Use}, nil + + case "EC": + key, err := decodeECDSAPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse EC key '%s': %w", jwk.KID, err) + } + return &PublicJWK{KID: jwk.KID, PublicKey: key, Use: jwk.Use}, nil + + default: + return nil, fmt.Errorf("failed to parse unknown key type '%s': %s", jwk.Kty, jwk.KID) + } +} + +// ReadPublicJWKs reads and parses JWKS from a file +func ReadPublicJWKs(filePath string) ([]PublicJWK, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open JWKS file '%s': %w", filePath, err) + } + defer func() { _ = file.Close() }() + + return DecodePublicJWKs(file) +} + +// FetchPublicJWKs retrieves and parses JWKS from a given URL +func FetchPublicJWKs(url string) ([]PublicJWK, error) { + // Set up HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // Make HTTP request + resp, err := client.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // Check response status + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return DecodePublicJWKs(resp.Body) +} + +// decodeRSAPublicJWK parses an RSA public key from a JWK +func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) { + n, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, fmt.Errorf("invalid RSA modulus: %w", err) + } + e, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, fmt.Errorf("invalid RSA exponent: %w", err) + } + + // Convert exponent to int + eInt := new(big.Int).SetBytes(e).Int64() + if eInt > int64(^uint(0)>>1) || eInt < 0 { + return nil, fmt.Errorf("RSA exponent too large or negative") + } + + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(eInt), + }, nil +} + +// decodeECDSAPublicJWK parses an ECDSA public key from a JWK +func decodeECDSAPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) { + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("invalid ECDSA X: %w", err) + } + y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, fmt.Errorf("invalid ECDSA Y: %w", err) + } + + var curve elliptic.Curve + switch jwk.Crv { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return nil, fmt.Errorf("unsupported ECDSA curve: %s", jwk.Crv) + } + + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + }, nil +} diff --git a/auth/jwt/pub_test.go b/auth/jwt/pub_test.go new file mode 100644 index 0000000..2069ffc --- /dev/null +++ b/auth/jwt/pub_test.go @@ -0,0 +1,51 @@ +package jwt + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "encoding/base64" + "math/big" + "testing" +) + +// TestDecodeJWKsJSON tests parsing a specific set of ECDSA P-256 JWKS +func TestDecodeJWKJSON(t *testing.T) { + // Create a temporary file with the test JWKS + kid := "KGx1KSmDRd_dwuwmZmWiEsl9Dh4c5dQtFLLtTl-UvlI" + jwkX := "WVBcjUpllgeGbGavZ9Bbq4ps3Zk73mgRRPpbfebkC3U" + jwkY := "aTmrRia2eiJsJwzuj7DIUVmMVGrjEzQJkxxiQMgVLOw" + jwkUse := "sig" + jwksJSON := []byte(`{"keys":[{"kty":"EC","crv":"P-256","x":"` + jwkX + `","y":"` + jwkY + `","kid":"` + kid + `","use":"` + jwkUse + `"}]}`) + + // Decode from bytes to JSON to Public JWKs + keys, err := UnmarshalPublicJWKs(jwksJSON) + if err != nil { + t.Fatalf("ReadJWKs failed: %v", err) + } + + // Verify results + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + key := keys[0] + if key.KID != kid { + t.Errorf("Expected KID '%s', got '%s'", kid, key.KID) + } + if key.Use != jwkUse { + t.Errorf("Expected Use 'sig', got '%s'", key.Use) + } + + expectedX, _ := base64.RawURLEncoding.DecodeString(jwkX) + expectedY, _ := base64.RawURLEncoding.DecodeString(jwkY) + + // Verify Equal method + sameKey := &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(expectedX), + Y: new(big.Int).SetBytes(expectedY), + } + if !key.Equal(sameKey) { + t.Errorf("Equal method failed: key should equal itself") + } +}