mirror of
https://github.com/therootcompany/golib.git
synced 2025-10-08 10:08:17 +00:00
484 lines
14 KiB
Go
484 lines
14 KiB
Go
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package jwt
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/big"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Keypair struct {
|
|
Thumbprint string
|
|
PrivateKey *ecdsa.PrivateKey
|
|
}
|
|
|
|
type JWK struct {
|
|
Kty string `json:"kty"`
|
|
Crv string `json:"crv"`
|
|
D string `json:"d"`
|
|
X string `json:"x"`
|
|
Y string `json:"y"`
|
|
}
|
|
|
|
type JWT string
|
|
|
|
func (jwt JWT) Split() (string, string, string, error) {
|
|
parts := strings.Split(string(jwt), ".")
|
|
if len(parts) != 3 {
|
|
return "", "", "", fmt.Errorf("invalid JWT format")
|
|
}
|
|
|
|
rawHeader, rawPayload, rawSig := parts[0], parts[1], parts[2]
|
|
return rawHeader, rawPayload, rawSig, nil
|
|
}
|
|
|
|
func (jwt JWT) Decode() (JWS, error) {
|
|
h64, p64, s64, err := jwt.Split()
|
|
if err != nil {
|
|
return JWS{}, err
|
|
}
|
|
|
|
var jws JWS
|
|
var sigEnc string
|
|
jws.Protected, jws.Payload, sigEnc = h64, p64, s64
|
|
|
|
header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
|
|
if err != nil {
|
|
return jws, fmt.Errorf("invalid header encoding: %v", err)
|
|
}
|
|
if err := json.Unmarshal(header, &jws.Header); err != nil {
|
|
return jws, fmt.Errorf("invalid header JSON: %v", err)
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
|
if err != nil {
|
|
return jws, fmt.Errorf("invalid claims encoding: %v", err)
|
|
}
|
|
if err := json.Unmarshal(payload, &jws.Claims); err != nil {
|
|
return jws, fmt.Errorf("invalid claims JSON: %v", err)
|
|
}
|
|
|
|
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
|
|
return jws, fmt.Errorf("invalid signature encoding: %v", err)
|
|
}
|
|
|
|
return jws, nil
|
|
}
|
|
|
|
type JWS struct {
|
|
Protected string `json:"-"` // base64
|
|
Header MyHeader `json:"headers"`
|
|
Payload string `json:"-"` // base64
|
|
Claims MyClaims `json:"claims"`
|
|
Signature URLBase64 `json:"signature"`
|
|
Verified bool `json:"-"`
|
|
}
|
|
|
|
type MyHeader struct {
|
|
StandardHeader
|
|
}
|
|
|
|
type StandardHeader struct {
|
|
Alg string `json:"alg"`
|
|
Kid string `json:"kid"`
|
|
Typ string `json:"typ"`
|
|
}
|
|
|
|
type MyClaims struct {
|
|
StandardClaims
|
|
Email string `json:"email"`
|
|
EmployeeID string `json:"employee_id"`
|
|
FamilyName string `json:"family_name"`
|
|
GivenName string `json:"given_name"`
|
|
Roles []string `json:"roles"`
|
|
}
|
|
|
|
type StandardClaims struct {
|
|
Iss string `json:"iss"`
|
|
Sub string `json:"sub"`
|
|
Aud string `json:"aud"`
|
|
Exp int64 `json:"exp"`
|
|
Iat int64 `json:"iat"`
|
|
AuthTime int64 `json:"auth_time"`
|
|
Nonce string `json:"nonce,omitempty"`
|
|
Amr []string `json:"amr"`
|
|
Azp string `json:"azp,omitempty"`
|
|
Jti string `json:"jti"`
|
|
}
|
|
|
|
func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) {
|
|
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWK X: %v", err)
|
|
}
|
|
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWK Y: %v", err)
|
|
}
|
|
d, err := base64.RawURLEncoding.DecodeString(jwk.D)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWK D: %v", err)
|
|
}
|
|
|
|
return &ecdsa.PrivateKey{
|
|
PublicKey: ecdsa.PublicKey{
|
|
Curve: elliptic.P256(),
|
|
X: new(big.Int).SetBytes(x),
|
|
Y: new(big.Int).SetBytes(y),
|
|
},
|
|
D: new(big.Int).SetBytes(d),
|
|
}, nil
|
|
}
|
|
|
|
func NewJWS(email, employeeID, issuer, thumbprint string, roles []string) (JWS, error) {
|
|
var jws JWS
|
|
|
|
jws.Header.StandardHeader = StandardHeader{
|
|
Alg: "ES256",
|
|
Kid: thumbprint,
|
|
Typ: "JWT",
|
|
}
|
|
headerJSON, _ := json.Marshal(jws.Header)
|
|
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
|
|
now := time.Now().Unix()
|
|
jtiBytes := make([]byte, 16)
|
|
if _, err := rand.Read(jtiBytes); err != nil {
|
|
return JWS{}, fmt.Errorf("failed to generate Jti: %v", err)
|
|
}
|
|
jti := base64.RawURLEncoding.EncodeToString(jtiBytes)
|
|
emailName := strings.Split(email, "@")[0]
|
|
|
|
jws.Claims = MyClaims{
|
|
StandardClaims: StandardClaims{
|
|
AuthTime: now,
|
|
Exp: now + 15*60*37, // TODO remove
|
|
Iat: now,
|
|
Iss: issuer,
|
|
Jti: jti,
|
|
Sub: email,
|
|
Amr: []string{"pwd"},
|
|
},
|
|
Email: email,
|
|
EmployeeID: employeeID,
|
|
FamilyName: "McTestface",
|
|
GivenName: strings.ToUpper(emailName),
|
|
Roles: roles,
|
|
}
|
|
claimsJSON, _ := json.Marshal(jws.Claims)
|
|
jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
return jws, nil
|
|
}
|
|
|
|
func (jws *JWS) Sign(key *ecdsa.PrivateKey) ([]byte, error) {
|
|
var err error
|
|
jws.Signature, err = SignJWS(jws.Protected, jws.Payload, key)
|
|
return jws.Signature, err
|
|
}
|
|
|
|
// UnsafeVerify only checks the signature, use Validate to check all values
|
|
func (jws *JWS) UnsafeVerify(pub *ecdsa.PublicKey) bool {
|
|
hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload))
|
|
n := len(jws.Signature)
|
|
if n != 64 {
|
|
// return fmt.Errorf("expected a 64-byte signature consisting of two 32-byte r and s components, but got %d instead (perhaps ASN.1 or other format)", n)
|
|
return false
|
|
}
|
|
|
|
r := new(big.Int).SetBytes(jws.Signature[:32])
|
|
s := new(big.Int).SetBytes(jws.Signature[32:])
|
|
|
|
jws.Verified = ecdsa.Verify(pub, hash[:], r, s)
|
|
return jws.Verified
|
|
}
|
|
|
|
// ValidateParams holds validation configuration.
|
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
|
type ValidateParams struct {
|
|
Now time.Time
|
|
IgnoreIss bool
|
|
Iss string
|
|
IgnoreSub bool
|
|
Sub string
|
|
IgnoreAud bool
|
|
Aud string
|
|
IgnoreExp bool
|
|
IgnoreJti bool
|
|
Jti string
|
|
IgnoreIat bool
|
|
IgnoreAuthTime bool
|
|
MaxAge time.Duration
|
|
IgnoreNonce bool
|
|
Nonce string
|
|
IgnoreAmr bool
|
|
RequiredAmrs []string
|
|
IgnoreAzp bool
|
|
Azp string
|
|
IgnoreSig bool
|
|
}
|
|
|
|
// Validate checks common JWS fields and issuer, collecting all errors.
|
|
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
|
|
var errs []string
|
|
|
|
if params.Now.IsZero() {
|
|
params.Now = time.Now()
|
|
}
|
|
|
|
// Required to exist and match
|
|
if len(params.Iss) > 0 || !params.IgnoreIss {
|
|
if len(jws.Claims.Iss) == 0 {
|
|
errs = append(errs, ("missing or malformed 'iss' (token issuer, identifier for public key)"))
|
|
} else if jws.Claims.Iss != params.Iss {
|
|
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", jws.Claims.Iss, params.Iss))
|
|
}
|
|
}
|
|
|
|
// Required to exist, optional match
|
|
if len(jws.Claims.Sub) == 0 {
|
|
if !params.IgnoreSub {
|
|
errs = append(errs, ("missing or malformed 'sub' (subject, typically pairwise user id)"))
|
|
}
|
|
} else if len(params.Sub) > 0 {
|
|
if params.Sub != jws.Claims.Sub {
|
|
errs = append(errs, fmt.Sprintf("'sub' (subject, typically pairwise user id) mismatch: got %s, expected %s", jws.Claims.Sub, params.Sub))
|
|
}
|
|
}
|
|
|
|
// Required to exist and match
|
|
if len(params.Aud) > 0 || !params.IgnoreAud {
|
|
if len(jws.Claims.Aud) == 0 {
|
|
errs = append(errs, ("missing or malformed 'aud' (audience receiving token)"))
|
|
} else if jws.Claims.Aud != params.Aud {
|
|
errs = append(errs, fmt.Sprintf("'aud' (audience receiving token) mismatch: got %s, expected %s", jws.Claims.Aud, params.Aud))
|
|
}
|
|
}
|
|
|
|
// Required to exist and not be in the past
|
|
if !params.IgnoreExp {
|
|
if jws.Claims.Exp <= 0 {
|
|
errs = append(errs, ("missing or malformed 'exp' (expiration date in seconds)"))
|
|
} else if jws.Claims.Exp < params.Now.Unix() {
|
|
duration := time.Since(time.Unix(jws.Claims.Exp, 0))
|
|
expTime := time.Unix(jws.Claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
|
|
errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
|
|
}
|
|
}
|
|
|
|
// Required to exist and not be in the future
|
|
if !params.IgnoreIat {
|
|
if jws.Claims.Iat <= 0 {
|
|
errs = append(errs, ("missing or malformed 'iat' (issued at, when token was signed)"))
|
|
} else if jws.Claims.Iat > params.Now.Unix() {
|
|
duration := time.Unix(jws.Claims.Iat, 0).Sub(params.Now)
|
|
iatTime := time.Unix(jws.Claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
|
|
errs = append(errs, fmt.Sprintf("'iat' (issued at, when token was signed) is %s in the future (%s)", formatDuration(duration), iatTime))
|
|
}
|
|
}
|
|
|
|
// Should exist, in the past, with optional max age
|
|
if params.MaxAge > 0 || !params.IgnoreAuthTime {
|
|
if jws.Claims.AuthTime == 0 {
|
|
errs = append(errs, ("missing or malformed 'auth_time' (time of real-world user authentication, in seconds)"))
|
|
} else {
|
|
authTime := time.Unix(jws.Claims.AuthTime, 0)
|
|
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
|
|
age := params.Now.Sub(authTime)
|
|
diff := age - params.MaxAge
|
|
if jws.Claims.AuthTime > params.Now.Unix() {
|
|
fromNow := time.Unix(jws.Claims.AuthTime, 0).Sub(params.Now)
|
|
authTimeStr := time.Unix(jws.Claims.AuthTime, 0).Format("2006-01-02 15:04:05 MST")
|
|
errs = append(errs, fmt.Sprintf(
|
|
"'auth_time' (time of real-world user authentication) of %s is %s in the future (server time %s)",
|
|
authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")),
|
|
)
|
|
} else if age > params.MaxAge {
|
|
errs = append(errs, fmt.Sprintf(
|
|
"'auth_time' (time of real-world user authentication) of %s is %s old, which exceeds the max age of %s (%ds) by %s",
|
|
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), params.MaxAge/time.Second, formatDuration(diff)),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Optional
|
|
if params.Jti != jws.Claims.Jti {
|
|
if len(params.Jti) > 0 {
|
|
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", jws.Claims.Jti, params.Jti))
|
|
} else if !params.IgnoreJti {
|
|
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", jws.Claims.Jti))
|
|
}
|
|
}
|
|
|
|
// Optional
|
|
if params.Nonce != jws.Claims.Nonce {
|
|
if len(params.Nonce) > 0 {
|
|
errs = append(errs, fmt.Sprintf("'nonce' (one-time random salt, as string) mismatch: got %s, expected %s", jws.Claims.Nonce, params.Nonce))
|
|
} else if !params.IgnoreNonce {
|
|
errs = append(errs, fmt.Sprintf("unchecked 'nonce' (one-time random salt): %s", jws.Claims.Nonce))
|
|
}
|
|
}
|
|
|
|
// Acr check not implemented because the use case is not yet clear
|
|
|
|
// Should exist, optional match
|
|
if !params.IgnoreAmr {
|
|
if len(jws.Claims.Amr) == 0 {
|
|
errs = append(errs, ("missing or malformed 'amr' (authorization methods, as json list)"))
|
|
} else {
|
|
if len(params.RequiredAmrs) > 0 {
|
|
for _, required := range params.RequiredAmrs {
|
|
if !slices.Contains(jws.Claims.Amr, required) {
|
|
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr' (authorization methods, as json list)", required))
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO specify multiple amrs in a tiered list (must have at least one from each list)
|
|
// count := 0
|
|
// if len(params.AcceptableAmrs) > 0 {
|
|
// for _, amr := range jws.Claims.Amr {
|
|
// if slices.Contains(params.AcceptableAmrs, amr) {
|
|
// count += 1
|
|
// }
|
|
// }
|
|
// }
|
|
}
|
|
}
|
|
|
|
// Optional, should match if exists
|
|
if params.Azp != jws.Claims.Azp {
|
|
if len(params.Azp) > 0 {
|
|
errs = append(errs, ("missing or malformed 'azp' (authorized party which presents token)"))
|
|
} else if !params.IgnoreAzp {
|
|
errs = append(errs, fmt.Sprintf("'azp' mismatch (authorized party which presents token): got %s, expected %s", jws.Claims.Azp, params.Azp))
|
|
}
|
|
}
|
|
|
|
// Must be checked
|
|
if !params.IgnoreSig {
|
|
if !jws.Verified {
|
|
errs = append(errs, ("signature was not checked"))
|
|
}
|
|
}
|
|
|
|
if len(errs) > 0 {
|
|
timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST"))
|
|
if loc, err := time.LoadLocation("Local"); err == nil {
|
|
timeInfo += fmt.Sprintf(" %s", loc)
|
|
}
|
|
errs = append(errs, timeInfo)
|
|
return errs, fmt.Errorf("has errors")
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func SignJWS(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
|
|
hash := sha256.Sum256([]byte(header + "." + payload))
|
|
r, s, err := ecdsa.Sign(rand.Reader, key, hash[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to sign: %v", err)
|
|
}
|
|
return append(r.Bytes(), s.Bytes()...), nil
|
|
}
|
|
|
|
func (jws JWS) Encode() string {
|
|
sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature)
|
|
return jws.Protected + "." + jws.Payload + "." + sigEnc
|
|
}
|
|
|
|
func EncodeToJWT(signingInput string, signature []byte) string {
|
|
sigEnc := base64.RawURLEncoding.EncodeToString(signature)
|
|
return signingInput + "." + sigEnc
|
|
}
|
|
|
|
func (jwk JWK) Thumbprint() (string, error) {
|
|
data := map[string]string{
|
|
"crv": jwk.Crv,
|
|
"kty": jwk.Kty,
|
|
"x": jwk.X,
|
|
"y": jwk.Y,
|
|
}
|
|
jsonData, err := json.Marshal(data)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
hash := sha256.Sum256(jsonData)
|
|
return base64.RawURLEncoding.EncodeToString(hash[:]), nil
|
|
}
|
|
|
|
// URLBase64 unmarshals to bytes and marshals to a raw url base64 string
|
|
type URLBase64 []byte
|
|
|
|
func (s URLBase64) String() string {
|
|
encoded := base64.RawURLEncoding.EncodeToString(s)
|
|
return encoded
|
|
}
|
|
|
|
// MarshalJSON implements JSON marshaling to URL-safe base64.
|
|
func (s URLBase64) MarshalJSON() ([]byte, error) {
|
|
encoded := base64.RawURLEncoding.EncodeToString(s)
|
|
return json.Marshal(encoded)
|
|
}
|
|
|
|
// UnmarshalJSON implements JSON unmarshaling from URL-safe base64.
|
|
func (s *URLBase64) UnmarshalJSON(data []byte) error {
|
|
dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data)
|
|
if err != nil {
|
|
return fmt.Errorf("decode base64url signature: %w", err)
|
|
}
|
|
|
|
*s = dst
|
|
return nil
|
|
}
|
|
|
|
func formatDuration(d time.Duration) string {
|
|
if d < 0 {
|
|
d = -d
|
|
}
|
|
days := int(d / (24 * time.Hour))
|
|
d -= time.Duration(days) * 24 * time.Hour
|
|
hours := int(d / time.Hour)
|
|
d -= time.Duration(hours) * time.Hour
|
|
minutes := int(d / time.Minute)
|
|
d -= time.Duration(minutes) * time.Minute
|
|
seconds := int(d / time.Second)
|
|
|
|
var parts []string
|
|
if days > 0 {
|
|
parts = append(parts, fmt.Sprintf("%dd", days))
|
|
}
|
|
if hours > 0 {
|
|
parts = append(parts, fmt.Sprintf("%dh", hours))
|
|
}
|
|
if minutes > 0 {
|
|
parts = append(parts, fmt.Sprintf("%dm", minutes))
|
|
}
|
|
if seconds > 0 || len(parts) == 0 {
|
|
parts = append(parts, fmt.Sprintf("%ds", seconds))
|
|
}
|
|
if seconds == 0 || len(parts) == 0 {
|
|
d -= time.Duration(seconds) * time.Second
|
|
millis := int(d / time.Millisecond)
|
|
parts = append(parts, fmt.Sprintf("%dms", millis))
|
|
}
|
|
|
|
return strings.Join(parts, " ")
|
|
}
|