diff --git a/keyfetch/fetch.go b/keyfetch/fetch.go index 3014257..2ac92d7 100644 --- a/keyfetch/fetch.go +++ b/keyfetch/fetch.go @@ -104,6 +104,28 @@ func JWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs) } +// PEM tries to return a key from cache, falling back to the specified PEM url +func PEM(url string) (keypairs.PublicKey, error) { + // url is kid in this case + return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { + m, key, err := uncached.PEM(url) + if nil != err { + return nil, nil, err + } + + // put in a map, just for caching + maps := map[string]map[string]string{} + maps[key.Thumbprint()] = m + maps[url] = m + + keys := map[string]keypairs.PublicKey{} + keys[key.Thumbprint()] = key + keys[url] = key + + return maps, keys, nil + }) +} + // Fetch returns a key from cache, falling back to an exact url as the "issuer" func Fetch(url string) (keypairs.PublicKey, error) { // url is kid in this case diff --git a/keyfetch/fetch_test.go b/keyfetch/fetch_test.go index 76d2ffe..4e5b89f 100644 --- a/keyfetch/fetch_test.go +++ b/keyfetch/fetch_test.go @@ -8,10 +8,20 @@ import ( "github.com/big-squid/go-keypairs/keyfetch/uncached" ) +var pubkey keypairs.PublicKey + func TestCachesKey(t *testing.T) { testCachesKey(t, "https://bigsquid.auth0.com/") clear() testCachesKey(t, "https://bigsquid.auth0.com") + // Get PEM + k3, err := PEM("https://bigsquid.auth0.com/pem") + if nil != err { + t.Fatal("Error fetching and caching key:", err) + } + if k3.Thumbprint() != pubkey.Thumbprint() { + t.Fatal("Error got different thumbprint for different versions of the same key:", err) + } clear() testCachesKey(t, "https://big-squid.github.io/") } @@ -39,7 +49,7 @@ func testCachesKey(t *testing.T, url string) { } // Get with caching - k2, err := OIDCJWK(thumb, url) + pubkey, err = OIDCJWK(thumb, url) if nil != err { t.Fatal("Error fetching and caching key:", err) } @@ -48,9 +58,9 @@ func testCachesKey(t *testing.T, url string) { if pub := Get(thumb, ""); nil == pub { t.Fatal("key was not properly cached by thumbprint", thumb) } - if "" != k2.KeyID() { - if pub := Get(k2.KeyID(), url); nil == pub { - t.Fatal("key was not properly cached by kid", k2.KeyID()) + if "" != pubkey.KeyID() { + if pub := Get(pubkey.KeyID(), url); nil == pub { + t.Fatal("key was not properly cached by kid", pubkey.KeyID()) } } else { t.Log("Key did not have an explicit KeyID") @@ -67,7 +77,13 @@ func testCachesKey(t *testing.T, url string) { } // Sanity check that the kid and thumb match - if key.KeyID() != k2.KeyID() || key.Thumbprint() != k2.Thumbprint() { - t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), k2.KeyID(), key.Thumbprint(), k2.Thumbprint()) + if key.KeyID() != pubkey.KeyID() || key.Thumbprint() != pubkey.Thumbprint() { + t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), pubkey.KeyID(), key.Thumbprint(), pubkey.Thumbprint()) + } + + // Get 404 + _, err = PEM(url + "/will-not-be-found.xyz") + if nil == err { + t.Fatal("Should have an error when retrieving a 404 or index.html:", err) } } diff --git a/keyfetch/uncached/fetch.go b/keyfetch/uncached/fetch.go index e909bab..532170c 100644 --- a/keyfetch/uncached/fetch.go +++ b/keyfetch/uncached/fetch.go @@ -2,8 +2,11 @@ package uncached import ( + "bytes" "encoding/json" + "errors" "io" + "io/ioutil" "net" "net/http" "strings" @@ -74,6 +77,41 @@ func JWKs(jwksurl string) (map[string]map[string]string, map[string]keypairs.Pub return maps, keys, nil } +// PEM fetches and parses a PEM (assuming well-known format) +func PEM(pemurl string) (map[string]string, keypairs.PublicKey, error) { + var pub keypairs.PublicKey + if err := safeFetch(pemurl, func(body io.Reader) error { + pem, err := ioutil.ReadAll(body) + if nil != err { + return err + } + pub, err = keypairs.ParsePublicKey(pem) + return err + }); nil != err { + return nil, nil, err + } + + jwk := map[string]interface{}{} + body := bytes.NewBuffer(keypairs.MarshalJWKPublicKey(pub)) + decoder := json.NewDecoder(body) + decoder.UseNumber() + _ = decoder.Decode(&jwk) + + m := getStringMap(jwk) + m["kid"] = pemurl + + switch p := pub.(type) { + case *keypairs.ECPublicKey: + p.KID = pemurl + case *keypairs.RSAPublicKey: + p.KID = pemurl + default: + return nil, nil, errors.New("impossible key type") + } + + return m, pub, nil +} + // Fetch retrieves a single JWK (plain, bare jwk) from a URL (off-spec) func Fetch(url string) (map[string]string, keypairs.PublicKey, error) { var m map[string]interface{} diff --git a/keypairs.go b/keypairs.go index f91a712..cd39388 100644 --- a/keypairs.go +++ b/keypairs.go @@ -314,7 +314,7 @@ func getPEMBytes(block []byte) ([][]byte, error) { } } -// ParsePrivateKey will try to parse the bytes you give it +// ParsePublicKey will try to parse the bytes you give it // in any of the supported formats: PEM, DER, PKIX/SPKI, PKCS1, x509 Certificate, and JWK func ParsePublicKey(block []byte) (PublicKey, error) { blocks, err := getPEMBytes(block) @@ -324,7 +324,7 @@ func ParsePublicKey(block []byte) (PublicKey, error) { // Parse PEM blocks (openssl generates junk metadata blocks for ECs) // or the original DER, or the JWK - for i, _ := range blocks { + for i := range blocks { block = blocks[i] if key, err := parsePublicKey(block); nil == err { return key, nil @@ -341,8 +341,6 @@ func ParsePublicKeyString(block string) (PublicKey, error) { } func parsePublicKey(der []byte) (PublicKey, error) { - var key PublicKey - cert, err := x509.ParseCertificate(der) if nil == err { switch k := cert.PublicKey.(type) { @@ -351,7 +349,7 @@ func parsePublicKey(der []byte) (PublicKey, error) { case *ecdsa.PublicKey: return NewPublicKey(k), nil default: - err = errors.New("Only RSA and ECDSA (EC) Public Keys are supported") + return nil, errors.New("Only RSA and ECDSA (EC) Public Keys are supported") } } @@ -364,28 +362,27 @@ func parsePublicKey(der []byte) (PublicKey, error) { case *ecdsa.PublicKey: return NewPublicKey(k), nil default: - err = errors.New("Only RSA and ECDSA (EC) Public Keys are supported") + return nil, errors.New("Only RSA and ECDSA (EC) Public Keys are supported") } } - if nil != err { - //fmt.Println("3. ParsePKCS1PrublicKey") - keyx, err := x509.ParsePKCS1PublicKey(der) - key = NewPublicKey(keyx) + //fmt.Println("3. ParsePKCS1PrublicKey") + rkey, err := x509.ParsePKCS1PublicKey(der) + if nil == err { + //fmt.Println("4. ParseJWKPublicKey") + return NewPublicKey(rkey), nil + } + + return ParseJWKPublicKey(der) + + /* + // But did you know? + // You must return nil explicitly for interfaces + // https://golang.org/doc/faq#nil_error if nil != err { - //fmt.Println("4. ParseJWKPublicKey") - key, err = ParseJWKPublicKey(der) + return nil, err } - } - - // But did you know? - // You must return nil explicitly for interfaces - // https://golang.org/doc/faq#nil_error - if nil != err { - return nil, err - } - - return key, nil + */ } // NewJWKPublicKey contstructs a PublicKey from the relevant pieces a map[string]string (generic JSON)