add caching to fetching
This commit is contained in:
parent
34e2ec1a8e
commit
f2468010fa
241
fetch.go
241
fetch.go
|
@ -1,37 +1,125 @@
|
||||||
package keypairs
|
package keypairs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
|
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
|
||||||
|
var KeyCache = map[string]CachableKey{}
|
||||||
|
var KeyCacheMux = sync.Mutex{}
|
||||||
|
|
||||||
|
type CachableKey struct {
|
||||||
|
Key PublicKey
|
||||||
|
Expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var StaleTime = 15 * time.Minute
|
||||||
|
var DefaultKeyDuration = 48 * time.Hour
|
||||||
|
var MinimumKeyDuration = time.Hour
|
||||||
|
var MaximumKeyDuration = 72 * time.Hour
|
||||||
|
|
||||||
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
|
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
|
||||||
func FetchOIDCPublicKeys(baseURL string) (map[string]PublicKey, error) {
|
func FetchOIDCPublicKeys(baseURL string) (map[string]PublicKey, error) {
|
||||||
|
if _, keys, err := fetchAndCacheOIDCPublicKeys(baseURL); nil != err {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]PublicKey, error) {
|
||||||
|
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
} else {
|
||||||
|
cacheKeys(maps, keys)
|
||||||
|
return maps, keys, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]PublicKey, error) {
|
||||||
oidcConf := struct {
|
oidcConf := struct {
|
||||||
JWKSURI string `json:"jwks_uri"`
|
JWKSURI string `json:"jwks_uri"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
// must come in as https://<domain>/
|
// must come in as https://<domain>/
|
||||||
url := baseURL + ".well-known/openid-configuration"
|
url := baseURL + ".well-known/openid-configuration"
|
||||||
err := safeFetch(url, func(body io.Reader) error {
|
err := safeFetch(url, func(body io.Reader) error {
|
||||||
return json.NewDecoder(body).Decode(&oidcConf)
|
decoder := json.NewDecoder(body)
|
||||||
|
decoder.UseNumber()
|
||||||
|
return decoder.Decode(&oidcConf)
|
||||||
})
|
})
|
||||||
|
if nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fetchPublicKeys(oidcConf.JWKSURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchOIDCPublicKey(id, baseURL string) (PublicKey, error) {
|
||||||
|
return fetchOIDCPublicKey(id, baseURL, fetchAndCacheOIDCPublicKeys)
|
||||||
|
}
|
||||||
|
func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map[string]string, map[string]PublicKey, error)) (PublicKey, error) {
|
||||||
|
_, keys, err := fetcher(baseURL)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return FetchPublicKeys(oidcConf.JWKSURI)
|
var pub PublicKey
|
||||||
|
var ok bool // because interfaces are never nil
|
||||||
|
|
||||||
|
for i := range keys {
|
||||||
|
key := keys[i]
|
||||||
|
|
||||||
|
if id == key.Thumbprint() {
|
||||||
|
pub = key
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var kid string
|
||||||
|
switch k := key.(type) {
|
||||||
|
case *RSAPublicKey:
|
||||||
|
kid = k.KID
|
||||||
|
case *ECPublicKey:
|
||||||
|
kid = k.KID
|
||||||
|
default:
|
||||||
|
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
||||||
|
}
|
||||||
|
if id == kid {
|
||||||
|
pub = key
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return pub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
|
// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
|
||||||
func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
||||||
|
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
cacheKeys(maps, keys)
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]PublicKey, error) {
|
||||||
keys := map[string]PublicKey{}
|
keys := map[string]PublicKey{}
|
||||||
|
maps := map[string]map[string]string{}
|
||||||
resp := struct {
|
resp := struct {
|
||||||
Keys []map[string]interface{} `json:"keys"`
|
Keys []map[string]interface{} `json:"keys"`
|
||||||
}{
|
}{
|
||||||
|
@ -39,18 +127,134 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := safeFetch(jwksurl, func(body io.Reader) error {
|
if err := safeFetch(jwksurl, func(body io.Reader) error {
|
||||||
return json.NewDecoder(body).Decode(&resp)
|
decoder := json.NewDecoder(body)
|
||||||
|
decoder.UseNumber()
|
||||||
|
return decoder.Decode(&resp)
|
||||||
}); nil != err {
|
}); nil != err {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range resp.Keys {
|
for i := range resp.Keys {
|
||||||
n := map[string]string{}
|
|
||||||
k := resp.Keys[i]
|
k := resp.Keys[i]
|
||||||
|
m := getStringMap(k)
|
||||||
|
|
||||||
|
if key, err := NewJWKPublicKey(m); nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
} else {
|
||||||
|
keys[key.Thumbprint()] = key
|
||||||
|
maps[key.Thumbprint()] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maps, keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchPublicKey retrieves a JWK from a URL that specifies only one
|
||||||
|
func FetchPublicKey(url string) (PublicKey, error) {
|
||||||
|
m, key, err := fetchPublicKey(url)
|
||||||
|
if nil != err {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey(m["kid"], m["iss"], m["exp"], key)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchPublicKey(url string) (map[string]string, PublicKey, error) {
|
||||||
|
var m map[string]interface{}
|
||||||
|
if err := safeFetch(url, func(body io.Reader) error {
|
||||||
|
decoder := json.NewDecoder(body)
|
||||||
|
decoder.UseNumber()
|
||||||
|
return decoder.Decode(&m)
|
||||||
|
}); nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n := getStringMap(m)
|
||||||
|
key, err := NewJWKPublicKey(n)
|
||||||
|
if nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasPublicKey(kid, iss string) (*CachableKey, bool) {
|
||||||
|
now := time.Now()
|
||||||
|
id := kid + "@" + iss
|
||||||
|
|
||||||
|
KeyCacheMux.Lock()
|
||||||
|
hit, ok := KeyCache[id]
|
||||||
|
KeyCacheMux.Unlock()
|
||||||
|
|
||||||
|
if ok && hit.Expiry.Sub(now) > 0 {
|
||||||
|
return &hit, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublicKey(kid, iss string) (PublicKey, error) {
|
||||||
|
now := time.Now()
|
||||||
|
key, ok := hasPublicKey(kid, iss)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return FetchOIDCPublicKey(kid, iss)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch just a little before the key actually expires
|
||||||
|
if key.Expiry.Sub(now) <= StaleTime {
|
||||||
|
go FetchOIDCPublicKey(kid, iss)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key.Key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cacheKey = func(kid, iss, expstr string, pub PublicKey) error {
|
||||||
|
var expiry time.Time
|
||||||
|
|
||||||
|
exp, _ := strconv.ParseInt(expstr, 10, 64)
|
||||||
|
if 0 == exp {
|
||||||
|
// use default
|
||||||
|
expiry = time.Now().Add(DefaultKeyDuration)
|
||||||
|
} else if exp < time.Now().Add(MinimumKeyDuration).Unix() || exp > time.Now().Add(MaximumKeyDuration).Unix() {
|
||||||
|
// use at least one hour
|
||||||
|
expiry = time.Now().Add(MinimumKeyDuration)
|
||||||
|
} else {
|
||||||
|
expiry = time.Unix(exp, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put the key in the cache by both kid and thumbprint, and set the expiry
|
||||||
|
KeyCacheMux.Lock()
|
||||||
|
id := kid + "@" + iss
|
||||||
|
KeyCache[id] = CachableKey{
|
||||||
|
Key: pub,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
id = pub.Thumbprint() + "@" + iss
|
||||||
|
KeyCache[id] = CachableKey{
|
||||||
|
Key: pub,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
KeyCacheMux.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey) {
|
||||||
|
for i := range keys {
|
||||||
|
key := keys[i]
|
||||||
|
m := maps[i]
|
||||||
|
cacheKey(m["kid"], m["iss"], m["exp"], key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringMap(m map[string]interface{}) map[string]string {
|
||||||
|
n := make(map[string]string)
|
||||||
|
|
||||||
// convert map[string]interface{} to map[string]string
|
// convert map[string]interface{} to map[string]string
|
||||||
for j := range k {
|
for j := range m {
|
||||||
switch s := k[j].(type) {
|
switch s := m[j].(type) {
|
||||||
case string:
|
case string:
|
||||||
n[j] = s
|
n[j] = s
|
||||||
default:
|
default:
|
||||||
|
@ -58,26 +262,7 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if key, err := NewJWKPublicKey(n); nil != err {
|
return n
|
||||||
return nil, err
|
|
||||||
} else {
|
|
||||||
keys[key.Thumbprint()] = key
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return keys, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchPublicKey retrieves a JWK from a URL that specifies only one
|
|
||||||
func FetchPublicKey(url string) (crypto.PublicKey, error) {
|
|
||||||
var m map[string]string
|
|
||||||
if err := safeFetch(url, func(body io.Reader) error {
|
|
||||||
return json.NewDecoder(body).Decode(&m)
|
|
||||||
}); nil != err {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewJWKPublicKey(m)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type decodeFunc func(io.Reader) error
|
type decodeFunc func(io.Reader) error
|
||||||
|
|
|
@ -15,7 +15,7 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
for i := range urls {
|
for i := range urls {
|
||||||
url := urls[i]
|
url := urls[i]
|
||||||
keys, err := FetchOIDCPublicKeys(url)
|
_, keys, err := fetchOIDCPublicKeys(url)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
t.Fatal(url, err)
|
t.Fatal(url, err)
|
||||||
}
|
}
|
||||||
|
@ -32,3 +32,11 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCachesKey(t *testing.T) {
|
||||||
|
// Raw fetch a key and get KID and Thumbprint
|
||||||
|
// Look in cache for each (and fail)
|
||||||
|
// Get with caching
|
||||||
|
// Look in cache for each (and succeed)
|
||||||
|
// Get again (should be sub-ms instant)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue