add key caching

This commit is contained in:
AJ ONeal 2019-02-20 19:26:37 +00:00
parent f2468010fa
commit 4ff0e898f1
3 changed files with 131 additions and 25 deletions

View File

@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
) )
@ -21,6 +22,28 @@ type CachableKey struct {
Expiry time.Time Expiry time.Time
} }
// TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method
/*
type KeyID string
func (kid KeyID) ID() string {
return string(kid)
}
func (kid KeyID) isID() {}
type Thumbprint string
func (thumb Thumbprint) ID() string {
return string(thumb)
}
func (thumb Thumbprint) isID() {}
type ID interface {
ID() string
isID()
}
*/
var StaleTime = 15 * time.Minute var StaleTime = 15 * time.Minute
var DefaultKeyDuration = 48 * time.Hour var DefaultKeyDuration = 48 * time.Hour
var MinimumKeyDuration = time.Hour var MinimumKeyDuration = time.Hour
@ -39,7 +62,7 @@ func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string,
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err { if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
return nil, nil, err return nil, nil, err
} else { } else {
cacheKeys(maps, keys) cacheKeys(maps, keys, baseURL)
return maps, keys, err return maps, keys, err
} }
} }
@ -72,16 +95,11 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
return nil, err return nil, err
} }
var pub PublicKey
var ok bool // because interfaces are never nil
for i := range keys { for i := range keys {
key := keys[i] key := keys[i]
if id == key.Thumbprint() { if id == key.Thumbprint() {
pub = key return key, nil
ok = true
break
} }
var kid string var kid string
@ -94,16 +112,10 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled")) panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
} }
if id == kid { if id == kid {
pub = key return key, nil
ok = true
break
} }
} }
if ok {
return pub, nil
}
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL) return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
} }
@ -112,7 +124,7 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err { if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
return nil, err return nil, err
} else { } else {
cacheKeys(maps, keys) cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1))
return keys, err return keys, err
} }
} }
@ -156,7 +168,14 @@ func FetchPublicKey(url string) (PublicKey, error) {
return nil, err return nil, err
} }
cacheKey(m["kid"], m["iss"], m["exp"], key) maps := map[string]map[string]string{}
maps[key.Thumbprint()] = m
keys := map[string]PublicKey{}
keys[key.Thumbprint()] = key
cacheKeys(maps, keys, url)
return key, nil return key, nil
} }
@ -180,31 +199,46 @@ func fetchPublicKey(url string) (map[string]string, PublicKey, error) {
} }
func hasPublicKey(kid, iss string) (*CachableKey, bool) { func hasPublicKey(kid, iss string) (*CachableKey, bool) {
now := time.Now()
id := kid + "@" + iss id := kid + "@" + iss
KeyCacheMux.Lock() KeyCacheMux.Lock()
hit, ok := KeyCache[id] hit, ok := KeyCache[id]
KeyCacheMux.Unlock() KeyCacheMux.Unlock()
if ok && hit.Expiry.Sub(now) > 0 { if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
return &hit, true return &hit, true
} }
return nil, false return nil, false
} }
func GetPublicKey(kid, iss string) (PublicKey, error) { // it would be a security risk to pass kid as a thumbprint
func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) {
KeyCacheMux.Lock()
hit, ok := KeyCache[thumb]
KeyCacheMux.Unlock()
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
return &hit, true
}
return nil, false
}
func GetPublicKey(kidOrThumb, iss string) (PublicKey, error) {
now := time.Now() now := time.Now()
key, ok := hasPublicKey(kid, iss) key, ok := hasPublicKeyByThumbprint(kidOrThumb)
if !ok { if !ok {
return FetchOIDCPublicKey(kid, iss) key, ok = hasPublicKey(kidOrThumb, iss)
if !ok {
return FetchOIDCPublicKey(kidOrThumb, iss)
}
} }
// Fetch just a little before the key actually expires // Fetch just a little before the key actually expires
if key.Expiry.Sub(now) <= StaleTime { if key.Expiry.Sub(now) <= StaleTime {
go FetchOIDCPublicKey(kid, iss) go FetchOIDCPublicKey(kidOrThumb, iss)
} }
return key.Key, nil return key.Key, nil
@ -231,27 +265,38 @@ var cacheKey = func(kid, iss, expstr string, pub PublicKey) error {
Key: pub, Key: pub,
Expiry: expiry, Expiry: expiry,
} }
id = pub.Thumbprint() + "@" + iss thumb := pub.Thumbprint()
id = thumb + "@" + iss
KeyCache[id] = CachableKey{ KeyCache[id] = CachableKey{
Key: pub, Key: pub,
Expiry: expiry, Expiry: expiry,
} }
// Since thumbprints are crypto secure, iss is not strictly needed
KeyCache[thumb] = CachableKey{
Key: pub,
Expiry: expiry,
}
KeyCacheMux.Unlock() KeyCacheMux.Unlock()
return nil return nil
} }
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey) { func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey, issuer string) {
for i := range keys { for i := range keys {
key := keys[i] key := keys[i]
m := maps[i] m := maps[i]
cacheKey(m["kid"], m["iss"], m["exp"], key) if "" != m["iss"] {
issuer = m["iss"]
}
cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key)
} }
} }
func getStringMap(m map[string]interface{}) map[string]string { func getStringMap(m map[string]interface{}) map[string]string {
n := make(map[string]string) n := make(map[string]string)
// TODO get issuer from x5c, if exists
// convert map[string]interface{} to map[string]string // convert map[string]interface{} to map[string]string
for j := range m { for j := range m {
switch s := m[j].(type) { switch s := m[j].(type) {

View File

@ -5,6 +5,7 @@ import (
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"testing" "testing"
"time"
) )
func TestFetchOIDCPublicKeys(t *testing.T) { func TestFetchOIDCPublicKeys(t *testing.T) {
@ -34,9 +35,62 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
} }
func TestCachesKey(t *testing.T) { func TestCachesKey(t *testing.T) {
url := "https://bigsquid.auth0.com/"
// Raw fetch a key and get KID and Thumbprint // Raw fetch a key and get KID and Thumbprint
// Look in cache for each (and fail) _, keys, err := fetchOIDCPublicKeys(url)
// Get with caching if nil != err {
// Look in cache for each (and succeed) t.Fatal(url, err)
// Get again (should be sub-ms instant) }
if 0 == len(keys) {
t.Fatal("Should discover 1 or more keys via", url)
}
var key PublicKey
for i := range keys {
key = keys[i]
break
}
thumb := key.Thumbprint()
// Look in cache for each (and fail)
if _, ok := hasPublicKeyByThumbprint(thumb); ok {
t.Fatal("SANITY: Should not have any key cached by thumbprint")
}
if _, ok := hasPublicKey(key.KeyID(), url); ok {
t.Fatal("SANITY: Should not have any key cached by kid")
}
// Get with caching
k2, err := GetPublicKey(thumb, url)
if nil != err {
t.Fatal("Error fetching and caching key:", err)
}
// Look in cache for each (and succeed)
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
t.Fatal("key was not properly cached by thumbprint")
}
if "" != k2.KeyID() {
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
t.Fatal("key was not properly cached by thumbprint")
}
} else {
t.Log("Key did not have an explicit KeyID")
}
// Get again (should be sub-ms instant)
now := time.Now()
_, err = GetPublicKey(thumb, url)
if nil != err {
t.Fatal("SANITY: Failed to get the key we just got...", err)
}
if time.Now().Sub(now) > time.Millisecond {
t.Fatal("Failed to cache key by thumbprint...", time.Now().Sub(now))
}
// Sanity check that the kid and thumb match
if key.KeyID() != k2.KeyID() || key.Thumbprint() != k2.Thumbprint() {
t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), k2.KeyID(), key.Thumbprint(), k2.Thumbprint())
}
} }

View File

@ -39,6 +39,7 @@ type PrivateKey interface {
type PublicKey interface { type PublicKey interface {
crypto.PublicKey crypto.PublicKey
Thumbprint() string Thumbprint() string
KeyID() string
Key() crypto.PublicKey Key() crypto.PublicKey
} }
@ -57,6 +58,9 @@ type RSAPublicKey struct {
func (p *ECPublicKey) Thumbprint() string { func (p *ECPublicKey) Thumbprint() string {
return ThumbprintUntypedPublicKey(p.PublicKey) return ThumbprintUntypedPublicKey(p.PublicKey)
} }
func (p *ECPublicKey) KeyID() string {
return p.KID
}
func (p *ECPublicKey) Key() crypto.PublicKey { func (p *ECPublicKey) Key() crypto.PublicKey {
return p.PublicKey return p.PublicKey
} }
@ -67,6 +71,9 @@ func (p *ECPublicKey) ExpireAt(t time.Time) {
func (p *RSAPublicKey) Thumbprint() string { func (p *RSAPublicKey) Thumbprint() string {
return ThumbprintUntypedPublicKey(p.PublicKey) return ThumbprintUntypedPublicKey(p.PublicKey)
} }
func (p *RSAPublicKey) KeyID() string {
return p.KID
}
func (p *RSAPublicKey) Key() crypto.PublicKey { func (p *RSAPublicKey) Key() crypto.PublicKey {
return p.PublicKey return p.PublicKey
} }