diff --git a/cmd/keypairs/keypairs.go b/cmd/keypairs/keypairs.go index 9ab4113..c8587ac 100644 --- a/cmd/keypairs/keypairs.go +++ b/cmd/keypairs/keypairs.go @@ -363,9 +363,9 @@ func marshalPub(pub keypairs.PublicKey, pubname string) { if strings.HasSuffix(pubname, ".json") { b = indentJSON(keypairs.MarshalJWKPublicKey(pub)) } else if strings.HasSuffix(pubname, ".pem") { - b, _ = keypairs.MarshalPEMPublicKey(pub) + b, _ = keypairs.MarshalPEMPublicKey(pub.Key().(keypairs.PublicKeyTransitional)) } else if strings.HasSuffix(pubname, ".der") { - b, _ = keypairs.MarshalDERPublicKey(pub) + b, _ = keypairs.MarshalDERPublicKey(pub.Key().(keypairs.PublicKeyTransitional)) } ioutil.WriteFile(pubname, b, 0644) diff --git a/keyfetch/fetch.go b/keyfetch/fetch.go index c609531..27d8f96 100644 --- a/keyfetch/fetch.go +++ b/keyfetch/fetch.go @@ -25,8 +25,8 @@ import ( // TODO should be ErrInvalidJWKURL -// EInvalidJWKURL means that the url did not provide JWKs -var EInvalidJWKURL = errors.New("url does not lead to valid JWKs") +// ErrInvalidJWKURL means that the url did not provide JWKs +var ErrInvalidJWKURL = errors.New("url does not lead to valid JWKs") // KeyCache is an in-memory key cache var KeyCache = map[string]CachableKey{} @@ -94,7 +94,7 @@ func OIDCJWKs(baseURL string) (PublicKeysMap, error) { } // OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint) -func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { +func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs) } @@ -110,7 +110,7 @@ func WellKnownJWKs(kidOrThumb, iss string) (PublicKeysMap, error) { } // WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint) -func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { +func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs) } @@ -128,12 +128,12 @@ func JWKs(jwksurl string) (PublicKeysMap, error) { } // JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer -func JWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { +func JWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs) } // PEM tries to return a key from cache, falling back to the specified PEM url -func PEM(url string) (keypairs.PublicKey, error) { +func PEM(url string) (keypairs.PublicKeyTransitional, error) { // url is kid in this case return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { m, key, err := uncached.PEM(url) @@ -143,11 +143,11 @@ func PEM(url string) (keypairs.PublicKey, error) { // put in a map, just for caching maps := map[string]map[string]string{} - maps[key.Thumbprint()] = m + maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m maps[url] = m keys := map[string]keypairs.PublicKey{} - keys[key.Thumbprint()] = key + keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key keys[url] = key return maps, keys, nil @@ -155,7 +155,7 @@ func PEM(url string) (keypairs.PublicKey, error) { } // Fetch returns a key from cache, falling back to an exact url as the "issuer" -func Fetch(url string) (keypairs.PublicKey, error) { +func Fetch(url string) (keypairs.PublicKeyTransitional, error) { // url is kid in this case return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { m, key, err := uncached.Fetch(url) @@ -165,10 +165,10 @@ func Fetch(url string) (keypairs.PublicKey, error) { // put in a map, just for caching maps := map[string]map[string]string{} - maps[key.Thumbprint()] = m + maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m keys := map[string]keypairs.PublicKey{} - keys[key.Thumbprint()] = key + keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key return maps, keys, nil }) @@ -210,7 +210,7 @@ func get(kidOrThumb, iss string) *CachableKey { return nil } -func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKey, error) { +func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { now := time.Now() key := get(kidOrThumb, iss) @@ -223,12 +223,12 @@ func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.Pu go fetchAndSelect(kidOrThumb, iss, fetcher) } - return key.Key, nil + return key.Key.Key().(keypairs.PublicKeyTransitional), nil } type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) -func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey, error) { +func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { maps, keys, err := fetcher(baseURL) if nil != err { return nil, err @@ -238,12 +238,12 @@ func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey, for i := range keys { key := keys[i] - if id == key.Thumbprint() { - return key, nil + if id == keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional)) { + return key.Key().(keypairs.PublicKeyTransitional), nil } if id == key.KeyID() { - return key, nil + return key.Key().(keypairs.PublicKeyTransitional), nil } } @@ -287,7 +287,7 @@ func cacheKey(kid, iss, expstr string, pub keypairs.PublicKey) error { Expiry: expiry, } // Since thumbprints are crypto secure, iss isn't needed - thumb := pub.Thumbprint() + thumb := keypairs.Thumbprint(pub.Key().(keypairs.PublicKeyTransitional)) KeyCache[thumb] = CachableKey{ Key: pub, Expiry: expiry, diff --git a/keyfetch/fetch_test.go b/keyfetch/fetch_test.go index 5f2dee7..2bb2df7 100644 --- a/keyfetch/fetch_test.go +++ b/keyfetch/fetch_test.go @@ -8,19 +8,21 @@ import ( "git.rootprojects.org/root/keypairs/keyfetch/uncached" ) -var pubkey keypairs.PublicKey +var pubkey keypairs.PublicKeyTransitional func TestCachesKey(t *testing.T) { testCachesKey(t, "https://bigsquid.auth0.com/") clear() testCachesKey(t, "https://bigsquid.auth0.com") // Get PEM - k3, err := PEM("https://bigsquid.auth0.com/pem") + pubk3, err := PEM("https://bigsquid.auth0.com/pem") if nil != err { t.Fatal("Error fetching and caching key:", err) } - if k3.Thumbprint() != pubkey.Thumbprint() { - t.Fatal("Error got different thumbprint for different versions of the same key:", err) + thumb3 := keypairs.Thumbprint(pubk3) + thumb := keypairs.Thumbprint(pubkey) + if thumb3 != thumb { + t.Fatalf("Error got different thumbprint for different versions of the same key %q != %q: %v", thumb3, thumb, err) } clear() testCachesKey(t, "https://big-squid.github.io/") @@ -45,12 +47,12 @@ func testCachesKey(t *testing.T, url string) { t.Fatal("Should discover 1 or more keys via", url) } - var key keypairs.PublicKey + var key keypairs.PublicKeyTransitional for i := range keys { - key = keys[i] + key = keys[i].Key().(keypairs.PublicKeyTransitional) break } - thumb := key.Thumbprint() + thumb := keypairs.Thumbprint(key) // Look in cache for each (and fail) if pub := Get(thumb, ""); nil != pub { @@ -67,10 +69,11 @@ func testCachesKey(t *testing.T, url string) { if pub := Get(thumb, ""); nil == pub { t.Fatal("key was not properly cached by thumbprint", thumb) } - if "" != pubkey.KeyID() { - if pub := Get(pubkey.KeyID(), url); nil == pub { - t.Fatal("key was not properly cached by kid", pubkey.KeyID()) - } + + // TODO thumb / id mapping + thumb = keypairs.Thumbprint(pubkey) + if pub := Get(thumb, url); nil == pub { + t.Fatal("key was not properly cached by kid", pubkey) } else { t.Log("Key did not have an explicit KeyID") } @@ -86,8 +89,10 @@ func testCachesKey(t *testing.T, url string) { } // Sanity check that the kid and thumb match - if key.KeyID() != pubkey.KeyID() || key.Thumbprint() != pubkey.Thumbprint() { - t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), pubkey.KeyID(), key.Thumbprint(), pubkey.Thumbprint()) + if !key.Equal(pubkey) || keypairs.Thumbprint(key) != keypairs.Thumbprint(pubkey) { + t.Fatalf("SANITY: [todo: KeyIDs or] Thumbprints do not match:\n%q != %q\n%q != %q", + keypairs.Thumbprint(key), keypairs.Thumbprint(pubkey), + keypairs.Thumbprint(key), keypairs.Thumbprint(pubkey)) } // Get 404 diff --git a/keyfetch/uncached/fetch.go b/keyfetch/uncached/fetch.go index 2e1c265..b0a73bf 100644 --- a/keyfetch/uncached/fetch.go +++ b/keyfetch/uncached/fetch.go @@ -71,8 +71,8 @@ func JWKs(jwksurl string) (map[string]map[string]string, map[string]keypairs.Pub if nil != err { return nil, nil, err } - keys[key.Thumbprint()] = key - maps[key.Thumbprint()] = m + keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key + maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m } return maps, keys, nil diff --git a/keypairs.go b/keypairs.go index ae6012b..92ffce9 100644 --- a/keypairs.go +++ b/keypairs.go @@ -3,7 +3,6 @@ package keypairs import ( "bytes" "crypto" - "crypto/dsa" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" @@ -57,12 +56,19 @@ const ErrDevBadKeyType = "[Developer Error] crypto.PublicKey and crypto.PrivateK // PrivateKey is a zero-cost typesafe substitue for crypto.PrivateKey type PrivateKey interface { Public() crypto.PublicKey + Equal(x crypto.PrivateKey) bool +} + +// PublicKeyTransitional is so that v0.7.x can use golang v1.15 keys +type PublicKeyTransitional interface { + Equal(x crypto.PublicKey) bool } // PublicKey thinly veils crypto.PublicKey for type safety type PublicKey interface { crypto.PublicKey - Thumbprint() string + //Equal(x crypto.PublicKey) bool + //Thumbprint() string KeyID() string Key() crypto.PublicKey ExpiresAt() time.Time @@ -87,6 +93,11 @@ func (p *ECPublicKey) Thumbprint() string { return ThumbprintUntypedPublicKey(p.PublicKey) } +// Equal returns true if the public key is equal. +func (p *ECPublicKey) Equal(x crypto.PublicKey) bool { + return p.PublicKey.Equal(x) +} + // KeyID returns the JWK `kid`, which will be the Thumbprint for keys generated with this library func (p *ECPublicKey) KeyID() string { return p.KID @@ -112,6 +123,11 @@ func (p *RSAPublicKey) Thumbprint() string { return ThumbprintUntypedPublicKey(p.PublicKey) } +// Equal returns true if the public key is equal. +func (p *RSAPublicKey) Equal(x crypto.PublicKey) bool { + return p.PublicKey.Equal(x) +} + // KeyID returns the JWK `kid`, which will be the Thumbprint for keys generated with this library func (p *RSAPublicKey) KeyID() string { return p.KID @@ -134,6 +150,11 @@ func (p *RSAPublicKey) ExpiresAt() time.Time { // NewPublicKey wraps a crypto.PublicKey to make it typesafe. func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey { + _, ok := pub.(PublicKeyTransitional) + if !ok { + panic("Developer Error: not a crypto.PublicKey") + } + var k PublicKey switch p := pub.(type) { case *ecdsa.PublicKey: @@ -156,14 +177,6 @@ func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey { rsakey.KID = ThumbprintRSAPublicKey(p) } k = rsakey - case *ecdsa.PrivateKey: - panic(errors.New(ErrDevSwapPrivatePublic)) - case *rsa.PrivateKey: - panic(errors.New(ErrDevSwapPrivatePublic)) - case *dsa.PublicKey: - panic(ErrInvalidPublicKey) - case *dsa.PrivateKey: - panic(ErrInvalidPrivateKey) default: panic(fmt.Errorf(ErrDevBadKeyType, pub)) } @@ -180,8 +193,6 @@ func MarshalJWKPublicKey(key PublicKey, exp ...time.Time) []byte { return MarshalRSAPublicKey(k, exp...) case *ecdsa.PublicKey: return MarshalECPublicKey(k, exp...) - case *dsa.PublicKey: - panic(ErrInvalidPublicKey) default: // this is unreachable because we know the types that we pass in log.Printf("keytype: %t, %+v\n", key, key) @@ -189,9 +200,14 @@ func MarshalJWKPublicKey(key PublicKey, exp ...time.Time) []byte { } } +// Thumbprint returns the SHA256 RFC-spec JWK thumbprint +func Thumbprint(pub PublicKeyTransitional) string { + return ThumbprintUntypedPublicKey(pub) +} + // ThumbprintPublicKey returns the SHA256 RFC-spec JWK thumbprint func ThumbprintPublicKey(pub PublicKey) string { - return ThumbprintUntypedPublicKey(pub.Key()) + return ThumbprintUntypedPublicKey(pub.Key().(PublicKeyTransitional)) } // ThumbprintUntypedPublicKey is a non-typesafe version of ThumbprintPublicKey @@ -199,7 +215,7 @@ func ThumbprintPublicKey(pub PublicKey) string { func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string { switch p := pub.(type) { case PublicKey: - return ThumbprintUntypedPublicKey(p.Key()) + return ThumbprintUntypedPublicKey(p.Key().(PublicKeyTransitional)) case *ecdsa.PublicKey: return ThumbprintECPublicKey(p) case *rsa.PublicKey: diff --git a/marshal.go b/marshal.go index 2198c5e..17aa373 100644 --- a/marshal.go +++ b/marshal.go @@ -1,7 +1,6 @@ package keypairs import ( - "crypto" "crypto/ecdsa" "crypto/rsa" "crypto/x509" @@ -14,7 +13,7 @@ import ( ) // MarshalPEMPublicKey outputs the given public key as JWK -func MarshalPEMPublicKey(pubkey crypto.PublicKey) ([]byte, error) { +func MarshalPEMPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { block, err := marshalDERPublicKey(pubkey) if nil != err { return nil, err @@ -23,7 +22,7 @@ func MarshalPEMPublicKey(pubkey crypto.PublicKey) ([]byte, error) { } // MarshalDERPublicKey outputs the given public key as JWK -func MarshalDERPublicKey(pubkey crypto.PublicKey) ([]byte, error) { +func MarshalDERPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { block, err := marshalDERPublicKey(pubkey) if nil != err { return nil, err @@ -32,7 +31,7 @@ func MarshalDERPublicKey(pubkey crypto.PublicKey) ([]byte, error) { } // marshalDERPublicKey outputs the given public key as JWK -func marshalDERPublicKey(pubkey crypto.PublicKey) (*pem.Block, error) { +func marshalDERPublicKey(pubkey PublicKeyTransitional) (*pem.Block, error) { var der []byte var typ string diff --git a/verify.go b/verify.go index f6dfae9..80f3216 100644 --- a/verify.go +++ b/verify.go @@ -140,7 +140,7 @@ func pubkeyCheck(pubkey PublicKey, kid string, opts *keyOptions, errs []error) ( } if nil != pub && "" != kid { - if 1 != subtle.ConstantTimeCompare([]byte(kid), []byte(pub.Thumbprint())) { + if 1 != subtle.ConstantTimeCompare([]byte(kid), []byte(Thumbprint(pub.Key().(PublicKeyTransitional)))) { err := errors.New("'kid' does not match the public key thumbprint") errs = append(errs, err) }