add key caching
This commit is contained in:
parent
f2468010fa
commit
4ff0e898f1
95
fetch.go
95
fetch.go
|
@ -8,6 +8,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -21,6 +22,28 @@ type CachableKey struct {
|
|||
Expiry time.Time
|
||||
}
|
||||
|
||||
// TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method
|
||||
/*
|
||||
type KeyID string
|
||||
|
||||
func (kid KeyID) ID() string {
|
||||
return string(kid)
|
||||
}
|
||||
func (kid KeyID) isID() {}
|
||||
|
||||
type Thumbprint string
|
||||
|
||||
func (thumb Thumbprint) ID() string {
|
||||
return string(thumb)
|
||||
}
|
||||
func (thumb Thumbprint) isID() {}
|
||||
|
||||
type ID interface {
|
||||
ID() string
|
||||
isID()
|
||||
}
|
||||
*/
|
||||
|
||||
var StaleTime = 15 * time.Minute
|
||||
var DefaultKeyDuration = 48 * time.Hour
|
||||
var MinimumKeyDuration = time.Hour
|
||||
|
@ -39,7 +62,7 @@ func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string,
|
|||
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
||||
return nil, nil, err
|
||||
} else {
|
||||
cacheKeys(maps, keys)
|
||||
cacheKeys(maps, keys, baseURL)
|
||||
return maps, keys, err
|
||||
}
|
||||
}
|
||||
|
@ -72,16 +95,11 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var pub PublicKey
|
||||
var ok bool // because interfaces are never nil
|
||||
|
||||
for i := range keys {
|
||||
key := keys[i]
|
||||
|
||||
if id == key.Thumbprint() {
|
||||
pub = key
|
||||
ok = true
|
||||
break
|
||||
return key, nil
|
||||
}
|
||||
|
||||
var kid string
|
||||
|
@ -94,16 +112,10 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
|
|||
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
||||
}
|
||||
if id == kid {
|
||||
pub = key
|
||||
ok = true
|
||||
break
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
|
||||
}
|
||||
|
||||
|
@ -112,7 +124,7 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
|||
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
||||
return nil, err
|
||||
} else {
|
||||
cacheKeys(maps, keys)
|
||||
cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1))
|
||||
return keys, err
|
||||
}
|
||||
}
|
||||
|
@ -156,7 +168,14 @@ func FetchPublicKey(url string) (PublicKey, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cacheKey(m["kid"], m["iss"], m["exp"], key)
|
||||
maps := map[string]map[string]string{}
|
||||
maps[key.Thumbprint()] = m
|
||||
|
||||
keys := map[string]PublicKey{}
|
||||
keys[key.Thumbprint()] = key
|
||||
|
||||
cacheKeys(maps, keys, url)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
|
@ -180,31 +199,46 @@ func fetchPublicKey(url string) (map[string]string, PublicKey, error) {
|
|||
}
|
||||
|
||||
func hasPublicKey(kid, iss string) (*CachableKey, bool) {
|
||||
now := time.Now()
|
||||
id := kid + "@" + iss
|
||||
|
||||
KeyCacheMux.Lock()
|
||||
hit, ok := KeyCache[id]
|
||||
KeyCacheMux.Unlock()
|
||||
|
||||
if ok && hit.Expiry.Sub(now) > 0 {
|
||||
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
|
||||
return &hit, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func GetPublicKey(kid, iss string) (PublicKey, error) {
|
||||
// it would be a security risk to pass kid as a thumbprint
|
||||
func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) {
|
||||
KeyCacheMux.Lock()
|
||||
hit, ok := KeyCache[thumb]
|
||||
KeyCacheMux.Unlock()
|
||||
|
||||
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
|
||||
return &hit, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func GetPublicKey(kidOrThumb, iss string) (PublicKey, error) {
|
||||
now := time.Now()
|
||||
key, ok := hasPublicKey(kid, iss)
|
||||
key, ok := hasPublicKeyByThumbprint(kidOrThumb)
|
||||
|
||||
if !ok {
|
||||
return FetchOIDCPublicKey(kid, iss)
|
||||
key, ok = hasPublicKey(kidOrThumb, iss)
|
||||
if !ok {
|
||||
return FetchOIDCPublicKey(kidOrThumb, iss)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch just a little before the key actually expires
|
||||
if key.Expiry.Sub(now) <= StaleTime {
|
||||
go FetchOIDCPublicKey(kid, iss)
|
||||
go FetchOIDCPublicKey(kidOrThumb, iss)
|
||||
}
|
||||
|
||||
return key.Key, nil
|
||||
|
@ -231,27 +265,38 @@ var cacheKey = func(kid, iss, expstr string, pub PublicKey) error {
|
|||
Key: pub,
|
||||
Expiry: expiry,
|
||||
}
|
||||
id = pub.Thumbprint() + "@" + iss
|
||||
thumb := pub.Thumbprint()
|
||||
id = thumb + "@" + iss
|
||||
KeyCache[id] = CachableKey{
|
||||
Key: pub,
|
||||
Expiry: expiry,
|
||||
}
|
||||
// Since thumbprints are crypto secure, iss is not strictly needed
|
||||
KeyCache[thumb] = CachableKey{
|
||||
Key: pub,
|
||||
Expiry: expiry,
|
||||
}
|
||||
KeyCacheMux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey) {
|
||||
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey, issuer string) {
|
||||
for i := range keys {
|
||||
key := keys[i]
|
||||
m := maps[i]
|
||||
cacheKey(m["kid"], m["iss"], m["exp"], key)
|
||||
if "" != m["iss"] {
|
||||
issuer = m["iss"]
|
||||
}
|
||||
cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key)
|
||||
}
|
||||
}
|
||||
|
||||
func getStringMap(m map[string]interface{}) map[string]string {
|
||||
n := make(map[string]string)
|
||||
|
||||
// TODO get issuer from x5c, if exists
|
||||
|
||||
// convert map[string]interface{} to map[string]string
|
||||
for j := range m {
|
||||
switch s := m[j].(type) {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/rsa"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||
|
@ -34,9 +35,62 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCachesKey(t *testing.T) {
|
||||
url := "https://bigsquid.auth0.com/"
|
||||
|
||||
// Raw fetch a key and get KID and Thumbprint
|
||||
// Look in cache for each (and fail)
|
||||
// Get with caching
|
||||
// Look in cache for each (and succeed)
|
||||
// Get again (should be sub-ms instant)
|
||||
_, keys, err := fetchOIDCPublicKeys(url)
|
||||
if nil != err {
|
||||
t.Fatal(url, err)
|
||||
}
|
||||
if 0 == len(keys) {
|
||||
t.Fatal("Should discover 1 or more keys via", url)
|
||||
}
|
||||
|
||||
var key PublicKey
|
||||
for i := range keys {
|
||||
key = keys[i]
|
||||
break
|
||||
}
|
||||
thumb := key.Thumbprint()
|
||||
|
||||
// Look in cache for each (and fail)
|
||||
if _, ok := hasPublicKeyByThumbprint(thumb); ok {
|
||||
t.Fatal("SANITY: Should not have any key cached by thumbprint")
|
||||
}
|
||||
if _, ok := hasPublicKey(key.KeyID(), url); ok {
|
||||
t.Fatal("SANITY: Should not have any key cached by kid")
|
||||
}
|
||||
|
||||
// Get with caching
|
||||
k2, err := GetPublicKey(thumb, url)
|
||||
if nil != err {
|
||||
t.Fatal("Error fetching and caching key:", err)
|
||||
}
|
||||
|
||||
// Look in cache for each (and succeed)
|
||||
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
|
||||
t.Fatal("key was not properly cached by thumbprint")
|
||||
}
|
||||
if "" != k2.KeyID() {
|
||||
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
|
||||
t.Fatal("key was not properly cached by thumbprint")
|
||||
}
|
||||
} else {
|
||||
t.Log("Key did not have an explicit KeyID")
|
||||
}
|
||||
|
||||
// Get again (should be sub-ms instant)
|
||||
now := time.Now()
|
||||
_, err = GetPublicKey(thumb, url)
|
||||
if nil != err {
|
||||
t.Fatal("SANITY: Failed to get the key we just got...", err)
|
||||
}
|
||||
if time.Now().Sub(now) > time.Millisecond {
|
||||
t.Fatal("Failed to cache key by thumbprint...", time.Now().Sub(now))
|
||||
}
|
||||
|
||||
// Sanity check that the kid and thumb match
|
||||
if key.KeyID() != k2.KeyID() || key.Thumbprint() != k2.Thumbprint() {
|
||||
t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), k2.KeyID(), key.Thumbprint(), k2.Thumbprint())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ type PrivateKey interface {
|
|||
type PublicKey interface {
|
||||
crypto.PublicKey
|
||||
Thumbprint() string
|
||||
KeyID() string
|
||||
Key() crypto.PublicKey
|
||||
}
|
||||
|
||||
|
@ -57,6 +58,9 @@ type RSAPublicKey struct {
|
|||
func (p *ECPublicKey) Thumbprint() string {
|
||||
return ThumbprintUntypedPublicKey(p.PublicKey)
|
||||
}
|
||||
func (p *ECPublicKey) KeyID() string {
|
||||
return p.KID
|
||||
}
|
||||
func (p *ECPublicKey) Key() crypto.PublicKey {
|
||||
return p.PublicKey
|
||||
}
|
||||
|
@ -67,6 +71,9 @@ func (p *ECPublicKey) ExpireAt(t time.Time) {
|
|||
func (p *RSAPublicKey) Thumbprint() string {
|
||||
return ThumbprintUntypedPublicKey(p.PublicKey)
|
||||
}
|
||||
func (p *RSAPublicKey) KeyID() string {
|
||||
return p.KID
|
||||
}
|
||||
func (p *RSAPublicKey) Key() crypto.PublicKey {
|
||||
return p.PublicKey
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue