transitioning to go1.15 PublicKey interface

This commit is contained in:
AJ ONeal 2020-10-21 03:38:05 -06:00
parent 743731c804
commit 286cf11070
9 changed files with 55 additions and 53 deletions

View File

@ -105,7 +105,7 @@ func gen(args []string) {
key := keypairs.NewDefaultPrivateKey() key := keypairs.NewDefaultPrivateKey()
marshalPriv(key, keyname) marshalPriv(key, keyname)
pub := key.Public().(keypairs.PublicKeyTransitional) pub := key.Public().(keypairs.PublicKey)
marshalPub(pub, pubname) marshalPub(pub, pubname)
} }
@ -255,8 +255,8 @@ func readKey(keyname string) (keypairs.PrivateKey, error) {
return key, nil return key, nil
} }
func readPub(pubname string) (keypairs.PublicKeyTransitional, error) { func readPub(pubname string) (keypairs.PublicKey, error) {
var pub keypairs.PublicKeyTransitional = nil var pub keypairs.PublicKey = nil
// Read as file // Read as file
b, err := ioutil.ReadFile(pubname) b, err := ioutil.ReadFile(pubname)
@ -269,7 +269,7 @@ func readPub(pubname string) (keypairs.PublicKeyTransitional, error) {
pubname, err, pubname, err,
) )
} }
pub = pub2.Key().(keypairs.PublicKeyTransitional) pub = pub2.Key()
} }
// Oh, it was a file. // Oh, it was a file.
@ -281,7 +281,7 @@ func readPub(pubname string) (keypairs.PublicKeyTransitional, error) {
pubname, err3, pubname, err3,
) )
} }
pub = pub3.Key().(keypairs.PublicKeyTransitional) pub = pub3.Key()
} }
return pub, nil return pub, nil
@ -351,7 +351,7 @@ func marshalPriv(key keypairs.PrivateKey, keyname string) {
ioutil.WriteFile(keyname, b, 0600) ioutil.WriteFile(keyname, b, 0600)
} }
func marshalPub(pub keypairs.PublicKeyTransitional, pubname string) { func marshalPub(pub keypairs.PublicKey, pubname string) {
var b []byte var b []byte
if "" == pubname { if "" == pubname {
b = indentJSON(keypairs.MarshalJWKPublicKey(pub)) b = indentJSON(keypairs.MarshalJWKPublicKey(pub))

View File

@ -94,7 +94,7 @@ func OIDCJWKs(baseURL string) (PublicKeysMap, error) {
} }
// OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint) // OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint)
func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) {
return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs) return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs)
} }
@ -110,7 +110,7 @@ func WellKnownJWKs(kidOrThumb, iss string) (PublicKeysMap, error) {
} }
// WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint) // WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint)
func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) {
return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs) return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs)
} }
@ -128,12 +128,12 @@ func JWKs(jwksurl string) (PublicKeysMap, error) {
} }
// JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer // JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer
func JWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { func JWK(kidOrThumb, iss string) (keypairs.PublicKey, error) {
return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs) return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs)
} }
// PEM tries to return a key from cache, falling back to the specified PEM url // PEM tries to return a key from cache, falling back to the specified PEM url
func PEM(url string) (keypairs.PublicKeyTransitional, error) { func PEM(url string) (keypairs.PublicKey, error) {
// url is kid in this case // url is kid in this case
return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) { return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) {
m, key, err := uncached.PEM(url) m, key, err := uncached.PEM(url)
@ -166,7 +166,7 @@ func PEM(url string) (keypairs.PublicKeyTransitional, error) {
} }
// Fetch returns a key from cache, falling back to an exact url as the "issuer" // Fetch returns a key from cache, falling back to an exact url as the "issuer"
func Fetch(url string) (keypairs.PublicKeyTransitional, error) { func Fetch(url string) (keypairs.PublicKey, error) {
// url is kid in this case // url is kid in this case
return immediateOneOrFetch(url, url, return immediateOneOrFetch(url, url,
func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) { func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) {
@ -177,10 +177,10 @@ func Fetch(url string) (keypairs.PublicKeyTransitional, error) {
// put in a map, just for caching // put in a map, just for caching
maps := map[string]map[string]string{} maps := map[string]map[string]string{}
maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m maps[keypairs.Thumbprint(key.Key())] = m
keys := map[string]keypairs.PublicKeyDeprecated{} keys := map[string]keypairs.PublicKeyDeprecated{}
keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key keys[keypairs.Thumbprint(key.Key())] = key
return maps, keys, nil return maps, keys, nil
}) })
@ -188,9 +188,9 @@ func Fetch(url string) (keypairs.PublicKeyTransitional, error) {
// Get retrieves a key from cache, or returns an error. // Get retrieves a key from cache, or returns an error.
// The issuer string may be empty if using a thumbprint rather than a kid. // The issuer string may be empty if using a thumbprint rather than a kid.
func Get(kidOrThumb, iss string) keypairs.PublicKeyTransitional { func Get(kidOrThumb, iss string) keypairs.PublicKey {
if pub := get(kidOrThumb, iss); nil != pub { if pub := get(kidOrThumb, iss); nil != pub {
return pub.Key.Key().(keypairs.PublicKeyTransitional) return pub.Key.Key()
} }
return nil return nil
} }
@ -222,7 +222,7 @@ func get(kidOrThumb, iss string) *CachableKey {
return nil return nil
} }
func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKey, error) {
now := time.Now() now := time.Now()
hit := get(kidOrThumb, iss) hit := get(kidOrThumb, iss)
@ -235,12 +235,12 @@ func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.Pu
go fetchAndSelect(kidOrThumb, iss, fetcher) go fetchAndSelect(kidOrThumb, iss, fetcher)
} }
return hit.Key.Key().(keypairs.PublicKeyTransitional), nil return hit.Key.Key(), nil
} }
type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error) type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKeyDeprecated, error)
func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey, error) {
maps, keys, err := fetcher(baseURL) maps, keys, err := fetcher(baseURL)
if nil != err { if nil != err {
return nil, err return nil, err
@ -249,13 +249,14 @@ func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTr
for i := range keys { for i := range keys {
key := keys[i] key := keys[i]
pub := key.Key()
if id == keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional)) { if id == keypairs.Thumbprint(pub) {
return key.Key().(keypairs.PublicKeyTransitional), nil return pub, nil
} }
if id == key.KeyID() { if id == key.KeyID() {
return key.Key().(keypairs.PublicKeyTransitional), nil return pub, nil
} }
} }
@ -302,7 +303,7 @@ func cacheKey(kid, iss, expstr string, pub keypairs.PublicKeyDeprecated) error {
Expiry: expiry, Expiry: expiry,
} }
// Since thumbprints are crypto secure, iss isn't needed // Since thumbprints are crypto secure, iss isn't needed
thumb := keypairs.Thumbprint(pub.Key().(keypairs.PublicKeyTransitional)) thumb := keypairs.Thumbprint(pub.Key())
KeyCache[thumb] = CachableKey{ KeyCache[thumb] = CachableKey{
Key: pub, Key: pub,
Expiry: expiry, Expiry: expiry,

View File

@ -8,7 +8,7 @@ import (
"git.rootprojects.org/root/keypairs/keyfetch/uncached" "git.rootprojects.org/root/keypairs/keyfetch/uncached"
) )
var pubkey keypairs.PublicKeyTransitional var pubkey keypairs.PublicKey
func TestCachesKey(t *testing.T) { func TestCachesKey(t *testing.T) {
// TODO set KeyID() in cache // TODO set KeyID() in cache
@ -48,9 +48,9 @@ func testCachesKey(t *testing.T, url string) {
t.Fatal("Should discover 1 or more keys via", url) t.Fatal("Should discover 1 or more keys via", url)
} }
var key keypairs.PublicKeyTransitional var key keypairs.PublicKey
for i := range keys { for i := range keys {
key = keys[i].Key().(keypairs.PublicKeyTransitional) key = keys[i].Key()
break break
} }
thumb := keypairs.Thumbprint(key) thumb := keypairs.Thumbprint(key)

View File

@ -82,15 +82,15 @@ func JWKs(jwksurl string) (JWKMapByID, PublicKeysMap, error) {
if nil != err { if nil != err {
return nil, nil, err return nil, nil, err
} }
keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key keys[keypairs.Thumbprint(key.Key())] = key
maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m maps[keypairs.Thumbprint(key.Key())] = m
} }
return maps, keys, nil return maps, keys, nil
} }
// PEM fetches and parses a PEM (assuming well-known format) // PEM fetches and parses a PEM (assuming well-known format)
func PEM(pemurl string) (map[string]string, keypairs.PublicKeyTransitional, error) { func PEM(pemurl string) (map[string]string, keypairs.PublicKey, error) {
var pubd keypairs.PublicKeyDeprecated var pubd keypairs.PublicKeyDeprecated
if err := safeFetch(pemurl, func(body io.Reader) error { if err := safeFetch(pemurl, func(body io.Reader) error {
pem, err := ioutil.ReadAll(body) pem, err := ioutil.ReadAll(body)
@ -107,7 +107,7 @@ func PEM(pemurl string) (map[string]string, keypairs.PublicKeyTransitional, erro
} }
jwk := map[string]interface{}{} jwk := map[string]interface{}{}
pub := pubd.Key().(keypairs.PublicKeyTransitional) pub := pubd.Key()
body := bytes.NewBuffer(keypairs.MarshalJWKPublicKey(pub)) body := bytes.NewBuffer(keypairs.MarshalJWKPublicKey(pub))
decoder := json.NewDecoder(body) decoder := json.NewDecoder(body)
decoder.UseNumber() decoder.UseNumber()

View File

@ -59,8 +59,8 @@ type PrivateKey interface {
Equal(x crypto.PrivateKey) bool Equal(x crypto.PrivateKey) bool
} }
// PublicKeyTransitional is so that v0.7.x can use golang v1.15 keys // PublicKey is so that v0.7.x can use golang v1.15 keys
type PublicKeyTransitional interface { type PublicKey interface {
Equal(x crypto.PublicKey) bool Equal(x crypto.PublicKey) bool
} }
@ -70,7 +70,7 @@ type PublicKeyDeprecated interface {
//Equal(x crypto.PublicKey) bool //Equal(x crypto.PublicKey) bool
//Thumbprint() string //Thumbprint() string
KeyID() string KeyID() string
Key() crypto.PublicKey Key() PublicKey
ExpiresAt() time.Time ExpiresAt() time.Time
} }
@ -104,7 +104,7 @@ func (p *ECPublicKey) KeyID() string {
} }
// Key returns the PublicKey // Key returns the PublicKey
func (p *ECPublicKey) Key() crypto.PublicKey { func (p *ECPublicKey) Key() PublicKey {
return p.PublicKey return p.PublicKey
} }
@ -134,7 +134,7 @@ func (p *RSAPublicKey) KeyID() string {
} }
// Key returns the PublicKey // Key returns the PublicKey
func (p *RSAPublicKey) Key() crypto.PublicKey { func (p *RSAPublicKey) Key() PublicKey {
return p.PublicKey return p.PublicKey
} }
@ -150,7 +150,7 @@ func (p *RSAPublicKey) ExpiresAt() time.Time {
// NewPublicKey wraps a crypto.PublicKey to make it typesafe. // NewPublicKey wraps a crypto.PublicKey to make it typesafe.
func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKeyDeprecated { func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKeyDeprecated {
_, ok := pub.(PublicKeyTransitional) _, ok := pub.(PublicKey)
if !ok { if !ok {
panic("Developer Error: not a crypto.PublicKey") panic("Developer Error: not a crypto.PublicKey")
} }
@ -186,7 +186,7 @@ func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKeyDeprecated {
// MarshalJWKPublicKey outputs a JWK with its key id (kid) and an optional expiration, // MarshalJWKPublicKey outputs a JWK with its key id (kid) and an optional expiration,
// making it suitable for use as an OIDC public key. // making it suitable for use as an OIDC public key.
func MarshalJWKPublicKey(key PublicKeyTransitional, exp ...time.Time) []byte { func MarshalJWKPublicKey(key PublicKey, exp ...time.Time) []byte {
// thumbprint keys are alphabetically sorted and only include the necessary public parts // thumbprint keys are alphabetically sorted and only include the necessary public parts
switch k := key.(type) { switch k := key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
@ -201,13 +201,13 @@ func MarshalJWKPublicKey(key PublicKeyTransitional, exp ...time.Time) []byte {
} }
// Thumbprint returns the SHA256 RFC-spec JWK thumbprint // Thumbprint returns the SHA256 RFC-spec JWK thumbprint
func Thumbprint(pub PublicKeyTransitional) string { func Thumbprint(pub PublicKey) string {
return ThumbprintUntypedPublicKey(pub) return ThumbprintUntypedPublicKey(pub)
} }
// ThumbprintPublicKey returns the SHA256 RFC-spec JWK thumbprint // ThumbprintPublicKey returns the SHA256 RFC-spec JWK thumbprint
func ThumbprintPublicKey(pub PublicKeyDeprecated) string { func ThumbprintPublicKey(pub PublicKeyDeprecated) string {
return ThumbprintUntypedPublicKey(pub.Key().(PublicKeyTransitional)) return ThumbprintUntypedPublicKey(pub.Key())
} }
// ThumbprintUntypedPublicKey is a non-typesafe version of ThumbprintPublicKey // ThumbprintUntypedPublicKey is a non-typesafe version of ThumbprintPublicKey
@ -215,7 +215,7 @@ func ThumbprintPublicKey(pub PublicKeyDeprecated) string {
func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string { func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string {
switch p := pub.(type) { switch p := pub.(type) {
case PublicKeyDeprecated: case PublicKeyDeprecated:
return ThumbprintUntypedPublicKey(p.Key().(PublicKeyTransitional)) return ThumbprintUntypedPublicKey(p.Key())
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
return ThumbprintECPublicKey(p) return ThumbprintECPublicKey(p)
case *rsa.PublicKey: case *rsa.PublicKey:

View File

@ -163,7 +163,8 @@ func marshalJWKs(keys []keypairs.PublicKeyDeprecated, exp2 time.Time) []string {
// Note that you don't have to embed `iss` in the JWK because the client // Note that you don't have to embed `iss` in the JWK because the client
// already has that info by virtue of getting to it in the first place. // already has that info by virtue of getting to it in the first place.
jwk := string(keypairs.MarshalJWKPublicKey(key.Key().(keypairs.PublicKeyTransitional), exp)) pub := key.Key()
jwk := string(keypairs.MarshalJWKPublicKey(pub, exp))
jwks = append(jwks, jwk) jwks = append(jwks, jwk)
} }

View File

@ -13,7 +13,7 @@ import (
) )
// MarshalPEMPublicKey outputs the given public key as JWK // MarshalPEMPublicKey outputs the given public key as JWK
func MarshalPEMPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { func MarshalPEMPublicKey(pubkey PublicKey) ([]byte, error) {
block, err := marshalDERPublicKey(pubkey) block, err := marshalDERPublicKey(pubkey)
if nil != err { if nil != err {
return nil, err return nil, err
@ -22,7 +22,7 @@ func MarshalPEMPublicKey(pubkey PublicKeyTransitional) ([]byte, error) {
} }
// MarshalDERPublicKey outputs the given public key as JWK // MarshalDERPublicKey outputs the given public key as JWK
func MarshalDERPublicKey(pubkey PublicKeyTransitional) ([]byte, error) { func MarshalDERPublicKey(pubkey PublicKey) ([]byte, error) {
block, err := marshalDERPublicKey(pubkey) block, err := marshalDERPublicKey(pubkey)
if nil != err { if nil != err {
return nil, err return nil, err
@ -31,7 +31,7 @@ func MarshalDERPublicKey(pubkey PublicKeyTransitional) ([]byte, error) {
} }
// marshalDERPublicKey outputs the given public key as JWK // marshalDERPublicKey outputs the given public key as JWK
func marshalDERPublicKey(pubkey PublicKeyTransitional) (*pem.Block, error) { func marshalDERPublicKey(pubkey PublicKey) (*pem.Block, error) {
var der []byte var der []byte
var typ string var typ string

View File

@ -26,7 +26,7 @@ func SignClaims(privkey PrivateKey, header Object, claims Object) (*JWS, error)
//delete(header, "_seed") //delete(header, "_seed")
} }
protected, header, err := headerToProtected(privkey.Public().(PublicKeyTransitional), header) protected, header, err := headerToProtected(privkey.Public().(PublicKey), header)
if nil != err { if nil != err {
return nil, err return nil, err
} }
@ -56,7 +56,7 @@ func SignClaims(privkey PrivateKey, header Object, claims Object) (*JWS, error)
}, nil }, nil
} }
func headerToProtected(pub PublicKeyTransitional, header Object) ([]byte, Object, error) { func headerToProtected(pub PublicKey, header Object) ([]byte, Object, error) {
if nil == header { if nil == header {
header = Object{} header = Object{}
} }

View File

@ -15,7 +15,7 @@ import (
) )
// VerifyClaims will check the signature of a parsed JWT // VerifyClaims will check the signature of a parsed JWT
func VerifyClaims(pubkey PublicKeyTransitional, jws *JWS) (errs []error) { func VerifyClaims(pubkey PublicKey, jws *JWS) (errs []error) {
kid, _ := jws.Header["kid"].(string) kid, _ := jws.Header["kid"].(string)
jwkmap, hasJWK := jws.Header["jwk"].(Object) jwkmap, hasJWK := jws.Header["jwk"].(Object)
//var jwk JWK = nil //var jwk JWK = nil
@ -27,7 +27,7 @@ func VerifyClaims(pubkey PublicKeyTransitional, jws *JWS) (errs []error) {
seed = int64(seedf64) seed = int64(seedf64)
} }
var pub PublicKeyTransitional = nil var pub PublicKey = nil
if hasJWK { if hasJWK {
pub, errs = selfsignCheck(jwkmap, errs) pub, errs = selfsignCheck(jwkmap, errs)
} else { } else {
@ -72,7 +72,7 @@ func VerifyClaims(pubkey PublicKeyTransitional, jws *JWS) (errs []error) {
return errs return errs
} }
func selfsignCheck(jwkmap Object, errs []error) (PublicKeyTransitional, []error) { func selfsignCheck(jwkmap Object, errs []error) (PublicKey, []error) {
var pub PublicKeyDeprecated = nil var pub PublicKeyDeprecated = nil
log.Println("Security TODO: did not check jws.Claims[\"sub\"] against 'jwk'") log.Println("Security TODO: did not check jws.Claims[\"sub\"] against 'jwk'")
log.Println("Security TODO: did not check jws.Claims[\"iss\"]") log.Println("Security TODO: did not check jws.Claims[\"iss\"]")
@ -104,11 +104,11 @@ func selfsignCheck(jwkmap Object, errs []error) (PublicKeyTransitional, []error)
} }
} }
return pub.Key().(PublicKeyTransitional), errs return pub.Key(), errs
} }
func pubkeyCheck(pubkey PublicKeyTransitional, kid string, opts *keyOptions, errs []error) (PublicKeyTransitional, []error) { func pubkeyCheck(pubkey PublicKey, kid string, opts *keyOptions, errs []error) (PublicKey, []error) {
var pub PublicKeyTransitional = nil var pub PublicKey = nil
if "" == kid { if "" == kid {
err := errors.New("token should have 'kid' or 'jwk' in header to identify the public key") err := errors.New("token should have 'kid' or 'jwk' in header to identify the public key")
@ -130,7 +130,7 @@ func pubkeyCheck(pubkey PublicKeyTransitional, kid string, opts *keyOptions, err
return nil, errs return nil, errs
} }
privkey := newPrivateKey(opts) privkey := newPrivateKey(opts)
pub = privkey.Public().(PublicKeyTransitional) pub = privkey.Public().(PublicKey)
return pub, errs return pub, errs
} }
err := errors.New("no matching public key") err := errors.New("no matching public key")
@ -149,7 +149,7 @@ func pubkeyCheck(pubkey PublicKeyTransitional, kid string, opts *keyOptions, err
} }
// Verify will check the signature of a hash // Verify will check the signature of a hash
func Verify(pubkey PublicKeyTransitional, hash []byte, sig []byte) bool { func Verify(pubkey PublicKey, hash []byte, sig []byte) bool {
switch pub := pubkey.(type) { switch pub := pubkey.(type) {
case *rsa.PublicKey: case *rsa.PublicKey: