From a62ae2ea87a203a061ed6400fb49b4823265cd47 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Wed, 25 Nov 2020 03:10:33 -0700 Subject: [PATCH] add inspect with OIDC key fetch --- cmd/keypairs/keypairs.go | 102 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 2 deletions(-) diff --git a/cmd/keypairs/keypairs.go b/cmd/keypairs/keypairs.go index dcb0efa..52e270c 100644 --- a/cmd/keypairs/keypairs.go +++ b/cmd/keypairs/keypairs.go @@ -10,6 +10,7 @@ import ( "time" "git.rootprojects.org/root/keypairs" + "git.rootprojects.org/root/keypairs/keyfetch" ) var ( @@ -31,6 +32,7 @@ func usage() { fmt.Println(" version") fmt.Println(" gen") fmt.Println(" sign") + fmt.Println(" inspect (decode)") fmt.Println(" verify") fmt.Println("") fmt.Println("Examples:") @@ -39,6 +41,8 @@ func usage() { fmt.Println(" keypairs sign --exp 15m key.jwk.json payload.json") fmt.Println(" keypairs sign --exp 15m key.jwk.json '{ \"sub\": \"xxxx\" }'") fmt.Println("") + fmt.Println(" keypairs inspect --verbose 'xxxx.yyyy.zzzz'") + fmt.Println("") fmt.Println(" keypairs verify ./pub.jwk.json 'xxxx.yyyy.zzzz'") // TODO fmt.Println(" keypairs verify --issuer https://example.com '{ \"sub\": \"xxxx\" }'") fmt.Println("") @@ -73,6 +77,10 @@ func main() { gen(args[2:]) case "sign": sign(args[2:]) + case "decode": + fallthrough + case "inspect": + inspect(args[2:]) case "verify": verify(args[2:]) default: @@ -176,19 +184,109 @@ func sign(args []string) { fmt.Fprintf(os.Stdout, "%s\n", keypairs.JWSToJWT(jws)) } +func inspect(args []string) { + var verbose bool + flags := flag.NewFlagSet("inspect", flag.ExitOnError) + flags.BoolVar(&verbose, "verbose", true, "print extra info") + flags.Usage = func() { + fmt.Println("Usage: keypairs inspect ") + fmt.Println("") + fmt.Println(" : a JWT or JWS File or String, if JWS the payload must be Base64") + fmt.Println("") + } + flags.Parse(args) + if len(flags.Args()) < 1 { + flags.Usage() + os.Exit(1) + return + } + + payload := flags.Args()[0] + jws, err := readJWS(payload) + if nil != err { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + return + } + + var pub keypairs.PublicKey = nil + // because interfaces are never truly nil + hasPub := false + jwk, _ := jws.Header["jwk"].(map[string]interface{}) + jwkE, _ := jwk["e"].(string) + jwkX, _ := jwk["x"].(string) + kid, _ := jws.Header["kid"].(string) + if len(jwkE) > 0 || len(jwkX) > 0 { + // TODO verify self-signed certificate + //b, _ := json.MarshalIndent(&jwk, "", " ") + if len(kid) > 0 { + fmt.Fprintf(os.Stderr, "[warn] jws header has both 'kid' (Key ID) and 'jwk' (for self-signed only)\n") + } else { + fmt.Fprintf(os.Stderr, "[debug] token is self-signed (jwk)\n") + //pub = pubx + //hasPub = true + } + } else if len(kid) > 0 { + iss, _ := jws.Claims["iss"].(string) + if strings.HasPrefix(iss, "http:") || strings.HasPrefix(iss, "https:") { + //fmt.Printf("iss: %s\n", iss) + //fmt.Printf("kid: %s\n", kid) + fmt.Fprintf(os.Stderr, "Checking for OIDC key... ") + pubx, err := keyfetch.OIDCJWK(kid, iss) + if nil != err { + fmt.Fprintf(os.Stderr, "not found.\n") + // ignore + } else { + fmt.Fprintf(os.Stderr, "found:\n") + if verbose { + b := keypairs.MarshalJWKPublicKey(pubx) + fmt.Fprintf(os.Stderr, "%s\n", indentJSON(b)) + } + pub = pubx + hasPub = true + } + } + } + + validSig := false + if hasPub { + errs := keypairs.VerifyClaims(pub, jws) + if len(errs) > 0 { + fmt.Fprintf(os.Stderr, "error:\n") + for _, err := range errs { + fmt.Fprintf(os.Stderr, "\t%v\n", err) + } + } else { + validSig = true + } + } + + b, _ := json.MarshalIndent(&jws, "", " ") + fmt.Fprintf(os.Stdout, "%s\n", b) + + if validSig { + fmt.Fprintf(os.Stderr, "Signature is Valid\n") + } +} + func verify(args []string) { flags := flag.NewFlagSet("verify", flag.ExitOnError) flags.Usage = func() { - fmt.Println("Usage: keypairs verify ") + fmt.Println("Usage: keypairs verify [public key] ") fmt.Println("") fmt.Println(" : a File or String of an EC or RSA key in JWK or PEM format") fmt.Println(" : a JWT or JWS File or String, if JWS the payload must be Base64") fmt.Println("") } flags.Parse(args) - if len(flags.Args()) <= 1 { + if len(flags.Args()) < 1 { flags.Usage() os.Exit(1) + return + } + if 1 == len(flags.Args()) { + inspect(args) + return } pubname := flags.Args()[0]