move to PublicKeyTransitional, toward go1.15 keypairs

This commit is contained in:
AJ ONeal 2020-10-21 01:53:40 -06:00
parent 2e3ead4102
commit 098f92178a
7 changed files with 74 additions and 54 deletions

View File

@ -363,9 +363,9 @@ func marshalPub(pub keypairs.PublicKey, pubname string) {
if strings.HasSuffix(pubname, ".json") { if strings.HasSuffix(pubname, ".json") {
b = indentJSON(keypairs.MarshalJWKPublicKey(pub)) b = indentJSON(keypairs.MarshalJWKPublicKey(pub))
} else if strings.HasSuffix(pubname, ".pem") { } else if strings.HasSuffix(pubname, ".pem") {
b, _ = keypairs.MarshalPEMPublicKey(pub) b, _ = keypairs.MarshalPEMPublicKey(pub.Key().(keypairs.PublicKeyTransitional))
} else if strings.HasSuffix(pubname, ".der") { } else if strings.HasSuffix(pubname, ".der") {
b, _ = keypairs.MarshalDERPublicKey(pub) b, _ = keypairs.MarshalDERPublicKey(pub.Key().(keypairs.PublicKeyTransitional))
} }
ioutil.WriteFile(pubname, b, 0644) ioutil.WriteFile(pubname, b, 0644)

View File

@ -25,8 +25,8 @@ import (
// TODO should be ErrInvalidJWKURL // TODO should be ErrInvalidJWKURL
// EInvalidJWKURL means that the url did not provide JWKs // ErrInvalidJWKURL means that the url did not provide JWKs
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs") var ErrInvalidJWKURL = errors.New("url does not lead to valid JWKs")
// KeyCache is an in-memory key cache // KeyCache is an in-memory key cache
var KeyCache = map[string]CachableKey{} var KeyCache = map[string]CachableKey{}
@ -94,7 +94,7 @@ func OIDCJWKs(baseURL string) (PublicKeysMap, error) {
} }
// OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint) // OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint)
func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) {
return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs) return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs)
} }
@ -110,7 +110,7 @@ func WellKnownJWKs(kidOrThumb, iss string) (PublicKeysMap, error) {
} }
// WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint) // WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint)
func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) {
return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs) return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs)
} }
@ -128,12 +128,12 @@ func JWKs(jwksurl string) (PublicKeysMap, error) {
} }
// JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer // JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer
func JWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { func JWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) {
return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs) return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs)
} }
// PEM tries to return a key from cache, falling back to the specified PEM url // PEM tries to return a key from cache, falling back to the specified PEM url
func PEM(url string) (keypairs.PublicKey, error) { func PEM(url string) (keypairs.PublicKeyTransitional, error) {
// url is kid in this case // url is kid in this case
return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
m, key, err := uncached.PEM(url) m, key, err := uncached.PEM(url)
@ -143,11 +143,11 @@ func PEM(url string) (keypairs.PublicKey, error) {
// put in a map, just for caching // put in a map, just for caching
maps := map[string]map[string]string{} maps := map[string]map[string]string{}
maps[key.Thumbprint()] = m maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m
maps[url] = m maps[url] = m
keys := map[string]keypairs.PublicKey{} keys := map[string]keypairs.PublicKey{}
keys[key.Thumbprint()] = key keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key
keys[url] = key keys[url] = key
return maps, keys, nil return maps, keys, nil
@ -155,7 +155,7 @@ func PEM(url string) (keypairs.PublicKey, error) {
} }
// Fetch returns a key from cache, falling back to an exact url as the "issuer" // Fetch returns a key from cache, falling back to an exact url as the "issuer"
func Fetch(url string) (keypairs.PublicKey, error) { func Fetch(url string) (keypairs.PublicKeyTransitional, error) {
// url is kid in this case // url is kid in this case
return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
m, key, err := uncached.Fetch(url) m, key, err := uncached.Fetch(url)
@ -165,10 +165,10 @@ func Fetch(url string) (keypairs.PublicKey, error) {
// put in a map, just for caching // put in a map, just for caching
maps := map[string]map[string]string{} maps := map[string]map[string]string{}
maps[key.Thumbprint()] = m maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m
keys := map[string]keypairs.PublicKey{} keys := map[string]keypairs.PublicKey{}
keys[key.Thumbprint()] = key keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key
return maps, keys, nil return maps, keys, nil
}) })
@ -210,7 +210,7 @@ func get(kidOrThumb, iss string) *CachableKey {
return nil return nil
} }
func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKey, error) { func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) {
now := time.Now() now := time.Now()
key := get(kidOrThumb, iss) key := get(kidOrThumb, iss)
@ -223,12 +223,12 @@ func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.Pu
go fetchAndSelect(kidOrThumb, iss, fetcher) go fetchAndSelect(kidOrThumb, iss, fetcher)
} }
return key.Key, nil return key.Key.Key().(keypairs.PublicKeyTransitional), nil
} }
type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error)
func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey, error) { func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) {
maps, keys, err := fetcher(baseURL) maps, keys, err := fetcher(baseURL)
if nil != err { if nil != err {
return nil, err return nil, err
@ -238,12 +238,12 @@ func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey,
for i := range keys { for i := range keys {
key := keys[i] key := keys[i]
if id == key.Thumbprint() { if id == keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional)) {
return key, nil return key.Key().(keypairs.PublicKeyTransitional), nil
} }
if id == key.KeyID() { if id == key.KeyID() {
return key, nil return key.Key().(keypairs.PublicKeyTransitional), nil
} }
} }
@ -287,7 +287,7 @@ func cacheKey(kid, iss, expstr string, pub keypairs.PublicKey) error {
Expiry: expiry, Expiry: expiry,
} }
// Since thumbprints are crypto secure, iss isn't needed // Since thumbprints are crypto secure, iss isn't needed
thumb := pub.Thumbprint() thumb := keypairs.Thumbprint(pub.Key().(keypairs.PublicKeyTransitional))
KeyCache[thumb] = CachableKey{ KeyCache[thumb] = CachableKey{
Key: pub, Key: pub,
Expiry: expiry, Expiry: expiry,

View File

@ -8,19 +8,21 @@ import (
"git.rootprojects.org/root/keypairs/keyfetch/uncached" "git.rootprojects.org/root/keypairs/keyfetch/uncached"
) )
var pubkey keypairs.PublicKey var pubkey keypairs.PublicKeyTransitional
func TestCachesKey(t *testing.T) { func TestCachesKey(t *testing.T) {
testCachesKey(t, "https://bigsquid.auth0.com/") testCachesKey(t, "https://bigsquid.auth0.com/")
clear() clear()
testCachesKey(t, "https://bigsquid.auth0.com") testCachesKey(t, "https://bigsquid.auth0.com")
// Get PEM // Get PEM
k3, err := PEM("https://bigsquid.auth0.com/pem") pubk3, err := PEM("https://bigsquid.auth0.com/pem")
if nil != err { if nil != err {
t.Fatal("Error fetching and caching key:", err) t.Fatal("Error fetching and caching key:", err)
} }
if k3.Thumbprint() != pubkey.Thumbprint() { thumb3 := keypairs.Thumbprint(pubk3)
t.Fatal("Error got different thumbprint for different versions of the same key:", err) thumb := keypairs.Thumbprint(pubkey)
if thumb3 != thumb {
t.Fatalf("Error got different thumbprint for different versions of the same key %q != %q: %v", thumb3, thumb, err)
} }
clear() clear()
testCachesKey(t, "https://big-squid.github.io/") testCachesKey(t, "https://big-squid.github.io/")
@ -45,12 +47,12 @@ func testCachesKey(t *testing.T, url string) {
t.Fatal("Should discover 1 or more keys via", url) t.Fatal("Should discover 1 or more keys via", url)
} }
var key keypairs.PublicKey var key keypairs.PublicKeyTransitional
for i := range keys { for i := range keys {
key = keys[i] key = keys[i].Key().(keypairs.PublicKeyTransitional)
break break
} }
thumb := key.Thumbprint() thumb := keypairs.Thumbprint(key)
// Look in cache for each (and fail) // Look in cache for each (and fail)
if pub := Get(thumb, ""); nil != pub { if pub := Get(thumb, ""); nil != pub {
@ -67,10 +69,11 @@ func testCachesKey(t *testing.T, url string) {
if pub := Get(thumb, ""); nil == pub { if pub := Get(thumb, ""); nil == pub {
t.Fatal("key was not properly cached by thumbprint", thumb) t.Fatal("key was not properly cached by thumbprint", thumb)
} }
if "" != pubkey.KeyID() {
if pub := Get(pubkey.KeyID(), url); nil == pub { // TODO thumb / id mapping
t.Fatal("key was not properly cached by kid", pubkey.KeyID()) thumb = keypairs.Thumbprint(pubkey)
} if pub := Get(thumb, url); nil == pub {
t.Fatal("key was not properly cached by kid", pubkey)
} else { } else {
t.Log("Key did not have an explicit KeyID") t.Log("Key did not have an explicit KeyID")
} }
@ -86,8 +89,10 @@ func testCachesKey(t *testing.T, url string) {
} }
// Sanity check that the kid and thumb match // Sanity check that the kid and thumb match
if key.KeyID() != pubkey.KeyID() || key.Thumbprint() != pubkey.Thumbprint() { if !key.Equal(pubkey) || keypairs.Thumbprint(key) != keypairs.Thumbprint(pubkey) {
t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), pubkey.KeyID(), key.Thumbprint(), pubkey.Thumbprint()) t.Fatalf("SANITY: [todo: KeyIDs or] Thumbprints do not match:\n%q != %q\n%q != %q",
keypairs.Thumbprint(key), keypairs.Thumbprint(pubkey),
keypairs.Thumbprint(key), keypairs.Thumbprint(pubkey))
} }
// Get 404 // Get 404

View File

@ -71,8 +71,8 @@ func JWKs(jwksurl string) (map[string]map[string]string, map[string]keypairs.Pub
if nil != err { if nil != err {
return nil, nil, err return nil, nil, err
} }
keys[key.Thumbprint()] = key keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key
maps[key.Thumbprint()] = m maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m
} }
return maps, keys, nil return maps, keys, nil

View File

@ -3,7 +3,6 @@ package keypairs
import ( import (
"bytes" "bytes"
"crypto" "crypto"
"crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rsa" "crypto/rsa"
@ -57,12 +56,19 @@ const ErrDevBadKeyType = "[Developer Error] crypto.PublicKey and crypto.PrivateK
// PrivateKey is a zero-cost typesafe substitue for crypto.PrivateKey // PrivateKey is a zero-cost typesafe substitue for crypto.PrivateKey
type PrivateKey interface { type PrivateKey interface {
Public() crypto.PublicKey Public() crypto.PublicKey
Equal(x crypto.PrivateKey) bool
}
// PublicKeyTransitional is so that v0.7.x can use golang v1.15 keys
type PublicKeyTransitional interface {
Equal(x crypto.PublicKey) bool
} }
// PublicKey thinly veils crypto.PublicKey for type safety // PublicKey thinly veils crypto.PublicKey for type safety
type PublicKey interface { type PublicKey interface {
crypto.PublicKey crypto.PublicKey
Thumbprint() string //Equal(x crypto.PublicKey) bool
//Thumbprint() string
KeyID() string KeyID() string
Key() crypto.PublicKey Key() crypto.PublicKey
ExpiresAt() time.Time ExpiresAt() time.Time
@ -87,6 +93,11 @@ func (p *ECPublicKey) Thumbprint() string {
return ThumbprintUntypedPublicKey(p.PublicKey) return ThumbprintUntypedPublicKey(p.PublicKey)
} }
// Equal returns true if the public key is equal.
func (p *ECPublicKey) Equal(x crypto.PublicKey) bool {
return p.PublicKey.Equal(x)
}
// KeyID returns the JWK `kid`, which will be the Thumbprint for keys generated with this library // KeyID returns the JWK `kid`, which will be the Thumbprint for keys generated with this library
func (p *ECPublicKey) KeyID() string { func (p *ECPublicKey) KeyID() string {
return p.KID return p.KID
@ -112,6 +123,11 @@ func (p *RSAPublicKey) Thumbprint() string {
return ThumbprintUntypedPublicKey(p.PublicKey) return ThumbprintUntypedPublicKey(p.PublicKey)
} }
// Equal returns true if the public key is equal.
func (p *RSAPublicKey) Equal(x crypto.PublicKey) bool {
return p.PublicKey.Equal(x)
}
// KeyID returns the JWK `kid`, which will be the Thumbprint for keys generated with this library // KeyID returns the JWK `kid`, which will be the Thumbprint for keys generated with this library
func (p *RSAPublicKey) KeyID() string { func (p *RSAPublicKey) KeyID() string {
return p.KID return p.KID
@ -134,6 +150,11 @@ func (p *RSAPublicKey) ExpiresAt() time.Time {
// NewPublicKey wraps a crypto.PublicKey to make it typesafe. // NewPublicKey wraps a crypto.PublicKey to make it typesafe.
func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey { func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey {
_, ok := pub.(PublicKeyTransitional)
if !ok {
panic("Developer Error: not a crypto.PublicKey")
}
var k PublicKey var k PublicKey
switch p := pub.(type) { switch p := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
@ -156,14 +177,6 @@ func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey {
rsakey.KID = ThumbprintRSAPublicKey(p) rsakey.KID = ThumbprintRSAPublicKey(p)
} }
k = rsakey k = rsakey
case *ecdsa.PrivateKey:
panic(errors.New(ErrDevSwapPrivatePublic))
case *rsa.PrivateKey:
panic(errors.New(ErrDevSwapPrivatePublic))
case *dsa.PublicKey:
panic(ErrInvalidPublicKey)
case *dsa.PrivateKey:
panic(ErrInvalidPrivateKey)
default: default:
panic(fmt.Errorf(ErrDevBadKeyType, pub)) panic(fmt.Errorf(ErrDevBadKeyType, pub))
} }
@ -180,8 +193,6 @@ func MarshalJWKPublicKey(key PublicKey, exp ...time.Time) []byte {
return MarshalRSAPublicKey(k, exp...) return MarshalRSAPublicKey(k, exp...)
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
return MarshalECPublicKey(k, exp...) return MarshalECPublicKey(k, exp...)
case *dsa.PublicKey:
panic(ErrInvalidPublicKey)
default: default:
// this is unreachable because we know the types that we pass in // this is unreachable because we know the types that we pass in
log.Printf("keytype: %t, %+v\n", key, key) log.Printf("keytype: %t, %+v\n", key, key)
@ -189,9 +200,14 @@ func MarshalJWKPublicKey(key PublicKey, exp ...time.Time) []byte {
} }
} }
// Thumbprint returns the SHA256 RFC-spec JWK thumbprint
func Thumbprint(pub PublicKeyTransitional) string {
return ThumbprintUntypedPublicKey(pub)
}
// ThumbprintPublicKey returns the SHA256 RFC-spec JWK thumbprint // ThumbprintPublicKey returns the SHA256 RFC-spec JWK thumbprint
func ThumbprintPublicKey(pub PublicKey) string { func ThumbprintPublicKey(pub PublicKey) string {
return ThumbprintUntypedPublicKey(pub.Key()) return ThumbprintUntypedPublicKey(pub.Key().(PublicKeyTransitional))
} }
// ThumbprintUntypedPublicKey is a non-typesafe version of ThumbprintPublicKey // ThumbprintUntypedPublicKey is a non-typesafe version of ThumbprintPublicKey
@ -199,7 +215,7 @@ func ThumbprintPublicKey(pub PublicKey) string {
func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string { func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string {
switch p := pub.(type) { switch p := pub.(type) {
case PublicKey: case PublicKey:
return ThumbprintUntypedPublicKey(p.Key()) return ThumbprintUntypedPublicKey(p.Key().(PublicKeyTransitional))
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
return ThumbprintECPublicKey(p) return ThumbprintECPublicKey(p)
case *rsa.PublicKey: case *rsa.PublicKey:

View File

@ -1,7 +1,6 @@
package keypairs package keypairs
import ( import (
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
@ -14,7 +13,7 @@ import (
) )
// MarshalPEMPublicKey outputs the given public key as JWK // MarshalPEMPublicKey outputs the given public key as JWK
func MarshalPEMPublicKey(pubkey crypto.PublicKey) ([]byte, error) { func MarshalPEMPublicKey(pubkey PublicKeyTransitional) ([]byte, error) {
block, err := marshalDERPublicKey(pubkey) block, err := marshalDERPublicKey(pubkey)
if nil != err { if nil != err {
return nil, err return nil, err
@ -23,7 +22,7 @@ func MarshalPEMPublicKey(pubkey crypto.PublicKey) ([]byte, error) {
} }
// MarshalDERPublicKey outputs the given public key as JWK // MarshalDERPublicKey outputs the given public key as JWK
func MarshalDERPublicKey(pubkey crypto.PublicKey) ([]byte, error) { func MarshalDERPublicKey(pubkey PublicKeyTransitional) ([]byte, error) {
block, err := marshalDERPublicKey(pubkey) block, err := marshalDERPublicKey(pubkey)
if nil != err { if nil != err {
return nil, err return nil, err
@ -32,7 +31,7 @@ func MarshalDERPublicKey(pubkey crypto.PublicKey) ([]byte, error) {
} }
// marshalDERPublicKey outputs the given public key as JWK // marshalDERPublicKey outputs the given public key as JWK
func marshalDERPublicKey(pubkey crypto.PublicKey) (*pem.Block, error) { func marshalDERPublicKey(pubkey PublicKeyTransitional) (*pem.Block, error) {
var der []byte var der []byte
var typ string var typ string

View File

@ -140,7 +140,7 @@ func pubkeyCheck(pubkey PublicKey, kid string, opts *keyOptions, errs []error) (
} }
if nil != pub && "" != kid { if nil != pub && "" != kid {
if 1 != subtle.ConstantTimeCompare([]byte(kid), []byte(pub.Thumbprint())) { if 1 != subtle.ConstantTimeCompare([]byte(kid), []byte(Thumbprint(pub.Key().(PublicKeyTransitional)))) {
err := errors.New("'kid' does not match the public key thumbprint") err := errors.New("'kid' does not match the public key thumbprint")
errs = append(errs, err) errs = append(errs, err)
} }