mirror of
				https://github.com/therootcompany/golib.git
				synced 2025-10-31 05:02:52 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			204 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			204 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package jwt
 | |
| 
 | |
| import (
 | |
| 	"crypto"
 | |
| 	"crypto/ecdsa"
 | |
| 	"crypto/elliptic"
 | |
| 	"crypto/rsa"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"math/big"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| type PublicKey interface {
 | |
| 	Equal(x crypto.PublicKey) bool
 | |
| }
 | |
| 
 | |
| // PublicJWK represents a parsed public key (RSA or ECDSA)
 | |
| type PublicJWK struct {
 | |
| 	PublicKey
 | |
| 	KID string
 | |
| 	Use string
 | |
| }
 | |
| 
 | |
| // PublicJWKJSON represents a JSON Web Key as defined in the provided code
 | |
| type PublicJWKJSON struct {
 | |
| 	Kty string `json:"kty"`
 | |
| 	KID string `json:"kid"`
 | |
| 	N   string `json:"n,omitempty"` // RSA modulus
 | |
| 	E   string `json:"e,omitempty"` // RSA exponent
 | |
| 	Crv string `json:"crv,omitempty"`
 | |
| 	X   string `json:"x,omitempty"`
 | |
| 	Y   string `json:"y,omitempty"`
 | |
| 	Use string `json:"use,omitempty"`
 | |
| }
 | |
| 
 | |
| type JWKsJSON struct {
 | |
| 	Keys []PublicJWKJSON `json:"keys"`
 | |
| }
 | |
| 
 | |
| func UnmarshalPublicJWKs(data []byte) ([]PublicJWK, error) {
 | |
| 	var jwks JWKsJSON
 | |
| 	if err := json.Unmarshal(data, &jwks); err != nil {
 | |
| 		return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	pubkeys, err := DecodePublicJWKsJSON(jwks)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return pubkeys, nil
 | |
| }
 | |
| 
 | |
| func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
 | |
| 	var jwks JWKsJSON
 | |
| 
 | |
| 	if err := json.NewDecoder(r).Decode(&jwks); err != nil {
 | |
| 		return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	pubkeys, err := DecodePublicJWKsJSON(jwks)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return pubkeys, nil
 | |
| }
 | |
| 
 | |
| // DecodePublicJWKsJSON parses JWKS from a Reader
 | |
| func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
 | |
| 	// Process keys
 | |
| 	var publicKeys []PublicJWK
 | |
| 	for _, jwk := range jwks.Keys {
 | |
| 		publicKey, err := DecodePublicJWK(jwk)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to parse public jwk '%s': %w", jwk.KID, err)
 | |
| 		}
 | |
| 		publicKeys = append(publicKeys, *publicKey)
 | |
| 	}
 | |
| 
 | |
| 	if len(publicKeys) == 0 {
 | |
| 		return nil, fmt.Errorf("no valid RSA or ECDSA keys found")
 | |
| 	}
 | |
| 
 | |
| 	return publicKeys, nil
 | |
| }
 | |
| 
 | |
| // DecodePublicJWK parses JWKS from a Reader
 | |
| func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK, error) {
 | |
| 	switch jwk.Kty {
 | |
| 	case "RSA":
 | |
| 		key, err := decodeRSAPublicJWK(jwk)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to parse RSA key '%s': %w", jwk.KID, err)
 | |
| 		}
 | |
| 		// Ensure RSA key meets minimum size requirement
 | |
| 		if key.Size() < 128 { // 1024 bits / 8 = 128 bytes
 | |
| 			return nil, fmt.Errorf("RSA key '%s' too small: %d bytes", jwk.KID, key.Size())
 | |
| 		}
 | |
| 		return &PublicJWK{PublicKey: key, KID: jwk.KID, Use: jwk.Use}, nil
 | |
| 
 | |
| 	case "EC":
 | |
| 		key, err := decodeECDSAPublicJWK(jwk)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to parse EC key '%s': %w", jwk.KID, err)
 | |
| 		}
 | |
| 		return &PublicJWK{KID: jwk.KID, PublicKey: key, Use: jwk.Use}, nil
 | |
| 
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("failed to parse unknown key type '%s': %s", jwk.Kty, jwk.KID)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ReadPublicJWKs reads and parses JWKS from a file
 | |
| func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
 | |
| 	file, err := os.Open(filePath)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to open JWKS file '%s': %w", filePath, err)
 | |
| 	}
 | |
| 	defer func() { _ = file.Close() }()
 | |
| 
 | |
| 	return DecodePublicJWKs(file)
 | |
| }
 | |
| 
 | |
| // FetchPublicJWKs retrieves and parses JWKS from a given URL
 | |
| func FetchPublicJWKs(url string) ([]PublicJWK, error) {
 | |
| 	// Set up HTTP client with timeout
 | |
| 	client := &http.Client{
 | |
| 		Timeout: 10 * time.Second,
 | |
| 	}
 | |
| 
 | |
| 	// Make HTTP request
 | |
| 	resp, err := client.Get(url)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
 | |
| 	}
 | |
| 	defer func() { _ = resp.Body.Close() }()
 | |
| 
 | |
| 	// Check response status
 | |
| 	if resp.StatusCode != http.StatusOK {
 | |
| 		return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
 | |
| 	}
 | |
| 
 | |
| 	return DecodePublicJWKs(resp.Body)
 | |
| }
 | |
| 
 | |
| // decodeRSAPublicJWK parses an RSA public key from a JWK
 | |
| func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
 | |
| 	n, err := base64.RawURLEncoding.DecodeString(jwk.N)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid RSA modulus: %w", err)
 | |
| 	}
 | |
| 	e, err := base64.RawURLEncoding.DecodeString(jwk.E)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid RSA exponent: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Convert exponent to int
 | |
| 	eInt := new(big.Int).SetBytes(e).Int64()
 | |
| 	if eInt > int64(^uint(0)>>1) || eInt < 0 {
 | |
| 		return nil, fmt.Errorf("RSA exponent too large or negative")
 | |
| 	}
 | |
| 
 | |
| 	return &rsa.PublicKey{
 | |
| 		N: new(big.Int).SetBytes(n),
 | |
| 		E: int(eInt),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // decodeECDSAPublicJWK parses an ECDSA public key from a JWK
 | |
| func decodeECDSAPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) {
 | |
| 	x, err := base64.RawURLEncoding.DecodeString(jwk.X)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid ECDSA X: %w", err)
 | |
| 	}
 | |
| 	y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid ECDSA Y: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	var curve elliptic.Curve
 | |
| 	switch jwk.Crv {
 | |
| 	case "P-256":
 | |
| 		curve = elliptic.P256()
 | |
| 	case "P-384":
 | |
| 		curve = elliptic.P384()
 | |
| 	case "P-521":
 | |
| 		curve = elliptic.P521()
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("unsupported ECDSA curve: %s", jwk.Crv)
 | |
| 	}
 | |
| 
 | |
| 	return &ecdsa.PublicKey{
 | |
| 		Curve: curve,
 | |
| 		X:     new(big.Int).SetBytes(x),
 | |
| 		Y:     new(big.Int).SetBytes(y),
 | |
| 	}, nil
 | |
| }
 |