diff --git a/cmd/keypairs/keypairs.go b/cmd/keypairs/keypairs.go index 1201f79..20e38ce 100644 --- a/cmd/keypairs/keypairs.go +++ b/cmd/keypairs/keypairs.go @@ -7,54 +7,192 @@ import ( "io/ioutil" "os" "strings" + "time" "git.rootprojects.org/root/keypairs" ) +var ( + name = "keypairs" + version = "0.0.0" + date = "0001-01-01T00:00:00Z" + commit = "0000000" +) + +func usage() { + ver() + fmt.Println("Usage") + fmt.Printf(" %s [flags] args...\n", name) + fmt.Println("") + fmt.Printf("See usage: %s help \n", name) + fmt.Println("") + fmt.Println("Commands:") + fmt.Println(" version") + fmt.Println(" gen") + fmt.Println(" sign") + fmt.Println("") + fmt.Println("Examples:") + fmt.Println(" keypairs gen -o key.jwk.json [--pub ]") + fmt.Println(" keypairs sign --exp 15m key.jwk.json payload.json") + fmt.Println(" keypairs sign --exp 15m key.jwk.json '{ \"sub\": \"xxxx\" }'") + fmt.Println("") + //fmt.Println(" verify") +} + +func ver() { + fmt.Printf("%s v%s %s (%s)\n", name, version, commit[:7], date) +} + func main() { - if 1 == len(os.Args) || "gen" != os.Args[1] { - fmt.Fprintln(os.Stderr, "Usage: keypairs gen -o [--pub ]") + args := os.Args[:] + + if "help" == args[1] { + // top-level help + if 2 == len(args) { + usage() + os.Exit(0) + return + } + // move help to subcommand argument + self := args[0] + args = append([]string{self}, args[2:]...) + args = append(args, "--help") + } + + switch args[1] { + case "version": + ver() + os.Exit(0) + return + case "gen": + gen(args) + case "sign": + sign(args) + default: + usage() + os.Exit(1) + return + } +} + +func gen(args []string) { + var keyname string + var pubname string + flags := flag.NewFlagSet("gen", flag.ExitOnError) + flags.StringVar(&keyname, "o", "", "private key file (ex: key.jwk.json or key.pem)") + flags.StringVar(&pubname, "pub", "", "public key file (ex: pub.jwk.json or pub.pem)") + flags.Parse(args) + + key := keypairs.NewDefaultPrivateKey() + marshalPriv(key, keyname) + pub := keypairs.NewPublicKey(key.Public()) + marshalPub(pub, pubname) +} + +func sign(args []string) { + var exp time.Duration + flags := flag.NewFlagSet("sign", flag.ExitOnError) + flags.DurationVar(&exp, "exp", 0, "duration until token expires (Default 15m)") + flags.Parse(args) + if len(flags.Args()) <= 3 { + fmt.Fprintf(os.Stderr, "Usage: keypairs sign --exp 1h ./payload.json\n") + os.Exit(1) + } + + keyname := flags.Args()[2] + payload := flags.Args()[3] + + var key keypairs.PrivateKey = nil + b, err := ioutil.ReadFile(keyname) + if nil != err { + var err2 error + key, err2 = keypairs.ParsePrivateKey([]byte(keyname)) + if nil != err2 { + fmt.Fprintf(os.Stderr, + "could not read private key as file (or parse as string) %q: %s\n", keyname, err) + } + os.Exit(1) + return + } + if nil == key { + var err3 error + key, err3 = keypairs.ParsePrivateKey(b) + if nil != err3 { + fmt.Fprintf(os.Stderr, + "could not parse private key from file %q: %s\n", keyname, err3) + os.Exit(1) + return + } + } + + if "" == payload { + payload = "{}" + } + + b, err = ioutil.ReadFile(payload) + claims := map[string]interface{}{} + if nil != err { + var err2 error + err2 = json.Unmarshal([]byte(payload), &claims) + if nil != err2 { + fmt.Fprintf(os.Stderr, + "could not read payload as file (or parse as string) %q: %s\n", payload, err) + os.Exit(1) + return + } + } + if 0 == len(claims) { + var err3 error + err3 = json.Unmarshal(b, &claims) + if nil != err3 { + fmt.Fprintf(os.Stderr, + "could not parse palyoad from file %q: %s\n", payload, err3) + os.Exit(1) + return + } + } + + if 0 != exp { + claims["exp"] = exp.Seconds() + } + if _, ok := claims["exp"]; !ok { + claims["exp"] = (15 * time.Minute).Seconds() + } + + jws, err := keypairs.SignClaims(key, nil, claims) + if nil != err { + fmt.Fprintf(os.Stderr, "could not sign claims: %v\n%#v\n", err, claims) os.Exit(1) return } - // gen subcommand - args := os.Args[2:] - - var privname string - var pubname string - flags := flag.NewFlagSet("gen", flag.ExitOnError) - flags.StringVar(&privname, "o", "", "private key file (should have .jwk.json or pkcs8.pem extension)") - flags.StringVar(&pubname, "pub", "", "public key file (should have .jwk.json or spki.pem extension)") - flags.Parse(args) - - priv := keypairs.NewDefaultPrivateKey() - marshalPriv(priv, privname) - marshalPub(keypairs.NewPublicKey(priv.Public()), pubname) + b, _ = json.Marshal(&jws) + fmt.Printf("JWS:\n%s\n\n", indentJSON(b)) + fmt.Printf("JWT:\n%s\n\n", keypairs.JWSToJWT(jws)) } -func marshalPriv(priv keypairs.PrivateKey, privname string) { - if "" == privname { - b := indentJSON(keypairs.MarshalJWKPrivateKey(priv)) +func marshalPriv(key keypairs.PrivateKey, keyname string) { + if "" == keyname { + b := indentJSON(keypairs.MarshalJWKPrivateKey(key)) fmt.Fprintf(os.Stdout, string(b)+"\n") return } var b []byte - if strings.HasSuffix(privname, ".json") { - b = indentJSON(keypairs.MarshalJWKPrivateKey(priv)) - } else if strings.HasSuffix(privname, ".pem") { - b, _ = keypairs.MarshalPEMPrivateKey(priv) - } else if strings.HasSuffix(privname, ".der") { - b, _ = keypairs.MarshalDERPrivateKey(priv) + if strings.HasSuffix(keyname, ".json") { + b = indentJSON(keypairs.MarshalJWKPrivateKey(key)) + } else if strings.HasSuffix(keyname, ".pem") { + b, _ = keypairs.MarshalPEMPrivateKey(key) + } else if strings.HasSuffix(keyname, ".der") { + b, _ = keypairs.MarshalDERPrivateKey(key) } else { fmt.Fprintf(os.Stderr, "private key extension should be .jwk.json, .pem, or .der") os.Exit(1) return } - ioutil.WriteFile(privname, b, 0600) + ioutil.WriteFile(keyname, b, 0600) } func marshalPub(pub keypairs.PublicKey, pubname string) { diff --git a/sign.go b/sign.go new file mode 100644 index 0000000..7b7aae0 --- /dev/null +++ b/sign.go @@ -0,0 +1,189 @@ +package keypairs + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + mathrand "math/rand" + "time" +) + +// randReader may be overwritten for testing +//var randReader io.Reader = rand.Reader + +//var randReader = rand.Reader + +// JWS is a parsed JWT, representation as signable/verifiable and human-readable parts +type JWS struct { + Header Object `json:"header"` // JSON + Claims Object `json:"claims"` // JSON + Protected string `json:"protected"` // base64 + Payload string `json:"payload"` // base64 + Signature string `json:"signature"` // base64 +} + +// Object is a type alias representing generic JSON data +type Object = map[string]interface{} + +// SignClaims adds `typ`, `kid` (or `jwk`), and `alg` in the header and expects claims for `jti`, `exp`, `iss`, and `iat` +func SignClaims(privkey PrivateKey, header Object, claims Object) (*JWS, error) { + var randsrc io.Reader = randReader + seed, _ := header["_seed"].(int64) + if 0 != seed { + randsrc = mathrand.New(mathrand.NewSource(seed)) + //delete(header, "_seed") + } + + protected, header, err := headerToProtected(NewPublicKey(privkey.Public()), header) + if nil != err { + return nil, err + } + protected64 := base64.RawURLEncoding.EncodeToString(protected) + + payload, err := claimsToPayload(claims) + if nil != err { + return nil, err + } + payload64 := base64.RawURLEncoding.EncodeToString(payload) + + signable := fmt.Sprintf(`%s.%s`, protected64, payload64) + hash := sha256.Sum256([]byte(signable)) + + sig := Sign(privkey, hash[:], randsrc) + sig64 := base64.RawURLEncoding.EncodeToString(sig) + //log.Printf("\n(Sign)\nSignable: %s", signable) + //log.Printf("Hash: %s", hash) + //log.Printf("Sig: %s", sig64) + + return &JWS{ + Header: header, + Claims: claims, + Protected: protected64, + Payload: payload64, + Signature: sig64, + }, nil +} + +func headerToProtected(pub PublicKey, header Object) ([]byte, Object, error) { + if nil == header { + header = Object{} + } + + // Only supporting 2048-bit and P256 keys right now + // because that's all that's practical and well-supported. + // No security theatre here. + alg := "ES256" + switch pub.Key().(type) { + case *rsa.PublicKey: + alg = "RS256" + } + + if selfSign, _ := header["_jwk"].(bool); selfSign { + delete(header, "_jwk") + any := Object{} + _ = json.Unmarshal(MarshalJWKPublicKey(pub), &any) + header["jwk"] = any + } + + // TODO what are the acceptable values? JWT. JWS? others? + header["typ"] = "JWT" + if _, ok := header["jwk"]; !ok { + thumbprint := ThumbprintPublicKey(pub) + kid, _ := header["kid"].(string) + if "" != kid && thumbprint != kid { + return nil, nil, errors.New("'kid' should be the key's thumbprint") + } + header["kid"] = thumbprint + } + header["alg"] = alg + + protected, err := json.Marshal(header) + if nil != err { + return nil, nil, err + } + return protected, header, nil +} + +func claimsToPayload(claims Object) ([]byte, error) { + if nil == claims { + claims = Object{} + } + + var dur time.Duration + jti, _ := claims["jti"].(string) + insecure, _ := claims["insecure"].(bool) + + switch exp := claims["exp"].(type) { + case time.Duration: + // TODO: MUST this go first? + // int64(time.Duration) vs time.Duration(int64) + dur = exp + case string: + var err error + dur, err = time.ParseDuration(exp) + // TODO s, err := time.ParseDuration(dur) + if nil != err { + return nil, err + } + case int: + dur = time.Second * time.Duration(exp) + case int64: + dur = time.Second * time.Duration(exp) + case float64: + dur = time.Second * time.Duration(exp) + default: + dur = 0 + } + + if "" == jti && 0 == dur && !insecure { + return nil, errors.New("token must have jti or exp as to be expirable / cancellable") + } + claims["exp"] = time.Now().Add(dur).Unix() + + return json.Marshal(claims) +} + +// JWSToJWT joins JWS parts into a JWT as {ProtectedHeader}.{SerializedPayload}.{Signature}. +func JWSToJWT(jwt *JWS) string { + return fmt.Sprintf( + "%s.%s.%s", + jwt.Protected, + jwt.Payload, + jwt.Signature, + ) +} + +// Sign signs both RSA and ECDSA. Use `nil` or `crypto/rand.Reader` except for debugging. +func Sign(privkey PrivateKey, hash []byte, rand io.Reader) []byte { + if nil == rand { + rand = randReader + } + var sig []byte + + if len(hash) != 32 { + panic("only 256-bit hashes for 2048-bit and 256-bit keys are supported") + } + + switch k := privkey.(type) { + case *rsa.PrivateKey: + sig, _ = rsa.SignPKCS1v15(rand, k, crypto.SHA256, hash) + case *ecdsa.PrivateKey: + r, s, _ := ecdsa.Sign(rand, k, hash[:]) + rb := r.Bytes() + for len(rb) < 32 { + rb = append([]byte{0}, rb...) + } + sb := s.Bytes() + for len(rb) < 32 { + sb = append([]byte{0}, sb...) + } + sig = append(rb, sb...) + } + return sig +}