diff --git a/fetch.go b/fetch.go index 79bf7cd..2e2fb25 100644 --- a/fetch.go +++ b/fetch.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strconv" + "strings" "sync" "time" ) @@ -21,6 +22,28 @@ type CachableKey struct { Expiry time.Time } +// TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method +/* +type KeyID string + +func (kid KeyID) ID() string { + return string(kid) +} +func (kid KeyID) isID() {} + +type Thumbprint string + +func (thumb Thumbprint) ID() string { + return string(thumb) +} +func (thumb Thumbprint) isID() {} + +type ID interface { + ID() string + isID() +} +*/ + var StaleTime = 15 * time.Minute var DefaultKeyDuration = 48 * time.Hour var MinimumKeyDuration = time.Hour @@ -39,7 +62,7 @@ func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string, if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err { return nil, nil, err } else { - cacheKeys(maps, keys) + cacheKeys(maps, keys, baseURL) return maps, keys, err } } @@ -72,16 +95,11 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map return nil, err } - var pub PublicKey - var ok bool // because interfaces are never nil - for i := range keys { key := keys[i] if id == key.Thumbprint() { - pub = key - ok = true - break + return key, nil } var kid string @@ -94,16 +112,10 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled")) } if id == kid { - pub = key - ok = true - break + return key, nil } } - if ok { - return pub, nil - } - return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL) } @@ -112,7 +124,7 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) { if maps, keys, err := fetchPublicKeys(jwksurl); nil != err { return nil, err } else { - cacheKeys(maps, keys) + cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1)) return keys, err } } @@ -156,7 +168,14 @@ func FetchPublicKey(url string) (PublicKey, error) { return nil, err } - cacheKey(m["kid"], m["iss"], m["exp"], key) + maps := map[string]map[string]string{} + maps[key.Thumbprint()] = m + + keys := map[string]PublicKey{} + keys[key.Thumbprint()] = key + + cacheKeys(maps, keys, url) + return key, nil } @@ -180,31 +199,46 @@ func fetchPublicKey(url string) (map[string]string, PublicKey, error) { } func hasPublicKey(kid, iss string) (*CachableKey, bool) { - now := time.Now() id := kid + "@" + iss KeyCacheMux.Lock() hit, ok := KeyCache[id] KeyCacheMux.Unlock() - if ok && hit.Expiry.Sub(now) > 0 { + if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 { return &hit, true } return nil, false } -func GetPublicKey(kid, iss string) (PublicKey, error) { +// it would be a security risk to pass kid as a thumbprint +func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) { + KeyCacheMux.Lock() + hit, ok := KeyCache[thumb] + KeyCacheMux.Unlock() + + if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 { + return &hit, true + } + + return nil, false +} + +func GetPublicKey(kidOrThumb, iss string) (PublicKey, error) { now := time.Now() - key, ok := hasPublicKey(kid, iss) + key, ok := hasPublicKeyByThumbprint(kidOrThumb) if !ok { - return FetchOIDCPublicKey(kid, iss) + key, ok = hasPublicKey(kidOrThumb, iss) + if !ok { + return FetchOIDCPublicKey(kidOrThumb, iss) + } } // Fetch just a little before the key actually expires if key.Expiry.Sub(now) <= StaleTime { - go FetchOIDCPublicKey(kid, iss) + go FetchOIDCPublicKey(kidOrThumb, iss) } return key.Key, nil @@ -231,27 +265,38 @@ var cacheKey = func(kid, iss, expstr string, pub PublicKey) error { Key: pub, Expiry: expiry, } - id = pub.Thumbprint() + "@" + iss + thumb := pub.Thumbprint() + id = thumb + "@" + iss KeyCache[id] = CachableKey{ Key: pub, Expiry: expiry, } + // Since thumbprints are crypto secure, iss is not strictly needed + KeyCache[thumb] = CachableKey{ + Key: pub, + Expiry: expiry, + } KeyCacheMux.Unlock() return nil } -func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey) { +func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey, issuer string) { for i := range keys { key := keys[i] m := maps[i] - cacheKey(m["kid"], m["iss"], m["exp"], key) + if "" != m["iss"] { + issuer = m["iss"] + } + cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key) } } func getStringMap(m map[string]interface{}) map[string]string { n := make(map[string]string) + // TODO get issuer from x5c, if exists + // convert map[string]interface{} to map[string]string for j := range m { switch s := m[j].(type) { diff --git a/fetch_test.go b/fetch_test.go index b4bb2bf..2e4f8f7 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "errors" "testing" + "time" ) func TestFetchOIDCPublicKeys(t *testing.T) { @@ -34,9 +35,62 @@ func TestFetchOIDCPublicKeys(t *testing.T) { } func TestCachesKey(t *testing.T) { + url := "https://bigsquid.auth0.com/" + // Raw fetch a key and get KID and Thumbprint + _, keys, err := fetchOIDCPublicKeys(url) + if nil != err { + t.Fatal(url, err) + } + if 0 == len(keys) { + t.Fatal("Should discover 1 or more keys via", url) + } + + var key PublicKey + for i := range keys { + key = keys[i] + break + } + thumb := key.Thumbprint() + // Look in cache for each (and fail) + if _, ok := hasPublicKeyByThumbprint(thumb); ok { + t.Fatal("SANITY: Should not have any key cached by thumbprint") + } + if _, ok := hasPublicKey(key.KeyID(), url); ok { + t.Fatal("SANITY: Should not have any key cached by kid") + } + // Get with caching + k2, err := GetPublicKey(thumb, url) + if nil != err { + t.Fatal("Error fetching and caching key:", err) + } + // Look in cache for each (and succeed) + if _, ok := hasPublicKeyByThumbprint(thumb); !ok { + t.Fatal("key was not properly cached by thumbprint") + } + if "" != k2.KeyID() { + if _, ok := hasPublicKeyByThumbprint(thumb); !ok { + t.Fatal("key was not properly cached by thumbprint") + } + } else { + t.Log("Key did not have an explicit KeyID") + } + // Get again (should be sub-ms instant) + now := time.Now() + _, err = GetPublicKey(thumb, url) + if nil != err { + t.Fatal("SANITY: Failed to get the key we just got...", err) + } + if time.Now().Sub(now) > time.Millisecond { + t.Fatal("Failed to cache key by thumbprint...", time.Now().Sub(now)) + } + + // Sanity check that the kid and thumb match + if key.KeyID() != k2.KeyID() || key.Thumbprint() != k2.Thumbprint() { + t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), k2.KeyID(), key.Thumbprint(), k2.Thumbprint()) + } } diff --git a/keypairs.go b/keypairs.go index 3a75d2a..aeb5ac8 100644 --- a/keypairs.go +++ b/keypairs.go @@ -39,6 +39,7 @@ type PrivateKey interface { type PublicKey interface { crypto.PublicKey Thumbprint() string + KeyID() string Key() crypto.PublicKey } @@ -57,6 +58,9 @@ type RSAPublicKey struct { func (p *ECPublicKey) Thumbprint() string { return ThumbprintUntypedPublicKey(p.PublicKey) } +func (p *ECPublicKey) KeyID() string { + return p.KID +} func (p *ECPublicKey) Key() crypto.PublicKey { return p.PublicKey } @@ -67,6 +71,9 @@ func (p *ECPublicKey) ExpireAt(t time.Time) { func (p *RSAPublicKey) Thumbprint() string { return ThumbprintUntypedPublicKey(p.PublicKey) } +func (p *RSAPublicKey) KeyID() string { + return p.KID +} func (p *RSAPublicKey) Key() crypto.PublicKey { return p.PublicKey }