From 286cf110701362fd299f46a5a9cbd7f2750ae568 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Wed, 21 Oct 2020 03:38:05 -0600 Subject: [PATCH] transitioning to go1.15 PublicKey interface --- cmd/keypairs/keypairs.go | 12 ++++++------ keyfetch/fetch.go | 33 +++++++++++++++++---------------- keyfetch/fetch_test.go | 6 +++--- keyfetch/uncached/fetch.go | 8 ++++---- keypairs.go | 20 ++++++++++---------- keyserve/keyserve.go | 3 ++- marshal.go | 6 +++--- sign.go | 4 ++-- verify.go | 16 ++++++++-------- 9 files changed, 55 insertions(+), 53 deletions(-) diff --git a/cmd/keypairs/keypairs.go b/cmd/keypairs/keypairs.go index 92cb915..77ef327 100644 --- a/cmd/keypairs/keypairs.go +++ b/cmd/keypairs/keypairs.go @@ -105,7 +105,7 @@ func gen(args []string) { key := keypairs.NewDefaultPrivateKey() marshalPriv(key, keyname) - pub := key.Public().(keypairs.PublicKeyTransitional) + pub := key.Public().(keypairs.PublicKey) marshalPub(pub, pubname) } @@ -255,8 +255,8 @@ func readKey(keyname string) (keypairs.PrivateKey, error) { return key, nil } -func readPub(pubname string) (keypairs.PublicKeyTransitional, error) { - var pub keypairs.PublicKeyTransitional = nil +func readPub(pubname string) (keypairs.PublicKey, error) { + var pub keypairs.PublicKey = nil // Read as file b, err := ioutil.ReadFile(pubname) @@ -269,7 +269,7 @@ func readPub(pubname string) (keypairs.PublicKeyTransitional, error) { pubname, err, ) } - pub = pub2.Key().(keypairs.PublicKeyTransitional) + pub = pub2.Key() } // Oh, it was a file. @@ -281,7 +281,7 @@ func readPub(pubname string) (keypairs.PublicKeyTransitional, error) { pubname, err3, ) } - pub = pub3.Key().(keypairs.PublicKeyTransitional) + pub = pub3.Key() } return pub, nil @@ -351,7 +351,7 @@ func marshalPriv(key keypairs.PrivateKey, keyname string) { ioutil.WriteFile(keyname, b, 0600) } -func marshalPub(pub keypairs.PublicKeyTransitional, pubname string) { +func marshalPub(pub keypairs.PublicKey, pubname string) { var b []byte if "" == pubname { b = indentJSON(keypairs.MarshalJWKPublicKey(pub)) diff --git a/keyfetch/fetch.go b/keyfetch/fetch.go index 9f16df8..43ef370 100644 --- a/keyfetch/fetch.go +++ b/keyfetch/fetch.go @@ -94,7 +94,7 @@ func OIDCJWKs(baseURL string) (PublicKeysMap, error) { } // OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint) -func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { +func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs) } @@ -110,7 +110,7 @@ func WellKnownJWKs(kidOrThumb, iss string) (PublicKeysMap, error) { } // WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint) -func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { +func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs) } @@ -128,12 +128,12 @@ func JWKs(jwksurl string) (PublicKeysMap, error) { } // JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer -func JWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { +func JWK(kidOrThumb, iss string) (keypairs.PublicKey, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs) } // PEM tries to return a key from cache, falling back to the specified PEM url -func PEM(url string) (keypairs.PublicKeyTransitional, error) { +func PEM(url string) (keypairs.PublicKey, error) { // url is kid in this case return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) { m, key, err := uncached.PEM(url) @@ -166,7 +166,7 @@ func PEM(url string) (keypairs.PublicKeyTransitional, error) { } // Fetch returns a key from cache, falling back to an exact url as the "issuer" -func Fetch(url string) (keypairs.PublicKeyTransitional, error) { +func Fetch(url string) (keypairs.PublicKey, error) { // url is kid in this case return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) { @@ -177,10 +177,10 @@ func Fetch(url string) (keypairs.PublicKeyTransitional, error) { // put in a map, just for caching maps := map[string]map[string]string{} - maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m + maps[keypairs.Thumbprint(key.Key())] = m keys := map[string]keypairs.PublicKeyDeprecated{} - keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key + keys[keypairs.Thumbprint(key.Key())] = key return maps, keys, nil }) @@ -188,9 +188,9 @@ func Fetch(url string) (keypairs.PublicKeyTransitional, error) { // Get retrieves a key from cache, or returns an error. // The issuer string may be empty if using a thumbprint rather than a kid. -func Get(kidOrThumb, iss string) keypairs.PublicKeyTransitional { +func Get(kidOrThumb, iss string) keypairs.PublicKey { if pub := get(kidOrThumb, iss); nil != pub { - return pub.Key.Key().(keypairs.PublicKeyTransitional) + return pub.Key.Key() } return nil } @@ -222,7 +222,7 @@ func get(kidOrThumb, iss string) *CachableKey { return nil } -func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { +func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKey, error) { now := time.Now() hit := get(kidOrThumb, iss) @@ -235,12 +235,12 @@ func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.Pu go fetchAndSelect(kidOrThumb, iss, fetcher) } - return hit.Key.Key().(keypairs.PublicKeyTransitional), nil + return hit.Key.Key(), nil } type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) -func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { +func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey, error) { maps, keys, err := fetcher(baseURL) if nil != err { return nil, err @@ -249,13 +249,14 @@ func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTr for i := range keys { key := keys[i] + pub := key.Key() - if id == keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional)) { - return key.Key().(keypairs.PublicKeyTransitional), nil + if id == keypairs.Thumbprint(pub) { + return pub, nil } if id == key.KeyID() { - return key.Key().(keypairs.PublicKeyTransitional), nil + return pub, nil } } @@ -302,7 +303,7 @@ func cacheKey(kid, iss, expstr string, pub keypairs.PublicKeyDeprecated) error { Expiry: expiry, } // Since thumbprints are crypto secure, iss isn't needed - thumb := keypairs.Thumbprint(pub.Key().(keypairs.PublicKeyTransitional)) + thumb := keypairs.Thumbprint(pub.Key()) KeyCache[thumb] = CachableKey{ Key: pub, Expiry: expiry, diff --git a/keyfetch/fetch_test.go b/keyfetch/fetch_test.go index 78df9fb..3d7f330 100644 --- a/keyfetch/fetch_test.go +++ b/keyfetch/fetch_test.go @@ -8,7 +8,7 @@ import ( "git.rootprojects.org/root/keypairs/keyfetch/uncached" ) -var pubkey keypairs.PublicKeyTransitional +var pubkey keypairs.PublicKey func TestCachesKey(t *testing.T) { // TODO set KeyID() in cache @@ -48,9 +48,9 @@ func testCachesKey(t *testing.T, url string) { t.Fatal("Should discover 1 or more keys via", url) } - var key keypairs.PublicKeyTransitional + var key keypairs.PublicKey for i := range keys { - key = keys[i].Key().(keypairs.PublicKeyTransitional) + key = keys[i].Key() break } thumb := keypairs.Thumbprint(key) diff --git a/keyfetch/uncached/fetch.go b/keyfetch/uncached/fetch.go index e37f9f9..d9db54c 100644 --- a/keyfetch/uncached/fetch.go +++ b/keyfetch/uncached/fetch.go @@ -82,15 +82,15 @@ func JWKs(jwksurl string) (JWKMapByID, PublicKeysMap, error) { if nil != err { return nil, nil, err } - keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key - maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m + keys[keypairs.Thumbprint(key.Key())] = key + maps[keypairs.Thumbprint(key.Key())] = m } return maps, keys, nil } // PEM fetches and parses a PEM (assuming well-known format) -func PEM(pemurl string) (map[string]string, keypairs.PublicKeyTransitional, error) { +func PEM(pemurl string) (map[string]string, keypairs.PublicKey, error) { var pubd keypairs.PublicKeyDeprecated if err := safeFetch(pemurl, func(body io.Reader) error { pem, err := ioutil.ReadAll(body) @@ -107,7 +107,7 @@ func PEM(pemurl string) (map[string]string, keypairs.PublicKeyTransitional, erro } jwk := map[string]interface{}{} - pub := pubd.Key().(keypairs.PublicKeyTransitional) + pub := pubd.Key() body := bytes.NewBuffer(keypairs.MarshalJWKPublicKey(pub)) decoder := json.NewDecoder(body) decoder.UseNumber() diff --git a/keypairs.go b/keypairs.go index e8b8879..d313d1a 100644 --- a/keypairs.go +++ b/keypairs.go @@ -59,8 +59,8 @@ type PrivateKey interface { Equal(x crypto.PrivateKey) bool } -// PublicKeyTransitional is so that v0.7.x can use golang v1.15 keys -type PublicKeyTransitional interface { +// PublicKey is so that v0.7.x can use golang v1.15 keys +type PublicKey interface { Equal(x crypto.PublicKey) bool } @@ -70,7 +70,7 @@ type PublicKeyDeprecated interface { //Equal(x crypto.PublicKey) bool //Thumbprint() string KeyID() string - Key() crypto.PublicKey + Key() PublicKey ExpiresAt() time.Time } @@ -104,7 +104,7 @@ func (p *ECPublicKey) KeyID() string { } // Key returns the PublicKey -func (p *ECPublicKey) Key() crypto.PublicKey { +func (p *ECPublicKey) Key() PublicKey { return p.PublicKey } @@ -134,7 +134,7 @@ func (p *RSAPublicKey) KeyID() string { } // Key returns the PublicKey -func (p *RSAPublicKey) Key() crypto.PublicKey { +func (p *RSAPublicKey) Key() PublicKey { return p.PublicKey } @@ -150,7 +150,7 @@ func (p *RSAPublicKey) ExpiresAt() time.Time { // NewPublicKey wraps a crypto.PublicKey to make it typesafe. func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKeyDeprecated { - _, ok := pub.(PublicKeyTransitional) + _, ok := pub.(PublicKey) if !ok { panic("Developer Error: not a crypto.PublicKey") } @@ -186,7 +186,7 @@ func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKeyDeprecated { // MarshalJWKPublicKey outputs a JWK with its key id (kid) and an optional expiration, // making it suitable for use as an OIDC public key. -func MarshalJWKPublicKey(key PublicKeyTransitional, exp ...time.Time) []byte { +func MarshalJWKPublicKey(key PublicKey, exp ...time.Time) []byte { // thumbprint keys are alphabetically sorted and only include the necessary public parts switch k := key.(type) { case *rsa.PublicKey: @@ -201,13 +201,13 @@ func MarshalJWKPublicKey(key PublicKeyTransitional, exp ...time.Time) []byte { } // Thumbprint returns the SHA256 RFC-spec JWK thumbprint -func Thumbprint(pub PublicKeyTransitional) string { +func Thumbprint(pub PublicKey) string { return ThumbprintUntypedPublicKey(pub) } // ThumbprintPublicKey returns the SHA256 RFC-spec JWK thumbprint func ThumbprintPublicKey(pub PublicKeyDeprecated) string { - return ThumbprintUntypedPublicKey(pub.Key().(PublicKeyTransitional)) + return ThumbprintUntypedPublicKey(pub.Key()) } // ThumbprintUntypedPublicKey is a non-typesafe version of ThumbprintPublicKey @@ -215,7 +215,7 @@ func ThumbprintPublicKey(pub PublicKeyDeprecated) string { func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string { switch p := pub.(type) { case PublicKeyDeprecated: - return ThumbprintUntypedPublicKey(p.Key().(PublicKeyTransitional)) + return ThumbprintUntypedPublicKey(p.Key()) case *ecdsa.PublicKey: return ThumbprintECPublicKey(p) case *rsa.PublicKey: diff --git a/keyserve/keyserve.go b/keyserve/keyserve.go index 5059c7f..040954b 100644 --- a/keyserve/keyserve.go +++ b/keyserve/keyserve.go @@ -163,7 +163,8 @@ func marshalJWKs(keys []keypairs.PublicKeyDeprecated, exp2 time.Time) []string { // Note that you don't have to embed `iss` in the JWK because the client // already has that info by virtue of getting to it in the first place. - jwk := string(keypairs.MarshalJWKPublicKey(key.Key().(keypairs.PublicKeyTransitional), exp)) + pub := key.Key() + jwk := string(keypairs.MarshalJWKPublicKey(pub, exp)) jwks = append(jwks, jwk) } diff --git a/marshal.go b/marshal.go index 17aa373..bf8cbf2 100644 --- a/marshal.go +++ b/marshal.go @@ -13,7 +13,7 @@ import ( ) // MarshalPEMPublicKey outputs the given public key as JWK -func MarshalPEMPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { +func MarshalPEMPublicKey(pubkey PublicKey) ([]byte, error) { block, err := marshalDERPublicKey(pubkey) if nil != err { return nil, err @@ -22,7 +22,7 @@ func MarshalPEMPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { } // MarshalDERPublicKey outputs the given public key as JWK -func MarshalDERPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { +func MarshalDERPublicKey(pubkey PublicKey) ([]byte, error) { block, err := marshalDERPublicKey(pubkey) if nil != err { return nil, err @@ -31,7 +31,7 @@ func MarshalDERPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { } // marshalDERPublicKey outputs the given public key as JWK -func marshalDERPublicKey(pubkey PublicKeyTransitional) (*pem.Block, error) { +func marshalDERPublicKey(pubkey PublicKey) (*pem.Block, error) { var der []byte var typ string diff --git a/sign.go b/sign.go index 5758b2c..81cfacc 100644 --- a/sign.go +++ b/sign.go @@ -26,7 +26,7 @@ func SignClaims(privkey PrivateKey, header Object, claims Object) (*JWS, error) //delete(header, "_seed") } - protected, header, err := headerToProtected(privkey.Public().(PublicKeyTransitional), header) + protected, header, err := headerToProtected(privkey.Public().(PublicKey), header) if nil != err { return nil, err } @@ -56,7 +56,7 @@ func SignClaims(privkey PrivateKey, header Object, claims Object) (*JWS, error) }, nil } -func headerToProtected(pub PublicKeyTransitional, header Object) ([]byte, Object, error) { +func headerToProtected(pub PublicKey, header Object) ([]byte, Object, error) { if nil == header { header = Object{} } diff --git a/verify.go b/verify.go index 81c5ec6..c6c6bc9 100644 --- a/verify.go +++ b/verify.go @@ -15,7 +15,7 @@ import ( ) // VerifyClaims will check the signature of a parsed JWT -func VerifyClaims(pubkey PublicKeyTransitional, jws *JWS) (errs []error) { +func VerifyClaims(pubkey PublicKey, jws *JWS) (errs []error) { kid, _ := jws.Header["kid"].(string) jwkmap, hasJWK := jws.Header["jwk"].(Object) //var jwk JWK = nil @@ -27,7 +27,7 @@ func VerifyClaims(pubkey PublicKeyTransitional, jws *JWS) (errs []error) { seed = int64(seedf64) } - var pub PublicKeyTransitional = nil + var pub PublicKey = nil if hasJWK { pub, errs = selfsignCheck(jwkmap, errs) } else { @@ -72,7 +72,7 @@ func VerifyClaims(pubkey PublicKeyTransitional, jws *JWS) (errs []error) { return errs } -func selfsignCheck(jwkmap Object, errs []error) (PublicKeyTransitional, []error) { +func selfsignCheck(jwkmap Object, errs []error) (PublicKey, []error) { var pub PublicKeyDeprecated = nil log.Println("Security TODO: did not check jws.Claims[\"sub\"] against 'jwk'") log.Println("Security TODO: did not check jws.Claims[\"iss\"]") @@ -104,11 +104,11 @@ func selfsignCheck(jwkmap Object, errs []error) (PublicKeyTransitional, []error) } } - return pub.Key().(PublicKeyTransitional), errs + return pub.Key(), errs } -func pubkeyCheck(pubkey PublicKeyTransitional, kid string, opts *keyOptions, errs []error) (PublicKeyTransitional, []error) { - var pub PublicKeyTransitional = nil +func pubkeyCheck(pubkey PublicKey, kid string, opts *keyOptions, errs []error) (PublicKey, []error) { + var pub PublicKey = nil if "" == kid { err := errors.New("token should have 'kid' or 'jwk' in header to identify the public key") @@ -130,7 +130,7 @@ func pubkeyCheck(pubkey PublicKeyTransitional, kid string, opts *keyOptions, err return nil, errs } privkey := newPrivateKey(opts) - pub = privkey.Public().(PublicKeyTransitional) + pub = privkey.Public().(PublicKey) return pub, errs } err := errors.New("no matching public key") @@ -149,7 +149,7 @@ func pubkeyCheck(pubkey PublicKeyTransitional, kid string, opts *keyOptions, err } // Verify will check the signature of a hash -func Verify(pubkey PublicKeyTransitional, hash []byte, sig []byte) bool { +func Verify(pubkey PublicKey, hash []byte, sig []byte) bool { switch pub := pubkey.(type) { case *rsa.PublicKey: