add key caching
This commit is contained in:
parent
f2468010fa
commit
4ff0e898f1
95
fetch.go
95
fetch.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -21,6 +22,28 @@ type CachableKey struct {
|
||||||
Expiry time.Time
|
Expiry time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method
|
||||||
|
/*
|
||||||
|
type KeyID string
|
||||||
|
|
||||||
|
func (kid KeyID) ID() string {
|
||||||
|
return string(kid)
|
||||||
|
}
|
||||||
|
func (kid KeyID) isID() {}
|
||||||
|
|
||||||
|
type Thumbprint string
|
||||||
|
|
||||||
|
func (thumb Thumbprint) ID() string {
|
||||||
|
return string(thumb)
|
||||||
|
}
|
||||||
|
func (thumb Thumbprint) isID() {}
|
||||||
|
|
||||||
|
type ID interface {
|
||||||
|
ID() string
|
||||||
|
isID()
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
var StaleTime = 15 * time.Minute
|
var StaleTime = 15 * time.Minute
|
||||||
var DefaultKeyDuration = 48 * time.Hour
|
var DefaultKeyDuration = 48 * time.Hour
|
||||||
var MinimumKeyDuration = time.Hour
|
var MinimumKeyDuration = time.Hour
|
||||||
|
@ -39,7 +62,7 @@ func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string,
|
||||||
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
} else {
|
} else {
|
||||||
cacheKeys(maps, keys)
|
cacheKeys(maps, keys, baseURL)
|
||||||
return maps, keys, err
|
return maps, keys, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -72,16 +95,11 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var pub PublicKey
|
|
||||||
var ok bool // because interfaces are never nil
|
|
||||||
|
|
||||||
for i := range keys {
|
for i := range keys {
|
||||||
key := keys[i]
|
key := keys[i]
|
||||||
|
|
||||||
if id == key.Thumbprint() {
|
if id == key.Thumbprint() {
|
||||||
pub = key
|
return key, nil
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var kid string
|
var kid string
|
||||||
|
@ -94,16 +112,10 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
|
||||||
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
||||||
}
|
}
|
||||||
if id == kid {
|
if id == kid {
|
||||||
pub = key
|
return key, nil
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok {
|
|
||||||
return pub, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
|
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +124,7 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
||||||
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
cacheKeys(maps, keys)
|
cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1))
|
||||||
return keys, err
|
return keys, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,7 +168,14 @@ func FetchPublicKey(url string) (PublicKey, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey(m["kid"], m["iss"], m["exp"], key)
|
maps := map[string]map[string]string{}
|
||||||
|
maps[key.Thumbprint()] = m
|
||||||
|
|
||||||
|
keys := map[string]PublicKey{}
|
||||||
|
keys[key.Thumbprint()] = key
|
||||||
|
|
||||||
|
cacheKeys(maps, keys, url)
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,31 +199,46 @@ func fetchPublicKey(url string) (map[string]string, PublicKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasPublicKey(kid, iss string) (*CachableKey, bool) {
|
func hasPublicKey(kid, iss string) (*CachableKey, bool) {
|
||||||
now := time.Now()
|
|
||||||
id := kid + "@" + iss
|
id := kid + "@" + iss
|
||||||
|
|
||||||
KeyCacheMux.Lock()
|
KeyCacheMux.Lock()
|
||||||
hit, ok := KeyCache[id]
|
hit, ok := KeyCache[id]
|
||||||
KeyCacheMux.Unlock()
|
KeyCacheMux.Unlock()
|
||||||
|
|
||||||
if ok && hit.Expiry.Sub(now) > 0 {
|
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
|
||||||
return &hit, true
|
return &hit, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPublicKey(kid, iss string) (PublicKey, error) {
|
// it would be a security risk to pass kid as a thumbprint
|
||||||
|
func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) {
|
||||||
|
KeyCacheMux.Lock()
|
||||||
|
hit, ok := KeyCache[thumb]
|
||||||
|
KeyCacheMux.Unlock()
|
||||||
|
|
||||||
|
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
|
||||||
|
return &hit, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublicKey(kidOrThumb, iss string) (PublicKey, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
key, ok := hasPublicKey(kid, iss)
|
key, ok := hasPublicKeyByThumbprint(kidOrThumb)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return FetchOIDCPublicKey(kid, iss)
|
key, ok = hasPublicKey(kidOrThumb, iss)
|
||||||
|
if !ok {
|
||||||
|
return FetchOIDCPublicKey(kidOrThumb, iss)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch just a little before the key actually expires
|
// Fetch just a little before the key actually expires
|
||||||
if key.Expiry.Sub(now) <= StaleTime {
|
if key.Expiry.Sub(now) <= StaleTime {
|
||||||
go FetchOIDCPublicKey(kid, iss)
|
go FetchOIDCPublicKey(kidOrThumb, iss)
|
||||||
}
|
}
|
||||||
|
|
||||||
return key.Key, nil
|
return key.Key, nil
|
||||||
|
@ -231,27 +265,38 @@ var cacheKey = func(kid, iss, expstr string, pub PublicKey) error {
|
||||||
Key: pub,
|
Key: pub,
|
||||||
Expiry: expiry,
|
Expiry: expiry,
|
||||||
}
|
}
|
||||||
id = pub.Thumbprint() + "@" + iss
|
thumb := pub.Thumbprint()
|
||||||
|
id = thumb + "@" + iss
|
||||||
KeyCache[id] = CachableKey{
|
KeyCache[id] = CachableKey{
|
||||||
Key: pub,
|
Key: pub,
|
||||||
Expiry: expiry,
|
Expiry: expiry,
|
||||||
}
|
}
|
||||||
|
// Since thumbprints are crypto secure, iss is not strictly needed
|
||||||
|
KeyCache[thumb] = CachableKey{
|
||||||
|
Key: pub,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
KeyCacheMux.Unlock()
|
KeyCacheMux.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey) {
|
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey, issuer string) {
|
||||||
for i := range keys {
|
for i := range keys {
|
||||||
key := keys[i]
|
key := keys[i]
|
||||||
m := maps[i]
|
m := maps[i]
|
||||||
cacheKey(m["kid"], m["iss"], m["exp"], key)
|
if "" != m["iss"] {
|
||||||
|
issuer = m["iss"]
|
||||||
|
}
|
||||||
|
cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStringMap(m map[string]interface{}) map[string]string {
|
func getStringMap(m map[string]interface{}) map[string]string {
|
||||||
n := make(map[string]string)
|
n := make(map[string]string)
|
||||||
|
|
||||||
|
// TODO get issuer from x5c, if exists
|
||||||
|
|
||||||
// convert map[string]interface{} to map[string]string
|
// convert map[string]interface{} to map[string]string
|
||||||
for j := range m {
|
for j := range m {
|
||||||
switch s := m[j].(type) {
|
switch s := m[j].(type) {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFetchOIDCPublicKeys(t *testing.T) {
|
func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||||
|
@ -34,9 +35,62 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCachesKey(t *testing.T) {
|
func TestCachesKey(t *testing.T) {
|
||||||
|
url := "https://bigsquid.auth0.com/"
|
||||||
|
|
||||||
// Raw fetch a key and get KID and Thumbprint
|
// Raw fetch a key and get KID and Thumbprint
|
||||||
// Look in cache for each (and fail)
|
_, keys, err := fetchOIDCPublicKeys(url)
|
||||||
// Get with caching
|
if nil != err {
|
||||||
// Look in cache for each (and succeed)
|
t.Fatal(url, err)
|
||||||
// Get again (should be sub-ms instant)
|
}
|
||||||
|
if 0 == len(keys) {
|
||||||
|
t.Fatal("Should discover 1 or more keys via", url)
|
||||||
|
}
|
||||||
|
|
||||||
|
var key PublicKey
|
||||||
|
for i := range keys {
|
||||||
|
key = keys[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
thumb := key.Thumbprint()
|
||||||
|
|
||||||
|
// Look in cache for each (and fail)
|
||||||
|
if _, ok := hasPublicKeyByThumbprint(thumb); ok {
|
||||||
|
t.Fatal("SANITY: Should not have any key cached by thumbprint")
|
||||||
|
}
|
||||||
|
if _, ok := hasPublicKey(key.KeyID(), url); ok {
|
||||||
|
t.Fatal("SANITY: Should not have any key cached by kid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get with caching
|
||||||
|
k2, err := GetPublicKey(thumb, url)
|
||||||
|
if nil != err {
|
||||||
|
t.Fatal("Error fetching and caching key:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look in cache for each (and succeed)
|
||||||
|
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
|
||||||
|
t.Fatal("key was not properly cached by thumbprint")
|
||||||
|
}
|
||||||
|
if "" != k2.KeyID() {
|
||||||
|
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
|
||||||
|
t.Fatal("key was not properly cached by thumbprint")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Log("Key did not have an explicit KeyID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get again (should be sub-ms instant)
|
||||||
|
now := time.Now()
|
||||||
|
_, err = GetPublicKey(thumb, url)
|
||||||
|
if nil != err {
|
||||||
|
t.Fatal("SANITY: Failed to get the key we just got...", err)
|
||||||
|
}
|
||||||
|
if time.Now().Sub(now) > time.Millisecond {
|
||||||
|
t.Fatal("Failed to cache key by thumbprint...", time.Now().Sub(now))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity check that the kid and thumb match
|
||||||
|
if key.KeyID() != k2.KeyID() || key.Thumbprint() != k2.Thumbprint() {
|
||||||
|
t.Fatal("SANITY: KeyIDs or Thumbprints do not match:", key.KeyID(), k2.KeyID(), key.Thumbprint(), k2.Thumbprint())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ type PrivateKey interface {
|
||||||
type PublicKey interface {
|
type PublicKey interface {
|
||||||
crypto.PublicKey
|
crypto.PublicKey
|
||||||
Thumbprint() string
|
Thumbprint() string
|
||||||
|
KeyID() string
|
||||||
Key() crypto.PublicKey
|
Key() crypto.PublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,6 +58,9 @@ type RSAPublicKey struct {
|
||||||
func (p *ECPublicKey) Thumbprint() string {
|
func (p *ECPublicKey) Thumbprint() string {
|
||||||
return ThumbprintUntypedPublicKey(p.PublicKey)
|
return ThumbprintUntypedPublicKey(p.PublicKey)
|
||||||
}
|
}
|
||||||
|
func (p *ECPublicKey) KeyID() string {
|
||||||
|
return p.KID
|
||||||
|
}
|
||||||
func (p *ECPublicKey) Key() crypto.PublicKey {
|
func (p *ECPublicKey) Key() crypto.PublicKey {
|
||||||
return p.PublicKey
|
return p.PublicKey
|
||||||
}
|
}
|
||||||
|
@ -67,6 +71,9 @@ func (p *ECPublicKey) ExpireAt(t time.Time) {
|
||||||
func (p *RSAPublicKey) Thumbprint() string {
|
func (p *RSAPublicKey) Thumbprint() string {
|
||||||
return ThumbprintUntypedPublicKey(p.PublicKey)
|
return ThumbprintUntypedPublicKey(p.PublicKey)
|
||||||
}
|
}
|
||||||
|
func (p *RSAPublicKey) KeyID() string {
|
||||||
|
return p.KID
|
||||||
|
}
|
||||||
func (p *RSAPublicKey) Key() crypto.PublicKey {
|
func (p *RSAPublicKey) Key() crypto.PublicKey {
|
||||||
return p.PublicKey
|
return p.PublicKey
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue