golib/auth/jwt/jwt.go

484 lines
14 KiB
Go

// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"slices"
"strings"
"time"
)
type Keypair struct {
Thumbprint string
PrivateKey *ecdsa.PrivateKey
}
type JWK struct {
Kty string `json:"kty"`
Crv string `json:"crv"`
D string `json:"d"`
X string `json:"x"`
Y string `json:"y"`
}
type JWT string
func (jwt JWT) Split() (string, string, string, error) {
parts := strings.Split(string(jwt), ".")
if len(parts) != 3 {
return "", "", "", fmt.Errorf("invalid JWT format")
}
rawHeader, rawPayload, rawSig := parts[0], parts[1], parts[2]
return rawHeader, rawPayload, rawSig, nil
}
func (jwt JWT) Decode() (JWS, error) {
h64, p64, s64, err := jwt.Split()
if err != nil {
return JWS{}, err
}
var jws JWS
var sigEnc string
jws.Protected, jws.Payload, sigEnc = h64, p64, s64
header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
if err != nil {
return jws, fmt.Errorf("invalid header encoding: %v", err)
}
if err := json.Unmarshal(header, &jws.Header); err != nil {
return jws, fmt.Errorf("invalid header JSON: %v", err)
}
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return jws, fmt.Errorf("invalid claims encoding: %v", err)
}
if err := json.Unmarshal(payload, &jws.Claims); err != nil {
return jws, fmt.Errorf("invalid claims JSON: %v", err)
}
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
return jws, fmt.Errorf("invalid signature encoding: %v", err)
}
return jws, nil
}
type JWS struct {
Protected string `json:"-"` // base64
Header MyHeader `json:"headers"`
Payload string `json:"-"` // base64
Claims MyClaims `json:"claims"`
Signature URLBase64 `json:"signature"`
Verified bool `json:"-"`
}
type MyHeader struct {
StandardHeader
}
type StandardHeader struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
type MyClaims struct {
StandardClaims
Email string `json:"email"`
EmployeeID string `json:"employee_id"`
FamilyName string `json:"family_name"`
GivenName string `json:"given_name"`
Roles []string `json:"roles"`
}
type StandardClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Aud string `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
AuthTime int64 `json:"auth_time"`
Nonce string `json:"nonce,omitempty"`
Amr []string `json:"amr"`
Azp string `json:"azp,omitempty"`
Jti string `json:"jti"`
}
func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid JWK X: %v", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("invalid JWK Y: %v", err)
}
d, err := base64.RawURLEncoding.DecodeString(jwk.D)
if err != nil {
return nil, fmt.Errorf("invalid JWK D: %v", err)
}
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
},
D: new(big.Int).SetBytes(d),
}, nil
}
func NewJWS(email, employeeID, issuer, thumbprint string, roles []string) (JWS, error) {
var jws JWS
jws.Header.StandardHeader = StandardHeader{
Alg: "ES256",
Kid: thumbprint,
Typ: "JWT",
}
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
now := time.Now().Unix()
jtiBytes := make([]byte, 16)
if _, err := rand.Read(jtiBytes); err != nil {
return JWS{}, fmt.Errorf("failed to generate Jti: %v", err)
}
jti := base64.RawURLEncoding.EncodeToString(jtiBytes)
emailName := strings.Split(email, "@")[0]
jws.Claims = MyClaims{
StandardClaims: StandardClaims{
AuthTime: now,
Exp: now + 15*60*37, // TODO remove
Iat: now,
Iss: issuer,
Jti: jti,
Sub: email,
Amr: []string{"pwd"},
},
Email: email,
EmployeeID: employeeID,
FamilyName: "McTestface",
GivenName: strings.ToUpper(emailName),
Roles: roles,
}
claimsJSON, _ := json.Marshal(jws.Claims)
jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
return jws, nil
}
func (jws *JWS) Sign(key *ecdsa.PrivateKey) ([]byte, error) {
var err error
jws.Signature, err = SignJWS(jws.Protected, jws.Payload, key)
return jws.Signature, err
}
// UnsafeVerify only checks the signature, use Validate to check all values
func (jws *JWS) UnsafeVerify(pub *ecdsa.PublicKey) bool {
hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload))
n := len(jws.Signature)
if n != 64 {
// return fmt.Errorf("expected a 64-byte signature consisting of two 32-byte r and s components, but got %d instead (perhaps ASN.1 or other format)", n)
return false
}
r := new(big.Int).SetBytes(jws.Signature[:32])
s := new(big.Int).SetBytes(jws.Signature[32:])
jws.Verified = ecdsa.Verify(pub, hash[:], r, s)
return jws.Verified
}
// ValidateParams holds validation configuration.
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type ValidateParams struct {
Now time.Time
IgnoreIss bool
Iss string
IgnoreSub bool
Sub string
IgnoreAud bool
Aud string
IgnoreExp bool
IgnoreJti bool
Jti string
IgnoreIat bool
IgnoreAuthTime bool
MaxAge time.Duration
IgnoreNonce bool
Nonce string
IgnoreAmr bool
RequiredAmrs []string
IgnoreAzp bool
Azp string
IgnoreSig bool
}
// Validate checks common JWS fields and issuer, collecting all errors.
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
var errs []string
if params.Now.IsZero() {
params.Now = time.Now()
}
// Required to exist and match
if len(params.Iss) > 0 || !params.IgnoreIss {
if len(jws.Claims.Iss) == 0 {
errs = append(errs, ("missing or malformed 'iss' (token issuer, identifier for public key)"))
} else if jws.Claims.Iss != params.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", jws.Claims.Iss, params.Iss))
}
}
// Required to exist, optional match
if len(jws.Claims.Sub) == 0 {
if !params.IgnoreSub {
errs = append(errs, ("missing or malformed 'sub' (subject, typically pairwise user id)"))
}
} else if len(params.Sub) > 0 {
if params.Sub != jws.Claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject, typically pairwise user id) mismatch: got %s, expected %s", jws.Claims.Sub, params.Sub))
}
}
// Required to exist and match
if len(params.Aud) > 0 || !params.IgnoreAud {
if len(jws.Claims.Aud) == 0 {
errs = append(errs, ("missing or malformed 'aud' (audience receiving token)"))
} else if jws.Claims.Aud != params.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience receiving token) mismatch: got %s, expected %s", jws.Claims.Aud, params.Aud))
}
}
// Required to exist and not be in the past
if !params.IgnoreExp {
if jws.Claims.Exp <= 0 {
errs = append(errs, ("missing or malformed 'exp' (expiration date in seconds)"))
} else if jws.Claims.Exp < params.Now.Unix() {
duration := time.Since(time.Unix(jws.Claims.Exp, 0))
expTime := time.Unix(jws.Claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
}
}
// Required to exist and not be in the future
if !params.IgnoreIat {
if jws.Claims.Iat <= 0 {
errs = append(errs, ("missing or malformed 'iat' (issued at, when token was signed)"))
} else if jws.Claims.Iat > params.Now.Unix() {
duration := time.Unix(jws.Claims.Iat, 0).Sub(params.Now)
iatTime := time.Unix(jws.Claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("'iat' (issued at, when token was signed) is %s in the future (%s)", formatDuration(duration), iatTime))
}
}
// Should exist, in the past, with optional max age
if params.MaxAge > 0 || !params.IgnoreAuthTime {
if jws.Claims.AuthTime == 0 {
errs = append(errs, ("missing or malformed 'auth_time' (time of real-world user authentication, in seconds)"))
} else {
authTime := time.Unix(jws.Claims.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
age := params.Now.Sub(authTime)
diff := age - params.MaxAge
if jws.Claims.AuthTime > params.Now.Unix() {
fromNow := time.Unix(jws.Claims.AuthTime, 0).Sub(params.Now)
authTimeStr := time.Unix(jws.Claims.AuthTime, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf(
"'auth_time' (time of real-world user authentication) of %s is %s in the future (server time %s)",
authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")),
)
} else if age > params.MaxAge {
errs = append(errs, fmt.Sprintf(
"'auth_time' (time of real-world user authentication) of %s is %s old, which exceeds the max age of %s (%ds) by %s",
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), params.MaxAge/time.Second, formatDuration(diff)),
)
}
}
}
// Optional
if params.Jti != jws.Claims.Jti {
if len(params.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", jws.Claims.Jti, params.Jti))
} else if !params.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", jws.Claims.Jti))
}
}
// Optional
if params.Nonce != jws.Claims.Nonce {
if len(params.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' (one-time random salt, as string) mismatch: got %s, expected %s", jws.Claims.Nonce, params.Nonce))
} else if !params.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce' (one-time random salt): %s", jws.Claims.Nonce))
}
}
// Acr check not implemented because the use case is not yet clear
// Should exist, optional match
if !params.IgnoreAmr {
if len(jws.Claims.Amr) == 0 {
errs = append(errs, ("missing or malformed 'amr' (authorization methods, as json list)"))
} else {
if len(params.RequiredAmrs) > 0 {
for _, required := range params.RequiredAmrs {
if !slices.Contains(jws.Claims.Amr, required) {
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr' (authorization methods, as json list)", required))
}
}
}
// TODO specify multiple amrs in a tiered list (must have at least one from each list)
// count := 0
// if len(params.AcceptableAmrs) > 0 {
// for _, amr := range jws.Claims.Amr {
// if slices.Contains(params.AcceptableAmrs, amr) {
// count += 1
// }
// }
// }
}
}
// Optional, should match if exists
if params.Azp != jws.Claims.Azp {
if len(params.Azp) > 0 {
errs = append(errs, ("missing or malformed 'azp' (authorized party which presents token)"))
} else if !params.IgnoreAzp {
errs = append(errs, fmt.Sprintf("'azp' mismatch (authorized party which presents token): got %s, expected %s", jws.Claims.Azp, params.Azp))
}
}
// Must be checked
if !params.IgnoreSig {
if !jws.Verified {
errs = append(errs, ("signature was not checked"))
}
}
if len(errs) > 0 {
timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST"))
if loc, err := time.LoadLocation("Local"); err == nil {
timeInfo += fmt.Sprintf(" %s", loc)
}
errs = append(errs, timeInfo)
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
func SignJWS(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
hash := sha256.Sum256([]byte(header + "." + payload))
r, s, err := ecdsa.Sign(rand.Reader, key, hash[:])
if err != nil {
return nil, fmt.Errorf("failed to sign: %v", err)
}
return append(r.Bytes(), s.Bytes()...), nil
}
func (jws JWS) Encode() string {
sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature)
return jws.Protected + "." + jws.Payload + "." + sigEnc
}
func EncodeToJWT(signingInput string, signature []byte) string {
sigEnc := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + sigEnc
}
func (jwk JWK) Thumbprint() (string, error) {
data := map[string]string{
"crv": jwk.Crv,
"kty": jwk.Kty,
"x": jwk.X,
"y": jwk.Y,
}
jsonData, err := json.Marshal(data)
if err != nil {
return "", err
}
hash := sha256.Sum256(jsonData)
return base64.RawURLEncoding.EncodeToString(hash[:]), nil
}
// URLBase64 unmarshals to bytes and marshals to a raw url base64 string
type URLBase64 []byte
func (s URLBase64) String() string {
encoded := base64.RawURLEncoding.EncodeToString(s)
return encoded
}
// MarshalJSON implements JSON marshaling to URL-safe base64.
func (s URLBase64) MarshalJSON() ([]byte, error) {
encoded := base64.RawURLEncoding.EncodeToString(s)
return json.Marshal(encoded)
}
// UnmarshalJSON implements JSON unmarshaling from URL-safe base64.
func (s *URLBase64) UnmarshalJSON(data []byte) error {
dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data)
if err != nil {
return fmt.Errorf("decode base64url signature: %w", err)
}
*s = dst
return nil
}
func formatDuration(d time.Duration) string {
if d < 0 {
d = -d
}
days := int(d / (24 * time.Hour))
d -= time.Duration(days) * 24 * time.Hour
hours := int(d / time.Hour)
d -= time.Duration(hours) * time.Hour
minutes := int(d / time.Minute)
d -= time.Duration(minutes) * time.Minute
seconds := int(d / time.Second)
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%dd", days))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%dh", hours))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%dm", minutes))
}
if seconds > 0 || len(parts) == 0 {
parts = append(parts, fmt.Sprintf("%ds", seconds))
}
if seconds == 0 || len(parts) == 0 {
d -= time.Duration(seconds) * time.Second
millis := int(d / time.Millisecond)
parts = append(parts, fmt.Sprintf("%dms", millis))
}
return strings.Join(parts, " ")
}