mirror of
				https://github.com/therootcompany/golib.git
				synced 2025-10-31 05:02:52 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			484 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			484 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
 | |
| //
 | |
| // This Source Code Form is subject to the terms of the Mozilla Public
 | |
| // License, v. 2.0. If a copy of the MPL was not distributed with this
 | |
| // file, You can obtain one at https://mozilla.org/MPL/2.0/.
 | |
| //
 | |
| // SPDX-License-Identifier: MPL-2.0
 | |
| 
 | |
| package jwt
 | |
| 
 | |
| import (
 | |
| 	"crypto/ecdsa"
 | |
| 	"crypto/elliptic"
 | |
| 	"crypto/rand"
 | |
| 	"crypto/sha256"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"math/big"
 | |
| 	"slices"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| type Keypair struct {
 | |
| 	Thumbprint string
 | |
| 	PrivateKey *ecdsa.PrivateKey
 | |
| }
 | |
| 
 | |
| type JWK struct {
 | |
| 	Kty string `json:"kty"`
 | |
| 	Crv string `json:"crv"`
 | |
| 	D   string `json:"d"`
 | |
| 	X   string `json:"x"`
 | |
| 	Y   string `json:"y"`
 | |
| }
 | |
| 
 | |
| type JWT string
 | |
| 
 | |
| func (jwt JWT) Split() (string, string, string, error) {
 | |
| 	parts := strings.Split(string(jwt), ".")
 | |
| 	if len(parts) != 3 {
 | |
| 		return "", "", "", fmt.Errorf("invalid JWT format")
 | |
| 	}
 | |
| 
 | |
| 	rawHeader, rawPayload, rawSig := parts[0], parts[1], parts[2]
 | |
| 	return rawHeader, rawPayload, rawSig, nil
 | |
| }
 | |
| 
 | |
| func (jwt JWT) Decode() (JWS, error) {
 | |
| 	h64, p64, s64, err := jwt.Split()
 | |
| 	if err != nil {
 | |
| 		return JWS{}, err
 | |
| 	}
 | |
| 
 | |
| 	var jws JWS
 | |
| 	var sigEnc string
 | |
| 	jws.Protected, jws.Payload, sigEnc = h64, p64, s64
 | |
| 
 | |
| 	header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
 | |
| 	if err != nil {
 | |
| 		return jws, fmt.Errorf("invalid header encoding: %v", err)
 | |
| 	}
 | |
| 	if err := json.Unmarshal(header, &jws.Header); err != nil {
 | |
| 		return jws, fmt.Errorf("invalid header JSON: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
 | |
| 	if err != nil {
 | |
| 		return jws, fmt.Errorf("invalid claims encoding: %v", err)
 | |
| 	}
 | |
| 	if err := json.Unmarshal(payload, &jws.Claims); err != nil {
 | |
| 		return jws, fmt.Errorf("invalid claims JSON: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
 | |
| 		return jws, fmt.Errorf("invalid signature encoding: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return jws, nil
 | |
| }
 | |
| 
 | |
| type JWS struct {
 | |
| 	Protected string    `json:"-"` // base64
 | |
| 	Header    MyHeader  `json:"headers"`
 | |
| 	Payload   string    `json:"-"` // base64
 | |
| 	Claims    MyClaims  `json:"claims"`
 | |
| 	Signature URLBase64 `json:"signature"`
 | |
| 	Verified  bool      `json:"-"`
 | |
| }
 | |
| 
 | |
| type MyHeader struct {
 | |
| 	StandardHeader
 | |
| }
 | |
| 
 | |
| type StandardHeader struct {
 | |
| 	Alg string `json:"alg"`
 | |
| 	Kid string `json:"kid"`
 | |
| 	Typ string `json:"typ"`
 | |
| }
 | |
| 
 | |
| type MyClaims struct {
 | |
| 	StandardClaims
 | |
| 	Email      string   `json:"email"`
 | |
| 	EmployeeID string   `json:"employee_id"`
 | |
| 	FamilyName string   `json:"family_name"`
 | |
| 	GivenName  string   `json:"given_name"`
 | |
| 	Roles      []string `json:"roles"`
 | |
| }
 | |
| 
 | |
| type StandardClaims struct {
 | |
| 	Iss      string   `json:"iss"`
 | |
| 	Sub      string   `json:"sub"`
 | |
| 	Aud      string   `json:"aud"`
 | |
| 	Exp      int64    `json:"exp"`
 | |
| 	Iat      int64    `json:"iat"`
 | |
| 	AuthTime int64    `json:"auth_time"`
 | |
| 	Nonce    string   `json:"nonce,omitempty"`
 | |
| 	Amr      []string `json:"amr"`
 | |
| 	Azp      string   `json:"azp,omitempty"`
 | |
| 	Jti      string   `json:"jti"`
 | |
| }
 | |
| 
 | |
| func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) {
 | |
| 	x, err := base64.RawURLEncoding.DecodeString(jwk.X)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid JWK X: %v", err)
 | |
| 	}
 | |
| 	y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid JWK Y: %v", err)
 | |
| 	}
 | |
| 	d, err := base64.RawURLEncoding.DecodeString(jwk.D)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid JWK D: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &ecdsa.PrivateKey{
 | |
| 		PublicKey: ecdsa.PublicKey{
 | |
| 			Curve: elliptic.P256(),
 | |
| 			X:     new(big.Int).SetBytes(x),
 | |
| 			Y:     new(big.Int).SetBytes(y),
 | |
| 		},
 | |
| 		D: new(big.Int).SetBytes(d),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func NewJWS(email, employeeID, issuer, thumbprint string, roles []string) (JWS, error) {
 | |
| 	var jws JWS
 | |
| 
 | |
| 	jws.Header.StandardHeader = StandardHeader{
 | |
| 		Alg: "ES256",
 | |
| 		Kid: thumbprint,
 | |
| 		Typ: "JWT",
 | |
| 	}
 | |
| 	headerJSON, _ := json.Marshal(jws.Header)
 | |
| 	jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
 | |
| 
 | |
| 	now := time.Now().Unix()
 | |
| 	jtiBytes := make([]byte, 16)
 | |
| 	if _, err := rand.Read(jtiBytes); err != nil {
 | |
| 		return JWS{}, fmt.Errorf("failed to generate Jti: %v", err)
 | |
| 	}
 | |
| 	jti := base64.RawURLEncoding.EncodeToString(jtiBytes)
 | |
| 	emailName := strings.Split(email, "@")[0]
 | |
| 
 | |
| 	jws.Claims = MyClaims{
 | |
| 		StandardClaims: StandardClaims{
 | |
| 			AuthTime: now,
 | |
| 			Exp:      now + 15*60*37, // TODO remove
 | |
| 			Iat:      now,
 | |
| 			Iss:      issuer,
 | |
| 			Jti:      jti,
 | |
| 			Sub:      email,
 | |
| 			Amr:      []string{"pwd"},
 | |
| 		},
 | |
| 		Email:      email,
 | |
| 		EmployeeID: employeeID,
 | |
| 		FamilyName: "McTestface",
 | |
| 		GivenName:  strings.ToUpper(emailName),
 | |
| 		Roles:      roles,
 | |
| 	}
 | |
| 	claimsJSON, _ := json.Marshal(jws.Claims)
 | |
| 	jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
 | |
| 
 | |
| 	return jws, nil
 | |
| }
 | |
| 
 | |
| func (jws *JWS) Sign(key *ecdsa.PrivateKey) ([]byte, error) {
 | |
| 	var err error
 | |
| 	jws.Signature, err = SignJWS(jws.Protected, jws.Payload, key)
 | |
| 	return jws.Signature, err
 | |
| }
 | |
| 
 | |
| // UnsafeVerify only checks the signature, use Validate to check all values
 | |
| func (jws *JWS) UnsafeVerify(pub *ecdsa.PublicKey) bool {
 | |
| 	hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload))
 | |
| 	n := len(jws.Signature)
 | |
| 	if n != 64 {
 | |
| 		// return fmt.Errorf("expected a 64-byte signature consisting of two 32-byte r and s components, but got %d instead (perhaps ASN.1 or other format)", n)
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	r := new(big.Int).SetBytes(jws.Signature[:32])
 | |
| 	s := new(big.Int).SetBytes(jws.Signature[32:])
 | |
| 
 | |
| 	jws.Verified = ecdsa.Verify(pub, hash[:], r, s)
 | |
| 	return jws.Verified
 | |
| }
 | |
| 
 | |
| // ValidateParams holds validation configuration.
 | |
| // https://openid.net/specs/openid-connect-core-1_0.html#IDToken
 | |
| type ValidateParams struct {
 | |
| 	Now            time.Time
 | |
| 	IgnoreIss      bool
 | |
| 	Iss            string
 | |
| 	IgnoreSub      bool
 | |
| 	Sub            string
 | |
| 	IgnoreAud      bool
 | |
| 	Aud            string
 | |
| 	IgnoreExp      bool
 | |
| 	IgnoreJti      bool
 | |
| 	Jti            string
 | |
| 	IgnoreIat      bool
 | |
| 	IgnoreAuthTime bool
 | |
| 	MaxAge         time.Duration
 | |
| 	IgnoreNonce    bool
 | |
| 	Nonce          string
 | |
| 	IgnoreAmr      bool
 | |
| 	RequiredAmrs   []string
 | |
| 	IgnoreAzp      bool
 | |
| 	Azp            string
 | |
| 	IgnoreSig      bool
 | |
| }
 | |
| 
 | |
| // Validate checks common JWS fields and issuer, collecting all errors.
 | |
| func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
 | |
| 	var errs []string
 | |
| 
 | |
| 	if params.Now.IsZero() {
 | |
| 		params.Now = time.Now()
 | |
| 	}
 | |
| 
 | |
| 	// Required to exist and match
 | |
| 	if len(params.Iss) > 0 || !params.IgnoreIss {
 | |
| 		if len(jws.Claims.Iss) == 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'iss' (token issuer, identifier for public key)"))
 | |
| 		} else if jws.Claims.Iss != params.Iss {
 | |
| 			errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", jws.Claims.Iss, params.Iss))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Required to exist, optional match
 | |
| 	if len(jws.Claims.Sub) == 0 {
 | |
| 		if !params.IgnoreSub {
 | |
| 			errs = append(errs, ("missing or malformed 'sub' (subject, typically pairwise user id)"))
 | |
| 		}
 | |
| 	} else if len(params.Sub) > 0 {
 | |
| 		if params.Sub != jws.Claims.Sub {
 | |
| 			errs = append(errs, fmt.Sprintf("'sub' (subject, typically pairwise user id) mismatch: got %s, expected %s", jws.Claims.Sub, params.Sub))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Required to exist and match
 | |
| 	if len(params.Aud) > 0 || !params.IgnoreAud {
 | |
| 		if len(jws.Claims.Aud) == 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'aud' (audience receiving token)"))
 | |
| 		} else if jws.Claims.Aud != params.Aud {
 | |
| 			errs = append(errs, fmt.Sprintf("'aud' (audience receiving token) mismatch: got %s, expected %s", jws.Claims.Aud, params.Aud))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Required to exist and not be in the past
 | |
| 	if !params.IgnoreExp {
 | |
| 		if jws.Claims.Exp <= 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'exp' (expiration date in seconds)"))
 | |
| 		} else if jws.Claims.Exp < params.Now.Unix() {
 | |
| 			duration := time.Since(time.Unix(jws.Claims.Exp, 0))
 | |
| 			expTime := time.Unix(jws.Claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
 | |
| 			errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Required to exist and not be in the future
 | |
| 	if !params.IgnoreIat {
 | |
| 		if jws.Claims.Iat <= 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'iat' (issued at, when token was signed)"))
 | |
| 		} else if jws.Claims.Iat > params.Now.Unix() {
 | |
| 			duration := time.Unix(jws.Claims.Iat, 0).Sub(params.Now)
 | |
| 			iatTime := time.Unix(jws.Claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
 | |
| 			errs = append(errs, fmt.Sprintf("'iat' (issued at, when token was signed) is %s in the future (%s)", formatDuration(duration), iatTime))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Should exist, in the past, with optional max age
 | |
| 	if params.MaxAge > 0 || !params.IgnoreAuthTime {
 | |
| 		if jws.Claims.AuthTime == 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'auth_time' (time of real-world user authentication, in seconds)"))
 | |
| 		} else {
 | |
| 			authTime := time.Unix(jws.Claims.AuthTime, 0)
 | |
| 			authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
 | |
| 			age := params.Now.Sub(authTime)
 | |
| 			diff := age - params.MaxAge
 | |
| 			if jws.Claims.AuthTime > params.Now.Unix() {
 | |
| 				fromNow := time.Unix(jws.Claims.AuthTime, 0).Sub(params.Now)
 | |
| 				authTimeStr := time.Unix(jws.Claims.AuthTime, 0).Format("2006-01-02 15:04:05 MST")
 | |
| 				errs = append(errs, fmt.Sprintf(
 | |
| 					"'auth_time' (time of real-world user authentication) of %s is %s in the future (server time %s)",
 | |
| 					authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")),
 | |
| 				)
 | |
| 			} else if age > params.MaxAge {
 | |
| 				errs = append(errs, fmt.Sprintf(
 | |
| 					"'auth_time' (time of real-world user authentication) of %s is %s old, which exceeds the max age of %s (%ds) by %s",
 | |
| 					authTimeStr, formatDuration(age), formatDuration(params.MaxAge), params.MaxAge/time.Second, formatDuration(diff)),
 | |
| 				)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Optional
 | |
| 	if params.Jti != jws.Claims.Jti {
 | |
| 		if len(params.Jti) > 0 {
 | |
| 			errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", jws.Claims.Jti, params.Jti))
 | |
| 		} else if !params.IgnoreJti {
 | |
| 			errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", jws.Claims.Jti))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Optional
 | |
| 	if params.Nonce != jws.Claims.Nonce {
 | |
| 		if len(params.Nonce) > 0 {
 | |
| 			errs = append(errs, fmt.Sprintf("'nonce' (one-time random salt, as string) mismatch: got %s, expected %s", jws.Claims.Nonce, params.Nonce))
 | |
| 		} else if !params.IgnoreNonce {
 | |
| 			errs = append(errs, fmt.Sprintf("unchecked 'nonce' (one-time random salt): %s", jws.Claims.Nonce))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Acr check not implemented because the use case is not yet clear
 | |
| 
 | |
| 	// Should exist, optional match
 | |
| 	if !params.IgnoreAmr {
 | |
| 		if len(jws.Claims.Amr) == 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'amr' (authorization methods, as json list)"))
 | |
| 		} else {
 | |
| 			if len(params.RequiredAmrs) > 0 {
 | |
| 				for _, required := range params.RequiredAmrs {
 | |
| 					if !slices.Contains(jws.Claims.Amr, required) {
 | |
| 						errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr' (authorization methods, as json list)", required))
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// TODO specify multiple amrs in a tiered list (must have at least one from each list)
 | |
| 			// count := 0
 | |
| 			// if len(params.AcceptableAmrs) > 0 {
 | |
| 			// 	for _, amr := range jws.Claims.Amr {
 | |
| 			// 		if slices.Contains(params.AcceptableAmrs, amr) {
 | |
| 			// 			count += 1
 | |
| 			// 		}
 | |
| 			// 	}
 | |
| 			// }
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Optional, should match if exists
 | |
| 	if params.Azp != jws.Claims.Azp {
 | |
| 		if len(params.Azp) > 0 {
 | |
| 			errs = append(errs, ("missing or malformed 'azp' (authorized party which presents token)"))
 | |
| 		} else if !params.IgnoreAzp {
 | |
| 			errs = append(errs, fmt.Sprintf("'azp' mismatch (authorized party which presents token): got %s, expected %s", jws.Claims.Azp, params.Azp))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Must be checked
 | |
| 	if !params.IgnoreSig {
 | |
| 		if !jws.Verified {
 | |
| 			errs = append(errs, ("signature was not checked"))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if len(errs) > 0 {
 | |
| 		timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST"))
 | |
| 		if loc, err := time.LoadLocation("Local"); err == nil {
 | |
| 			timeInfo += fmt.Sprintf(" %s", loc)
 | |
| 		}
 | |
| 		errs = append(errs, timeInfo)
 | |
| 		return errs, fmt.Errorf("has errors")
 | |
| 	}
 | |
| 	return nil, nil
 | |
| }
 | |
| 
 | |
| func SignJWS(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
 | |
| 	hash := sha256.Sum256([]byte(header + "." + payload))
 | |
| 	r, s, err := ecdsa.Sign(rand.Reader, key, hash[:])
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to sign: %v", err)
 | |
| 	}
 | |
| 	return append(r.Bytes(), s.Bytes()...), nil
 | |
| }
 | |
| 
 | |
| func (jws JWS) Encode() string {
 | |
| 	sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature)
 | |
| 	return jws.Protected + "." + jws.Payload + "." + sigEnc
 | |
| }
 | |
| 
 | |
| func EncodeToJWT(signingInput string, signature []byte) string {
 | |
| 	sigEnc := base64.RawURLEncoding.EncodeToString(signature)
 | |
| 	return signingInput + "." + sigEnc
 | |
| }
 | |
| 
 | |
| func (jwk JWK) Thumbprint() (string, error) {
 | |
| 	data := map[string]string{
 | |
| 		"crv": jwk.Crv,
 | |
| 		"kty": jwk.Kty,
 | |
| 		"x":   jwk.X,
 | |
| 		"y":   jwk.Y,
 | |
| 	}
 | |
| 	jsonData, err := json.Marshal(data)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	hash := sha256.Sum256(jsonData)
 | |
| 	return base64.RawURLEncoding.EncodeToString(hash[:]), nil
 | |
| }
 | |
| 
 | |
| // URLBase64 unmarshals to bytes and marshals to a raw url base64 string
 | |
| type URLBase64 []byte
 | |
| 
 | |
| func (s URLBase64) String() string {
 | |
| 	encoded := base64.RawURLEncoding.EncodeToString(s)
 | |
| 	return encoded
 | |
| }
 | |
| 
 | |
| // MarshalJSON implements JSON marshaling to URL-safe base64.
 | |
| func (s URLBase64) MarshalJSON() ([]byte, error) {
 | |
| 	encoded := base64.RawURLEncoding.EncodeToString(s)
 | |
| 	return json.Marshal(encoded)
 | |
| }
 | |
| 
 | |
| // UnmarshalJSON implements JSON unmarshaling from URL-safe base64.
 | |
| func (s *URLBase64) UnmarshalJSON(data []byte) error {
 | |
| 	dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("decode base64url signature: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	*s = dst
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func formatDuration(d time.Duration) string {
 | |
| 	if d < 0 {
 | |
| 		d = -d
 | |
| 	}
 | |
| 	days := int(d / (24 * time.Hour))
 | |
| 	d -= time.Duration(days) * 24 * time.Hour
 | |
| 	hours := int(d / time.Hour)
 | |
| 	d -= time.Duration(hours) * time.Hour
 | |
| 	minutes := int(d / time.Minute)
 | |
| 	d -= time.Duration(minutes) * time.Minute
 | |
| 	seconds := int(d / time.Second)
 | |
| 
 | |
| 	var parts []string
 | |
| 	if days > 0 {
 | |
| 		parts = append(parts, fmt.Sprintf("%dd", days))
 | |
| 	}
 | |
| 	if hours > 0 {
 | |
| 		parts = append(parts, fmt.Sprintf("%dh", hours))
 | |
| 	}
 | |
| 	if minutes > 0 {
 | |
| 		parts = append(parts, fmt.Sprintf("%dm", minutes))
 | |
| 	}
 | |
| 	if seconds > 0 || len(parts) == 0 {
 | |
| 		parts = append(parts, fmt.Sprintf("%ds", seconds))
 | |
| 	}
 | |
| 	if seconds == 0 || len(parts) == 0 {
 | |
| 		d -= time.Duration(seconds) * time.Second
 | |
| 		millis := int(d / time.Millisecond)
 | |
| 		parts = append(parts, fmt.Sprintf("%dms", millis))
 | |
| 	}
 | |
| 
 | |
| 	return strings.Join(parts, " ")
 | |
| }
 |