add sign subcommand
This commit is contained in:
parent
a2aa6b5411
commit
22ba73fa12
|
@ -7,54 +7,192 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.rootprojects.org/root/keypairs"
|
||||
)
|
||||
|
||||
var (
|
||||
name = "keypairs"
|
||||
version = "0.0.0"
|
||||
date = "0001-01-01T00:00:00Z"
|
||||
commit = "0000000"
|
||||
)
|
||||
|
||||
func usage() {
|
||||
ver()
|
||||
fmt.Println("Usage")
|
||||
fmt.Printf(" %s <command> [flags] args...\n", name)
|
||||
fmt.Println("")
|
||||
fmt.Printf("See usage: %s help <command>\n", name)
|
||||
fmt.Println("")
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" version")
|
||||
fmt.Println(" gen")
|
||||
fmt.Println(" sign")
|
||||
fmt.Println("")
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" keypairs gen -o key.jwk.json [--pub <public-key>]")
|
||||
fmt.Println(" keypairs sign --exp 15m key.jwk.json payload.json")
|
||||
fmt.Println(" keypairs sign --exp 15m key.jwk.json '{ \"sub\": \"xxxx\" }'")
|
||||
fmt.Println("")
|
||||
//fmt.Println(" verify")
|
||||
}
|
||||
|
||||
func ver() {
|
||||
fmt.Printf("%s v%s %s (%s)\n", name, version, commit[:7], date)
|
||||
}
|
||||
|
||||
func main() {
|
||||
if 1 == len(os.Args) || "gen" != os.Args[1] {
|
||||
fmt.Fprintln(os.Stderr, "Usage: keypairs gen -o <filename> [--pub <filename>]")
|
||||
args := os.Args[:]
|
||||
|
||||
if "help" == args[1] {
|
||||
// top-level help
|
||||
if 2 == len(args) {
|
||||
usage()
|
||||
os.Exit(0)
|
||||
return
|
||||
}
|
||||
// move help to subcommand argument
|
||||
self := args[0]
|
||||
args = append([]string{self}, args[2:]...)
|
||||
args = append(args, "--help")
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "version":
|
||||
ver()
|
||||
os.Exit(0)
|
||||
return
|
||||
case "gen":
|
||||
gen(args)
|
||||
case "sign":
|
||||
sign(args)
|
||||
default:
|
||||
usage()
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func gen(args []string) {
|
||||
var keyname string
|
||||
var pubname string
|
||||
flags := flag.NewFlagSet("gen", flag.ExitOnError)
|
||||
flags.StringVar(&keyname, "o", "", "private key file (ex: key.jwk.json or key.pem)")
|
||||
flags.StringVar(&pubname, "pub", "", "public key file (ex: pub.jwk.json or pub.pem)")
|
||||
flags.Parse(args)
|
||||
|
||||
key := keypairs.NewDefaultPrivateKey()
|
||||
marshalPriv(key, keyname)
|
||||
pub := keypairs.NewPublicKey(key.Public())
|
||||
marshalPub(pub, pubname)
|
||||
}
|
||||
|
||||
func sign(args []string) {
|
||||
var exp time.Duration
|
||||
flags := flag.NewFlagSet("sign", flag.ExitOnError)
|
||||
flags.DurationVar(&exp, "exp", 0, "duration until token expires (Default 15m)")
|
||||
flags.Parse(args)
|
||||
if len(flags.Args()) <= 3 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: keypairs sign --exp 1h <private PEM or JWK> ./payload.json\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
keyname := flags.Args()[2]
|
||||
payload := flags.Args()[3]
|
||||
|
||||
var key keypairs.PrivateKey = nil
|
||||
b, err := ioutil.ReadFile(keyname)
|
||||
if nil != err {
|
||||
var err2 error
|
||||
key, err2 = keypairs.ParsePrivateKey([]byte(keyname))
|
||||
if nil != err2 {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"could not read private key as file (or parse as string) %q: %s\n", keyname, err)
|
||||
}
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
if nil == key {
|
||||
var err3 error
|
||||
key, err3 = keypairs.ParsePrivateKey(b)
|
||||
if nil != err3 {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"could not parse private key from file %q: %s\n", keyname, err3)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if "" == payload {
|
||||
payload = "{}"
|
||||
}
|
||||
|
||||
b, err = ioutil.ReadFile(payload)
|
||||
claims := map[string]interface{}{}
|
||||
if nil != err {
|
||||
var err2 error
|
||||
err2 = json.Unmarshal([]byte(payload), &claims)
|
||||
if nil != err2 {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"could not read payload as file (or parse as string) %q: %s\n", payload, err)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
if 0 == len(claims) {
|
||||
var err3 error
|
||||
err3 = json.Unmarshal(b, &claims)
|
||||
if nil != err3 {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"could not parse palyoad from file %q: %s\n", payload, err3)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if 0 != exp {
|
||||
claims["exp"] = exp.Seconds()
|
||||
}
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = (15 * time.Minute).Seconds()
|
||||
}
|
||||
|
||||
jws, err := keypairs.SignClaims(key, nil, claims)
|
||||
if nil != err {
|
||||
fmt.Fprintf(os.Stderr, "could not sign claims: %v\n%#v\n", err, claims)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
// gen subcommand
|
||||
args := os.Args[2:]
|
||||
|
||||
var privname string
|
||||
var pubname string
|
||||
flags := flag.NewFlagSet("gen", flag.ExitOnError)
|
||||
flags.StringVar(&privname, "o", "", "private key file (should have .jwk.json or pkcs8.pem extension)")
|
||||
flags.StringVar(&pubname, "pub", "", "public key file (should have .jwk.json or spki.pem extension)")
|
||||
flags.Parse(args)
|
||||
|
||||
priv := keypairs.NewDefaultPrivateKey()
|
||||
marshalPriv(priv, privname)
|
||||
marshalPub(keypairs.NewPublicKey(priv.Public()), pubname)
|
||||
b, _ = json.Marshal(&jws)
|
||||
fmt.Printf("JWS:\n%s\n\n", indentJSON(b))
|
||||
fmt.Printf("JWT:\n%s\n\n", keypairs.JWSToJWT(jws))
|
||||
}
|
||||
|
||||
func marshalPriv(priv keypairs.PrivateKey, privname string) {
|
||||
if "" == privname {
|
||||
b := indentJSON(keypairs.MarshalJWKPrivateKey(priv))
|
||||
func marshalPriv(key keypairs.PrivateKey, keyname string) {
|
||||
if "" == keyname {
|
||||
b := indentJSON(keypairs.MarshalJWKPrivateKey(key))
|
||||
|
||||
fmt.Fprintf(os.Stdout, string(b)+"\n")
|
||||
return
|
||||
}
|
||||
|
||||
var b []byte
|
||||
if strings.HasSuffix(privname, ".json") {
|
||||
b = indentJSON(keypairs.MarshalJWKPrivateKey(priv))
|
||||
} else if strings.HasSuffix(privname, ".pem") {
|
||||
b, _ = keypairs.MarshalPEMPrivateKey(priv)
|
||||
} else if strings.HasSuffix(privname, ".der") {
|
||||
b, _ = keypairs.MarshalDERPrivateKey(priv)
|
||||
if strings.HasSuffix(keyname, ".json") {
|
||||
b = indentJSON(keypairs.MarshalJWKPrivateKey(key))
|
||||
} else if strings.HasSuffix(keyname, ".pem") {
|
||||
b, _ = keypairs.MarshalPEMPrivateKey(key)
|
||||
} else if strings.HasSuffix(keyname, ".der") {
|
||||
b, _ = keypairs.MarshalDERPrivateKey(key)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "private key extension should be .jwk.json, .pem, or .der")
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
ioutil.WriteFile(privname, b, 0600)
|
||||
ioutil.WriteFile(keyname, b, 0600)
|
||||
}
|
||||
|
||||
func marshalPub(pub keypairs.PublicKey, pubname string) {
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
package keypairs
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
mathrand "math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// randReader may be overwritten for testing
|
||||
//var randReader io.Reader = rand.Reader
|
||||
|
||||
//var randReader = rand.Reader
|
||||
|
||||
// JWS is a parsed JWT, representation as signable/verifiable and human-readable parts
|
||||
type JWS struct {
|
||||
Header Object `json:"header"` // JSON
|
||||
Claims Object `json:"claims"` // JSON
|
||||
Protected string `json:"protected"` // base64
|
||||
Payload string `json:"payload"` // base64
|
||||
Signature string `json:"signature"` // base64
|
||||
}
|
||||
|
||||
// Object is a type alias representing generic JSON data
|
||||
type Object = map[string]interface{}
|
||||
|
||||
// SignClaims adds `typ`, `kid` (or `jwk`), and `alg` in the header and expects claims for `jti`, `exp`, `iss`, and `iat`
|
||||
func SignClaims(privkey PrivateKey, header Object, claims Object) (*JWS, error) {
|
||||
var randsrc io.Reader = randReader
|
||||
seed, _ := header["_seed"].(int64)
|
||||
if 0 != seed {
|
||||
randsrc = mathrand.New(mathrand.NewSource(seed))
|
||||
//delete(header, "_seed")
|
||||
}
|
||||
|
||||
protected, header, err := headerToProtected(NewPublicKey(privkey.Public()), header)
|
||||
if nil != err {
|
||||
return nil, err
|
||||
}
|
||||
protected64 := base64.RawURLEncoding.EncodeToString(protected)
|
||||
|
||||
payload, err := claimsToPayload(claims)
|
||||
if nil != err {
|
||||
return nil, err
|
||||
}
|
||||
payload64 := base64.RawURLEncoding.EncodeToString(payload)
|
||||
|
||||
signable := fmt.Sprintf(`%s.%s`, protected64, payload64)
|
||||
hash := sha256.Sum256([]byte(signable))
|
||||
|
||||
sig := Sign(privkey, hash[:], randsrc)
|
||||
sig64 := base64.RawURLEncoding.EncodeToString(sig)
|
||||
//log.Printf("\n(Sign)\nSignable: %s", signable)
|
||||
//log.Printf("Hash: %s", hash)
|
||||
//log.Printf("Sig: %s", sig64)
|
||||
|
||||
return &JWS{
|
||||
Header: header,
|
||||
Claims: claims,
|
||||
Protected: protected64,
|
||||
Payload: payload64,
|
||||
Signature: sig64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func headerToProtected(pub PublicKey, header Object) ([]byte, Object, error) {
|
||||
if nil == header {
|
||||
header = Object{}
|
||||
}
|
||||
|
||||
// Only supporting 2048-bit and P256 keys right now
|
||||
// because that's all that's practical and well-supported.
|
||||
// No security theatre here.
|
||||
alg := "ES256"
|
||||
switch pub.Key().(type) {
|
||||
case *rsa.PublicKey:
|
||||
alg = "RS256"
|
||||
}
|
||||
|
||||
if selfSign, _ := header["_jwk"].(bool); selfSign {
|
||||
delete(header, "_jwk")
|
||||
any := Object{}
|
||||
_ = json.Unmarshal(MarshalJWKPublicKey(pub), &any)
|
||||
header["jwk"] = any
|
||||
}
|
||||
|
||||
// TODO what are the acceptable values? JWT. JWS? others?
|
||||
header["typ"] = "JWT"
|
||||
if _, ok := header["jwk"]; !ok {
|
||||
thumbprint := ThumbprintPublicKey(pub)
|
||||
kid, _ := header["kid"].(string)
|
||||
if "" != kid && thumbprint != kid {
|
||||
return nil, nil, errors.New("'kid' should be the key's thumbprint")
|
||||
}
|
||||
header["kid"] = thumbprint
|
||||
}
|
||||
header["alg"] = alg
|
||||
|
||||
protected, err := json.Marshal(header)
|
||||
if nil != err {
|
||||
return nil, nil, err
|
||||
}
|
||||
return protected, header, nil
|
||||
}
|
||||
|
||||
func claimsToPayload(claims Object) ([]byte, error) {
|
||||
if nil == claims {
|
||||
claims = Object{}
|
||||
}
|
||||
|
||||
var dur time.Duration
|
||||
jti, _ := claims["jti"].(string)
|
||||
insecure, _ := claims["insecure"].(bool)
|
||||
|
||||
switch exp := claims["exp"].(type) {
|
||||
case time.Duration:
|
||||
// TODO: MUST this go first?
|
||||
// int64(time.Duration) vs time.Duration(int64)
|
||||
dur = exp
|
||||
case string:
|
||||
var err error
|
||||
dur, err = time.ParseDuration(exp)
|
||||
// TODO s, err := time.ParseDuration(dur)
|
||||
if nil != err {
|
||||
return nil, err
|
||||
}
|
||||
case int:
|
||||
dur = time.Second * time.Duration(exp)
|
||||
case int64:
|
||||
dur = time.Second * time.Duration(exp)
|
||||
case float64:
|
||||
dur = time.Second * time.Duration(exp)
|
||||
default:
|
||||
dur = 0
|
||||
}
|
||||
|
||||
if "" == jti && 0 == dur && !insecure {
|
||||
return nil, errors.New("token must have jti or exp as to be expirable / cancellable")
|
||||
}
|
||||
claims["exp"] = time.Now().Add(dur).Unix()
|
||||
|
||||
return json.Marshal(claims)
|
||||
}
|
||||
|
||||
// JWSToJWT joins JWS parts into a JWT as {ProtectedHeader}.{SerializedPayload}.{Signature}.
|
||||
func JWSToJWT(jwt *JWS) string {
|
||||
return fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
jwt.Protected,
|
||||
jwt.Payload,
|
||||
jwt.Signature,
|
||||
)
|
||||
}
|
||||
|
||||
// Sign signs both RSA and ECDSA. Use `nil` or `crypto/rand.Reader` except for debugging.
|
||||
func Sign(privkey PrivateKey, hash []byte, rand io.Reader) []byte {
|
||||
if nil == rand {
|
||||
rand = randReader
|
||||
}
|
||||
var sig []byte
|
||||
|
||||
if len(hash) != 32 {
|
||||
panic("only 256-bit hashes for 2048-bit and 256-bit keys are supported")
|
||||
}
|
||||
|
||||
switch k := privkey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
sig, _ = rsa.SignPKCS1v15(rand, k, crypto.SHA256, hash)
|
||||
case *ecdsa.PrivateKey:
|
||||
r, s, _ := ecdsa.Sign(rand, k, hash[:])
|
||||
rb := r.Bytes()
|
||||
for len(rb) < 32 {
|
||||
rb = append([]byte{0}, rb...)
|
||||
}
|
||||
sb := s.Bytes()
|
||||
for len(rb) < 32 {
|
||||
sb = append([]byte{0}, sb...)
|
||||
}
|
||||
sig = append(rb, sb...)
|
||||
}
|
||||
return sig
|
||||
}
|
Loading…
Reference in New Issue