From 1f0b36fc6ddcbf05569970fd5ed17969044c9a33 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Thu, 12 Mar 2026 19:40:18 -0600 Subject: [PATCH] feat(auth/ajwt): add first-principles JWT/JWS/JWK package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Design goals from first principles: - JWS holds only parsed structure (header, payload, sig) — no Claims interface, no Verified flag. Removes footguns from the simpler packages. - Issuer owns key management and verification. Verify does key lookup by kid, sig verification, and iss claim check — in that order, so sig is always authenticated before any payload data is trusted. - ValidateParams is a stable config object with Validate(StandardClaims, time.Time) as a method. Time is passed at the call site, not stored in the params struct, so the same config object can be reused across requests. - UnmarshalClaims(v any) accepts any type — no Claims interface to implement. Custom validation is a plain function call, not a method satisfying an interface. - Sign uses crypto.Signer, supporting ES256/ES384/ES512 (ECDSA), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519, RFC 8037). - PublicJWK uses crypto.PublicKey (not generics) since JWKS returns heterogeneous key types at runtime. Typed accessors ECDSA(), RSA(), and EdDSA() replace TypedKeys[K] filtering. - JWKS parsing handles kty: "EC", "RSA", and "OKP" (Ed25519). 10 tests: ES256/RS256/EdDSA round trips, custom validation, wrong key, unknown kid, iss mismatch, tampered alg, PublicJWK accessors, JWKS JSON. --- auth/ajwt/go.mod | 3 + auth/ajwt/jwt.go | 641 ++++++++++++++++++++++++++++++++++++++++++ auth/ajwt/jwt_test.go | 410 +++++++++++++++++++++++++++ auth/ajwt/pub.go | 235 ++++++++++++++++ 4 files changed, 1289 insertions(+) create mode 100644 auth/ajwt/go.mod create mode 100644 auth/ajwt/jwt.go create mode 100644 auth/ajwt/jwt_test.go create mode 100644 auth/ajwt/pub.go diff --git a/auth/ajwt/go.mod b/auth/ajwt/go.mod new file mode 100644 index 0000000..af286bc --- /dev/null +++ b/auth/ajwt/go.mod @@ -0,0 +1,3 @@ +module github.com/therootcompany/golib/auth/ajwt + +go 1.24.0 diff --git a/auth/ajwt/jwt.go b/auth/ajwt/jwt.go new file mode 100644 index 0000000..1f77d47 --- /dev/null +++ b/auth/ajwt/jwt.go @@ -0,0 +1,641 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +// Package ajwt is a lightweight JWT/JWS/JWK library designed from first +// principles: +// +// - [JWS] is a parsed structure only — no Claims interface, no Verified flag. +// - [Issuer] owns key management and signature verification, centralizing +// the key lookup → sig verify → iss check sequence. +// - [ValidateParams] is a stable config object; time is passed at the call +// site so the same params can be reused across requests. +// - [JWS.UnmarshalClaims] accepts any type — no interface to implement. +// - [JWS.Sign] uses [crypto.Signer] for ES256 (P-256), ES384 (P-384), +// ES512 (P-521), RS256 (RSA PKCS#1 v1.5), and EdDSA (Ed25519/RFC 8037). +// +// Typical usage: +// +// // At startup: +// iss := ajwt.NewIssuer("https://accounts.example.com") +// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true} +// if err := iss.FetchKeys(ctx); err != nil { ... } +// +// // Per request: +// jws, err := ajwt.Decode(tokenStr) +// if err := iss.Verify(jws); err != nil { ... } // sig + iss check +// var claims AppClaims +// if err := jws.UnmarshalClaims(&claims); err != nil { ... } +// if errs, err := iss.Params.Validate(claims.StandardClaims, time.Now()); err != nil { ... } +package ajwt + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "encoding/asn1" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "slices" + "strings" + "time" +) + +// JWS is a decoded JSON Web Signature / JWT. +// +// It holds only the parsed structure — header, raw base64url fields, and +// decoded signature bytes. It carries no Claims interface and no Verified flag; +// use [Issuer.Verify] to authenticate the token and [JWS.UnmarshalClaims] to +// decode the payload into a typed struct. +type JWS struct { + Protected string // base64url-encoded header + Header StandardHeader + Payload string // base64url-encoded claims + Signature []byte +} + +// StandardHeader holds the standard JOSE header fields. +type StandardHeader struct { + Alg string `json:"alg"` + Kid string `json:"kid"` + Typ string `json:"typ"` +} + +// StandardClaims holds the registered JWT claim names defined in RFC 7519 +// and extended by OpenID Connect Core. +// +// Embed StandardClaims in your own claims struct: +// +// type AppClaims struct { +// ajwt.StandardClaims +// Email string `json:"email"` +// } +type StandardClaims struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Aud string `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + AuthTime int64 `json:"auth_time"` + Nonce string `json:"nonce,omitempty"` + Amr []string `json:"amr"` + Azp string `json:"azp,omitempty"` + Jti string `json:"jti"` +} + +// Decode parses a compact JWT string (header.payload.signature) into a JWS. +// +// It does not unmarshal the claims payload — call [JWS.UnmarshalClaims] after +// [Issuer.Verify] to populate a typed claims struct. +func Decode(tokenStr string) (*JWS, error) { + parts := strings.Split(tokenStr, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT format") + } + + var jws JWS + jws.Protected, jws.Payload = parts[0], parts[1] + + header, err := base64.RawURLEncoding.DecodeString(jws.Protected) + if err != nil { + return nil, fmt.Errorf("invalid header encoding: %v", err) + } + if err := json.Unmarshal(header, &jws.Header); err != nil { + return nil, fmt.Errorf("invalid header JSON: %v", err) + } + + jws.Signature, err = base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("invalid signature encoding: %v", err) + } + + return &jws, nil +} + +// UnmarshalClaims decodes the JWT payload into v. +// +// v must be a pointer to a struct (e.g. *AppClaims). Always call +// [Issuer.Verify] before UnmarshalClaims to ensure the signature is +// authenticated before trusting the payload. +func (jws *JWS) UnmarshalClaims(v any) error { + payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return fmt.Errorf("invalid claims encoding: %v", err) + } + if err := json.Unmarshal(payload, v); err != nil { + return fmt.Errorf("invalid claims JSON: %v", err) + } + return nil +} + +// NewJWSFromClaims creates an unsigned JWS from the provided claims. +// +// kid identifies the signing key. The "alg" header field is set automatically +// when [JWS.Sign] is called. Call [JWS.Encode] to produce the compact JWT +// string after signing. +func NewJWSFromClaims(claims any, kid string) (*JWS, error) { + var jws JWS + + jws.Header = StandardHeader{ + // Alg is set by Sign based on the key type. + Kid: kid, + Typ: "JWT", + } + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + claimsJSON, err := json.Marshal(claims) + if err != nil { + return nil, fmt.Errorf("marshal claims: %w", err) + } + jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON) + + return &jws, nil +} + +// Sign signs the JWS in-place using the provided [crypto.Signer]. +// It sets the "alg" header field based on the public key type and re-encodes +// the protected header before signing, so the signed input is always +// consistent with the token header. +// +// Supported algorithms (inferred from key type): +// - *ecdsa.PublicKey P-256 → ES256 (SHA-256, raw r||s) +// - *ecdsa.PublicKey P-384 → ES384 (SHA-384, raw r||s) +// - *ecdsa.PublicKey P-521 → ES512 (SHA-512, raw r||s) +// - *rsa.PublicKey → RS256 (PKCS#1 v1.5 + SHA-256) +// - ed25519.PublicKey → EdDSA (Ed25519, RFC 8037) +func (jws *JWS) Sign(key crypto.Signer) ([]byte, error) { + switch pub := key.Public().(type) { + case *ecdsa.PublicKey: + alg, h, err := algForECKey(pub) + if err != nil { + return nil, err + } + jws.Header.Alg = alg + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + digest := digestFor(h, jws.Protected+"."+jws.Payload) + // crypto.Signer returns ASN.1 DER for ECDSA; convert to raw r||s for JWS. + derSig, err := key.Sign(rand.Reader, digest, h) + if err != nil { + return nil, fmt.Errorf("Sign %s: %w", alg, err) + } + jws.Signature, err = ecdsaDERToRaw(derSig, pub.Curve) + return jws.Signature, err + + case *rsa.PublicKey: + jws.Header.Alg = "RS256" + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + digest := digestFor(crypto.SHA256, jws.Protected+"."+jws.Payload) + // crypto.Signer returns raw PKCS#1 v1.5 bytes for RSA; use directly. + var err error + jws.Signature, err = key.Sign(rand.Reader, digest, crypto.SHA256) + return jws.Signature, err + + case ed25519.PublicKey: + jws.Header.Alg = "EdDSA" + headerJSON, _ := json.Marshal(jws.Header) + jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON) + + // Ed25519 signs the raw message with no pre-hashing; pass crypto.Hash(0). + signingInput := jws.Protected + "." + jws.Payload + var err error + jws.Signature, err = key.Sign(rand.Reader, []byte(signingInput), crypto.Hash(0)) + return jws.Signature, err + + default: + return nil, fmt.Errorf( + "Sign: unsupported public key type %T (supported: *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey)", + key.Public(), + ) + } +} + +// Encode produces the compact JWT string (header.payload.signature). +func (jws *JWS) Encode() string { + return jws.Protected + "." + jws.Payload + "." + base64.RawURLEncoding.EncodeToString(jws.Signature) +} + +// ValidateParams holds claim validation configuration. +// +// Configure once at startup; call [ValidateParams.Validate] per request, +// passing the current time. This keeps the config stable and makes the +// time dependency explicit at the call site. +// +// https://openid.net/specs/openid-connect-core-1_0.html#IDToken +type ValidateParams struct { + IgnoreIss bool + Iss string + IgnoreSub bool + Sub string + IgnoreAud bool + Aud string + IgnoreExp bool + IgnoreJti bool + Jti string + IgnoreIat bool + IgnoreAuthTime bool + MaxAge time.Duration + IgnoreNonce bool + Nonce string + IgnoreAmr bool + RequiredAmrs []string + IgnoreAzp bool + Azp string +} + +// Validate checks the standard JWT/OIDC claim fields against this config. +// +// now is typically time.Now() — passing it explicitly keeps the config stable +// across requests and avoids hidden time dependencies in the params struct. +func (p ValidateParams) Validate(claims StandardClaims, now time.Time) ([]string, error) { + return ValidateStandardClaims(claims, p, now) +} + +// ValidateStandardClaims checks the registered JWT/OIDC claim fields against params. +// +// Exported so callers can use it directly without a [ValidateParams] receiver: +// +// errs, err := ajwt.ValidateStandardClaims(claims.StandardClaims, params, time.Now()) +func ValidateStandardClaims(claims StandardClaims, params ValidateParams, now time.Time) ([]string, error) { + var errs []string + + // Required to exist and match + if len(params.Iss) > 0 || !params.IgnoreIss { + if len(claims.Iss) == 0 { + errs = append(errs, "missing or malformed 'iss' (token issuer, identifier for public key)") + } else if claims.Iss != params.Iss { + errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", claims.Iss, params.Iss)) + } + } + + // Required to exist, optional match + if len(claims.Sub) == 0 { + if !params.IgnoreSub { + errs = append(errs, "missing or malformed 'sub' (subject, typically pairwise user id)") + } + } else if len(params.Sub) > 0 { + if params.Sub != claims.Sub { + errs = append(errs, fmt.Sprintf("'sub' (subject) mismatch: got %s, expected %s", claims.Sub, params.Sub)) + } + } + + // Required to exist and match + if len(params.Aud) > 0 || !params.IgnoreAud { + if len(claims.Aud) == 0 { + errs = append(errs, "missing or malformed 'aud' (audience receiving token)") + } else if claims.Aud != params.Aud { + errs = append(errs, fmt.Sprintf("'aud' (audience) mismatch: got %s, expected %s", claims.Aud, params.Aud)) + } + } + + // Required to exist and not be in the past + if !params.IgnoreExp { + if claims.Exp <= 0 { + errs = append(errs, "missing or malformed 'exp' (expiration date in seconds)") + } else if claims.Exp < now.Unix() { + duration := now.Sub(time.Unix(claims.Exp, 0)) + expTime := time.Unix(claims.Exp, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime)) + } + } + + // Required to exist and not be in the future + if !params.IgnoreIat { + if claims.Iat <= 0 { + errs = append(errs, "missing or malformed 'iat' (issued at, when token was signed)") + } else if claims.Iat > now.Unix() { + duration := time.Unix(claims.Iat, 0).Sub(now) + iatTime := time.Unix(claims.Iat, 0).Format("2006-01-02 15:04:05 MST") + errs = append(errs, fmt.Sprintf("'iat' (issued at) is %s in the future (%s)", formatDuration(duration), iatTime)) + } + } + + // Should exist, in the past, with optional max age + if params.MaxAge > 0 || !params.IgnoreAuthTime { + if claims.AuthTime == 0 { + errs = append(errs, "missing or malformed 'auth_time' (time of real-world user authentication, in seconds)") + } else { + authTime := time.Unix(claims.AuthTime, 0) + authTimeStr := authTime.Format("2006-01-02 15:04:05 MST") + age := now.Sub(authTime) + diff := age - params.MaxAge + if claims.AuthTime > now.Unix() { + fromNow := time.Unix(claims.AuthTime, 0).Sub(now) + errs = append(errs, fmt.Sprintf( + "'auth_time' of %s is %s in the future (server time %s)", + authTimeStr, formatDuration(fromNow), now.Format("2006-01-02 15:04:05 MST")), + ) + } else if params.MaxAge > 0 && age > params.MaxAge { + errs = append(errs, fmt.Sprintf( + "'auth_time' of %s is %s old, exceeding max age %s by %s", + authTimeStr, formatDuration(age), formatDuration(params.MaxAge), formatDuration(diff)), + ) + } + } + } + + // Optional exact match + if params.Jti != claims.Jti { + if len(params.Jti) > 0 { + errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", claims.Jti, params.Jti)) + } else if !params.IgnoreJti { + errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", claims.Jti)) + } + } + + // Optional exact match + if params.Nonce != claims.Nonce { + if len(params.Nonce) > 0 { + errs = append(errs, fmt.Sprintf("'nonce' mismatch: got %s, expected %s", claims.Nonce, params.Nonce)) + } else if !params.IgnoreNonce { + errs = append(errs, fmt.Sprintf("unchecked 'nonce': %s", claims.Nonce)) + } + } + + // Should exist, optional required-set check + if !params.IgnoreAmr { + if len(claims.Amr) == 0 { + errs = append(errs, "missing or malformed 'amr' (authorization methods, as json list)") + } else if len(params.RequiredAmrs) > 0 { + for _, required := range params.RequiredAmrs { + if !slices.Contains(claims.Amr, required) { + errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr'", required)) + } + } + } + } + + // Optional, match if present + if params.Azp != claims.Azp { + if len(params.Azp) > 0 { + errs = append(errs, fmt.Sprintf("'azp' (authorized party) mismatch: got %s, expected %s", claims.Azp, params.Azp)) + } else if !params.IgnoreAzp { + errs = append(errs, fmt.Sprintf("unchecked 'azp' (authorized party): %s", claims.Azp)) + } + } + + if len(errs) > 0 { + timeInfo := fmt.Sprintf("info: server time is %s", now.Format("2006-01-02 15:04:05 MST")) + if loc, err := time.LoadLocation("Local"); err == nil { + timeInfo += fmt.Sprintf(" %s", loc) + } + errs = append(errs, timeInfo) + return errs, fmt.Errorf("has errors") + } + return nil, nil +} + +// Issuer holds public keys and validation config for a trusted token issuer. +// +// [Issuer.FetchKeys] loads keys from the issuer's JWKS endpoint. +// [Issuer.SetKeys] injects keys directly (useful in tests). +// [Issuer.Verify] authenticates the token: key lookup → sig verify → iss check. +// +// Typical setup: +// +// iss := ajwt.NewIssuer("https://accounts.example.com") +// iss.Params = ajwt.ValidateParams{Aud: "my-app", IgnoreIss: true} +// if err := iss.FetchKeys(ctx); err != nil { ... } +type Issuer struct { + URL string + JWKsURL string // optional; defaults to URL + "/.well-known/jwks.json" + Params ValidateParams + keys map[string]crypto.PublicKey // kid → key +} + +// NewIssuer creates an Issuer for the given base URL. +func NewIssuer(url string) *Issuer { + return &Issuer{ + URL: url, + keys: make(map[string]crypto.PublicKey), + } +} + +// SetKeys stores public keys by their KID, replacing any previously stored keys. +// Useful for injecting keys in tests without an HTTP round-trip. +func (iss *Issuer) SetKeys(keys []PublicJWK) { + m := make(map[string]crypto.PublicKey, len(keys)) + for _, k := range keys { + m[k.KID] = k.Key + } + iss.keys = m +} + +// FetchKeys retrieves and stores the JWKS from the issuer's endpoint. +// If JWKsURL is empty, it defaults to URL + "/.well-known/jwks.json". +func (iss *Issuer) FetchKeys(ctx context.Context) error { + url := iss.JWKsURL + if url == "" { + url = strings.TrimRight(iss.URL, "/") + "/.well-known/jwks.json" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("fetch JWKS: %w", err) + } + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("fetch JWKS: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("fetch JWKS: unexpected status %d", resp.StatusCode) + } + + keys, err := DecodePublicJWKs(resp.Body) + if err != nil { + return fmt.Errorf("parse JWKS: %w", err) + } + + iss.SetKeys(keys) + return nil +} + +// Verify authenticates jws against this issuer: +// 1. Looks up the signing key by jws.Header.Kid. +// 2. Verifies the signature before trusting any payload data. +// 3. Checks that the token's "iss" claim matches iss.URL. +// +// Call [JWS.UnmarshalClaims] after Verify to safely decode the payload into a +// typed struct, then [ValidateParams.Validate] to check claim values. +func (iss *Issuer) Verify(jws *JWS) error { + if jws.Header.Kid == "" { + return fmt.Errorf("missing 'kid' header") + } + key, ok := iss.keys[jws.Header.Kid] + if !ok { + return fmt.Errorf("unknown kid: %q", jws.Header.Kid) + } + + signingInput := jws.Protected + "." + jws.Payload + if err := verifyWith(signingInput, jws.Signature, jws.Header.Alg, key); err != nil { + return fmt.Errorf("signature verification failed: %w", err) + } + + // Signature verified — now safe to inspect the payload. + payload, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return fmt.Errorf("invalid claims encoding: %w", err) + } + var partial struct { + Iss string `json:"iss"` + } + if err := json.Unmarshal(payload, &partial); err != nil { + return fmt.Errorf("invalid claims JSON: %w", err) + } + if partial.Iss != iss.URL { + return fmt.Errorf("iss mismatch: got %q, want %q", partial.Iss, iss.URL) + } + + return nil +} + +// verifyWith checks a JWS signature using the given algorithm and public key. +// Returns nil on success, a descriptive error on failure. +func verifyWith(signingInput string, sig []byte, alg string, key crypto.PublicKey) error { + switch alg { + case "ES256", "ES384", "ES512": + k, ok := key.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("alg %s requires *ecdsa.PublicKey, got %T", alg, key) + } + expectedAlg, h, err := algForECKey(k) + if err != nil { + return err + } + if expectedAlg != alg { + return fmt.Errorf("key curve mismatch: key is %s, token alg is %s", expectedAlg, alg) + } + byteLen := (k.Curve.Params().BitSize + 7) / 8 + if len(sig) != 2*byteLen { + return fmt.Errorf("invalid %s signature length: got %d, want %d", alg, len(sig), 2*byteLen) + } + digest := digestFor(h, signingInput) + r := new(big.Int).SetBytes(sig[:byteLen]) + s := new(big.Int).SetBytes(sig[byteLen:]) + if !ecdsa.Verify(k, digest, r, s) { + return fmt.Errorf("%s signature invalid", alg) + } + return nil + + case "RS256": + k, ok := key.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("alg RS256 requires *rsa.PublicKey, got %T", key) + } + digest := digestFor(crypto.SHA256, signingInput) + if err := rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, sig); err != nil { + return fmt.Errorf("RS256 signature invalid: %w", err) + } + return nil + + case "EdDSA": + k, ok := key.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("alg EdDSA requires ed25519.PublicKey, got %T", key) + } + if !ed25519.Verify(k, []byte(signingInput), sig) { + return fmt.Errorf("EdDSA signature invalid") + } + return nil + + default: + return fmt.Errorf("unsupported alg: %q", alg) + } +} + +// --- Internal helpers --- + +func algForECKey(pub *ecdsa.PublicKey) (alg string, h crypto.Hash, err error) { + switch pub.Curve { + case elliptic.P256(): + return "ES256", crypto.SHA256, nil + case elliptic.P384(): + return "ES384", crypto.SHA384, nil + case elliptic.P521(): + return "ES512", crypto.SHA512, nil + default: + return "", 0, fmt.Errorf("unsupported EC curve: %s", pub.Curve.Params().Name) + } +} + +func digestFor(h crypto.Hash, data string) []byte { + switch h { + case crypto.SHA256: + d := sha256.Sum256([]byte(data)) + return d[:] + case crypto.SHA384: + d := sha512.Sum384([]byte(data)) + return d[:] + case crypto.SHA512: + d := sha512.Sum512([]byte(data)) + return d[:] + default: + panic(fmt.Sprintf("ajwt: unsupported hash %v", h)) + } +} + +func ecdsaDERToRaw(der []byte, curve elliptic.Curve) ([]byte, error) { + var sig struct{ R, S *big.Int } + if _, err := asn1.Unmarshal(der, &sig); err != nil { + return nil, fmt.Errorf("ecdsaDERToRaw: %w", err) + } + byteLen := (curve.Params().BitSize + 7) / 8 + out := make([]byte, 2*byteLen) + sig.R.FillBytes(out[:byteLen]) + sig.S.FillBytes(out[byteLen:]) + return out, nil +} + +func formatDuration(d time.Duration) string { + if d < 0 { + d = -d + } + days := int(d / (24 * time.Hour)) + d -= time.Duration(days) * 24 * time.Hour + hours := int(d / time.Hour) + d -= time.Duration(hours) * time.Hour + minutes := int(d / time.Minute) + d -= time.Duration(minutes) * time.Minute + seconds := int(d / time.Second) + + var parts []string + if days > 0 { + parts = append(parts, fmt.Sprintf("%dd", days)) + } + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + if seconds == 0 || len(parts) == 0 { + d -= time.Duration(seconds) * time.Second + millis := int(d / time.Millisecond) + parts = append(parts, fmt.Sprintf("%dms", millis)) + } + + return strings.Join(parts, " ") +} diff --git a/auth/ajwt/jwt_test.go b/auth/ajwt/jwt_test.go new file mode 100644 index 0000000..e8319a2 --- /dev/null +++ b/auth/ajwt/jwt_test.go @@ -0,0 +1,410 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package ajwt_test + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "fmt" + "testing" + "time" + + "github.com/therootcompany/golib/auth/ajwt" +) + +// AppClaims embeds StandardClaims and adds application-specific fields. +// +// Unlike embeddedjwt and bestjwt, AppClaims does NOT implement a Validate +// interface — there is none. Validation is explicit: call +// ValidateStandardClaims or ValidateParams.Validate at the call site. +type AppClaims struct { + ajwt.StandardClaims + Email string `json:"email"` + Roles []string `json:"roles"` +} + +// validateAppClaims is a plain function — not a method satisfying an interface. +// Custom validation logic lives here, calling ValidateStandardClaims directly. +func validateAppClaims(c AppClaims, params ajwt.ValidateParams, now time.Time) ([]string, error) { + errs, _ := ajwt.ValidateStandardClaims(c.StandardClaims, params, now) + if c.Email == "" { + errs = append(errs, "missing email claim") + } + if len(errs) > 0 { + return errs, fmt.Errorf("has errors") + } + return nil, nil +} + +func goodClaims() AppClaims { + now := time.Now() + return AppClaims{ + StandardClaims: ajwt.StandardClaims{ + Iss: "https://example.com", + Sub: "user123", + Aud: "myapp", + Exp: now.Add(time.Hour).Unix(), + Iat: now.Unix(), + AuthTime: now.Unix(), + Amr: []string{"pwd"}, + Jti: "abc123", + Azp: "myapp", + Nonce: "nonce1", + }, + Email: "user@example.com", + Roles: []string{"admin"}, + } +} + +// goodParams configures the validator. Iss is omitted because Issuer.Verify +// already enforces the iss claim — no need to check it twice. +func goodParams() ajwt.ValidateParams { + return ajwt.ValidateParams{ + IgnoreIss: true, // Issuer.Verify handles iss + Sub: "user123", + Aud: "myapp", + Jti: "abc123", + Nonce: "nonce1", + Azp: "myapp", + RequiredAmrs: []string{"pwd"}, + } +} + +func goodIssuer(pub ajwt.PublicJWK) *ajwt.Issuer { + iss := ajwt.NewIssuer("https://example.com") + iss.Params = goodParams() + iss.SetKeys([]ajwt.PublicJWK{pub}) + return iss +} + +// TestRoundTrip is the primary happy path using ES256. +// It demonstrates the full Issuer-based flow: +// +// Decode → Issuer.Verify → UnmarshalClaims → Params.Validate +// +// No Claims interface, no Verified flag, no type assertions on jws. +func TestRoundTrip(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := ajwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "ES256" { + t.Fatalf("expected ES256, got %s", jws.Header.Alg) + } + + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) + + jws2, err := ajwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err = iss.Verify(jws2); err != nil { + t.Fatalf("Verify failed: %v", err) + } + + var decoded AppClaims + if err = jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatal(err) + } + if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { + t.Fatalf("validation failed: %v", errs) + } + // Direct field access — no type assertion needed, no jws.Claims interface. + if decoded.Email != claims.Email { + t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) + } +} + +// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256. +func TestRoundTripRS256(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := ajwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "RS256" { + t.Fatalf("expected RS256, got %s", jws.Header.Alg) + } + + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "key-1"}) + + jws2, err := ajwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err = iss.Verify(jws2); err != nil { + t.Fatalf("Verify failed: %v", err) + } + + var decoded AppClaims + if err = jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatal(err) + } + if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { + t.Fatalf("validation failed: %v", errs) + } +} + +// TestRoundTripEdDSA exercises Ed25519 / EdDSA (RFC 8037). +func TestRoundTripEdDSA(t *testing.T) { + pubKeyBytes, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := ajwt.NewJWSFromClaims(&claims, "key-1") + if err != nil { + t.Fatal(err) + } + + if _, err = jws.Sign(privKey); err != nil { + t.Fatal(err) + } + if jws.Header.Alg != "EdDSA" { + t.Fatalf("expected EdDSA, got %s", jws.Header.Alg) + } + + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: pubKeyBytes, KID: "key-1"}) + + jws2, err := ajwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err = iss.Verify(jws2); err != nil { + t.Fatalf("Verify failed: %v", err) + } + + var decoded AppClaims + if err = jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatal(err) + } + if errs, err := iss.Params.Validate(decoded.StandardClaims, time.Now()); err != nil { + t.Fatalf("validation failed: %v", errs) + } +} + +// TestCustomValidation demonstrates custom claim validation without any interface. +// The caller owns the validation logic and calls ValidateStandardClaims directly. +func TestCustomValidation(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Token with empty Email — our custom validator should reject it. + claims := goodClaims() + claims.Email = "" + jws, _ := ajwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) + jws2, _ := ajwt.Decode(token) + _ = iss.Verify(jws2) + var decoded AppClaims + _ = jws2.UnmarshalClaims(&decoded) + + errs, err := validateAppClaims(decoded, goodParams(), time.Now()) + if err == nil { + t.Fatal("expected validation to fail: email is empty") + } + found := false + for _, e := range errs { + if e == "missing email claim" { + found = true + } + } + if !found { + t.Fatalf("expected 'missing email claim' in errors: %v", errs) + } +} + +// TestIssuerWrongKey confirms that a different key's public key is rejected. +func TestIssuerWrongKey(t *testing.T) { + signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := ajwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(signingKey) + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: &wrongKey.PublicKey, KID: "k"}) + jws2, _ := ajwt.Decode(token) + + if err := iss.Verify(jws2); err == nil { + t.Fatal("expected Verify to fail with wrong key") + } +} + +// TestIssuerUnknownKid confirms that an unknown kid is rejected. +func TestIssuerUnknownKid(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := ajwt.NewJWSFromClaims(&claims, "unknown-kid") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "known-kid"}) + jws2, _ := ajwt.Decode(token) + + if err := iss.Verify(jws2); err == nil { + t.Fatal("expected Verify to fail for unknown kid") + } +} + +// TestIssuerIssMismatch confirms that a token with a mismatched iss is rejected +// even if the signature is valid. +func TestIssuerIssMismatch(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + claims.Iss = "https://evil.example.com" // not the issuer URL + jws, _ := ajwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + // Issuer expects "https://example.com" but token says "https://evil.example.com" + iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) + jws2, _ := ajwt.Decode(token) + + if err := iss.Verify(jws2); err == nil { + t.Fatal("expected Verify to fail: iss mismatch") + } +} + +// TestVerifyTamperedAlg confirms that a tampered alg header is rejected. +func TestVerifyTamperedAlg(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + claims := goodClaims() + jws, _ := ajwt.NewJWSFromClaims(&claims, "k") + _, _ = jws.Sign(privKey) + token := jws.Encode() + + iss := goodIssuer(ajwt.PublicJWK{Key: &privKey.PublicKey, KID: "k"}) + jws2, _ := ajwt.Decode(token) + jws2.Header.Alg = "none" // tamper + + if err := iss.Verify(jws2); err == nil { + t.Fatal("expected Verify to fail for tampered alg") + } +} + +// TestPublicJWKAccessors confirms the ECDSA, RSA, and EdDSA typed accessor methods. +func TestPublicJWKAccessors(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + edPub, _, _ := ed25519.GenerateKey(rand.Reader) + + ecJWK := ajwt.PublicJWK{Key: &ecKey.PublicKey, KID: "ec-1"} + rsaJWK := ajwt.PublicJWK{Key: &rsaKey.PublicKey, KID: "rsa-1"} + edJWK := ajwt.PublicJWK{Key: edPub, KID: "ed-1"} + + if k, ok := ecJWK.ECDSA(); !ok || k == nil { + t.Error("expected ECDSA() to succeed for EC key") + } + if _, ok := ecJWK.RSA(); ok { + t.Error("expected RSA() to fail for EC key") + } + if _, ok := ecJWK.EdDSA(); ok { + t.Error("expected EdDSA() to fail for EC key") + } + + if k, ok := rsaJWK.RSA(); !ok || k == nil { + t.Error("expected RSA() to succeed for RSA key") + } + if _, ok := rsaJWK.ECDSA(); ok { + t.Error("expected ECDSA() to fail for RSA key") + } + if _, ok := rsaJWK.EdDSA(); ok { + t.Error("expected EdDSA() to fail for RSA key") + } + + if k, ok := edJWK.EdDSA(); !ok || k == nil { + t.Error("expected EdDSA() to succeed for Ed25519 key") + } + if _, ok := edJWK.ECDSA(); ok { + t.Error("expected ECDSA() to fail for Ed25519 key") + } + if _, ok := edJWK.RSA(); ok { + t.Error("expected RSA() to fail for Ed25519 key") + } +} + +// TestDecodePublicJWKJSON verifies JWKS JSON parsing with real base64url-encoded +// key material from RFC 7517 / OIDC examples. +func TestDecodePublicJWKJSON(t *testing.T) { + jwksJSON := []byte(`{"keys":[ + {"kty":"EC","crv":"P-256", + "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "kid":"ec-256","use":"sig"}, + {"kty":"RSA", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e":"AQAB","kid":"rsa-2048","use":"sig"} + ]}`) + + keys, err := ajwt.UnmarshalPublicJWKs(jwksJSON) + if err != nil { + t.Fatal(err) + } + if len(keys) != 2 { + t.Fatalf("expected 2 keys, got %d", len(keys)) + } + + var ecCount, rsaCount int + for _, k := range keys { + if _, ok := k.ECDSA(); ok { + ecCount++ + if k.KID != "ec-256" { + t.Errorf("unexpected EC kid: %s", k.KID) + } + } + if _, ok := k.RSA(); ok { + rsaCount++ + if k.KID != "rsa-2048" { + t.Errorf("unexpected RSA kid: %s", k.KID) + } + } + } + if ecCount != 1 { + t.Errorf("expected 1 EC key, got %d", ecCount) + } + if rsaCount != 1 { + t.Errorf("expected 1 RSA key, got %d", rsaCount) + } +} diff --git a/auth/ajwt/pub.go b/auth/ajwt/pub.go new file mode 100644 index 0000000..0114556 --- /dev/null +++ b/auth/ajwt/pub.go @@ -0,0 +1,235 @@ +// Copyright 2025 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package ajwt + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "os" + "time" +) + +// PublicJWK wraps a parsed public key with its JWKS metadata. +// +// Key is [crypto.PublicKey] (= any) since a JWKS endpoint returns a +// heterogeneous mix of key types determined at runtime by the "kty" field. +// Use the typed accessor methods [PublicJWK.ECDSA], [PublicJWK.RSA], and +// [PublicJWK.EdDSA] to assert the underlying type without a raw type switch. +type PublicJWK struct { + Key crypto.PublicKey + KID string + Use string +} + +// ECDSA returns the key as *ecdsa.PublicKey if it is one, else (nil, false). +func (k PublicJWK) ECDSA() (*ecdsa.PublicKey, bool) { + key, ok := k.Key.(*ecdsa.PublicKey) + return key, ok +} + +// RSA returns the key as *rsa.PublicKey if it is one, else (nil, false). +func (k PublicJWK) RSA() (*rsa.PublicKey, bool) { + key, ok := k.Key.(*rsa.PublicKey) + return key, ok +} + +// EdDSA returns the key as ed25519.PublicKey if it is one, else (nil, false). +func (k PublicJWK) EdDSA() (ed25519.PublicKey, bool) { + key, ok := k.Key.(ed25519.PublicKey) + return key, ok +} + +// PublicJWKJSON is the JSON representation of a single key in a JWKS document. +type PublicJWKJSON struct { + Kty string `json:"kty"` + KID string `json:"kid"` + Crv string `json:"crv,omitempty"` // EC / OKP curve + X string `json:"x,omitempty"` // EC / OKP public key x (or Ed25519 key bytes) + Y string `json:"y,omitempty"` // EC public key y + N string `json:"n,omitempty"` // RSA modulus + E string `json:"e,omitempty"` // RSA exponent + Use string `json:"use,omitempty"` +} + +// JWKsJSON is the JSON representation of a JWKS document. +type JWKsJSON struct { + Keys []PublicJWKJSON `json:"keys"` +} + +// FetchPublicJWKs retrieves and parses a JWKS document from url. +// +// For issuer-scoped key management with context support, use +// [Issuer.FetchKeys] instead. +func FetchPublicJWKs(url string) ([]PublicJWK, error) { + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return DecodePublicJWKs(resp.Body) +} + +// ReadPublicJWKs reads and parses a JWKS document from a file path. +func ReadPublicJWKs(filePath string) ([]PublicJWK, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open JWKS file %q: %w", filePath, err) + } + defer func() { _ = file.Close() }() + return DecodePublicJWKs(file) +} + +// UnmarshalPublicJWKs parses a JWKS document from raw JSON bytes. +func UnmarshalPublicJWKs(data []byte) ([]PublicJWK, error) { + var jwks JWKsJSON + if err := json.Unmarshal(data, &jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err) + } + return DecodePublicJWKsJSON(jwks) +} + +// DecodePublicJWKs parses a JWKS document from an [io.Reader]. +func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) { + var jwks JWKsJSON + if err := json.NewDecoder(r).Decode(&jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err) + } + return DecodePublicJWKsJSON(jwks) +} + +// DecodePublicJWKsJSON converts a parsed [JWKsJSON] into typed public keys. +func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) { + var keys []PublicJWK + for _, jwk := range jwks.Keys { + key, err := DecodePublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse public jwk %q: %w", jwk.KID, err) + } + keys = append(keys, *key) + } + if len(keys) == 0 { + return nil, fmt.Errorf("no valid keys found in JWKS") + } + return keys, nil +} + +// DecodePublicJWK parses a single [PublicJWKJSON] into a [PublicJWK]. +// +// Supported key types: +// - "RSA" — minimum 1024-bit (RS256) +// - "EC" — P-256, P-384, P-521 (ES256, ES384, ES512) +// - "OKP" — Ed25519 crv (EdDSA / RFC 8037) +func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK, error) { + switch jwk.Kty { + case "RSA": + key, err := decodeRSAPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse RSA key %q: %w", jwk.KID, err) + } + if key.Size() < 128 { // 1024 bits minimum + return nil, fmt.Errorf("RSA key %q too small: %d bytes", jwk.KID, key.Size()) + } + return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil + + case "EC": + key, err := decodeECPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse EC key %q: %w", jwk.KID, err) + } + return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil + + case "OKP": + key, err := decodeOKPPublicJWK(jwk) + if err != nil { + return nil, fmt.Errorf("failed to parse OKP key %q: %w", jwk.KID, err) + } + return &PublicJWK{Key: key, KID: jwk.KID, Use: jwk.Use}, nil + + default: + return nil, fmt.Errorf("unsupported key type %q for kid %q", jwk.Kty, jwk.KID) + } +} + +func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) { + n, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, fmt.Errorf("invalid RSA modulus: %w", err) + } + e, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, fmt.Errorf("invalid RSA exponent: %w", err) + } + + eInt := new(big.Int).SetBytes(e).Int64() + if eInt > int64(^uint(0)>>1) || eInt < 0 { + return nil, fmt.Errorf("RSA exponent too large or negative") + } + + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(eInt), + }, nil +} + +func decodeECPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) { + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("invalid ECDSA X: %w", err) + } + y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, fmt.Errorf("invalid ECDSA Y: %w", err) + } + + var curve elliptic.Curve + switch jwk.Crv { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return nil, fmt.Errorf("unsupported EC curve: %s", jwk.Crv) + } + + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + }, nil +} + +func decodeOKPPublicJWK(jwk PublicJWKJSON) (ed25519.PublicKey, error) { + if jwk.Crv != "Ed25519" { + return nil, fmt.Errorf("unsupported OKP curve: %q (only Ed25519 supported)", jwk.Crv) + } + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("invalid OKP X: %w", err) + } + if len(x) != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid Ed25519 key size: got %d bytes, want %d", len(x), ed25519.PublicKeySize) + } + return ed25519.PublicKey(x), nil +}