keypairs/fetch/fetch.go

338 lines
7.6 KiB
Go
Raw Normal View History

2019-02-20 19:59:22 +00:00
package fetch
2019-02-08 01:26:45 +00:00
import (
"encoding/json"
"errors"
2019-02-19 23:50:46 +00:00
"fmt"
2019-02-08 01:26:45 +00:00
"io"
"net"
"net/http"
2019-02-19 23:50:46 +00:00
"strconv"
2019-02-20 19:26:37 +00:00
"strings"
2019-02-19 23:50:46 +00:00
"sync"
2019-02-08 01:26:45 +00:00
"time"
2019-02-20 19:59:22 +00:00
keypairs "github.com/big-squid/go-keypairs"
2019-02-08 01:26:45 +00:00
)
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
2019-02-19 23:50:46 +00:00
var KeyCache = map[string]CachableKey{}
var KeyCacheMux = sync.Mutex{}
type CachableKey struct {
2019-02-20 19:59:22 +00:00
Key keypairs.PublicKey
2019-02-19 23:50:46 +00:00
Expiry time.Time
}
2019-02-20 19:26:37 +00:00
// TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method
/*
type KeyID string
func (kid KeyID) ID() string {
return string(kid)
}
func (kid KeyID) isID() {}
type Thumbprint string
func (thumb Thumbprint) ID() string {
return string(thumb)
}
func (thumb Thumbprint) isID() {}
type ID interface {
ID() string
isID()
}
*/
2019-02-19 23:50:46 +00:00
var StaleTime = 15 * time.Minute
var DefaultKeyDuration = 48 * time.Hour
var MinimumKeyDuration = time.Hour
var MaximumKeyDuration = 72 * time.Hour
2019-02-08 01:26:45 +00:00
2019-02-08 23:53:29 +00:00
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
2019-02-20 19:59:22 +00:00
func FetchOIDCPublicKeys(baseURL string) (map[string]keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
if _, keys, err := fetchAndCacheOIDCPublicKeys(baseURL); nil != err {
return nil, err
} else {
return keys, err
}
}
2019-02-20 19:59:22 +00:00
func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
return nil, nil, err
} else {
2019-02-20 19:26:37 +00:00
cacheKeys(maps, keys, baseURL)
2019-02-19 23:50:46 +00:00
return maps, keys, err
}
}
2019-02-20 19:59:22 +00:00
func fetchOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
2019-02-08 01:26:45 +00:00
oidcConf := struct {
JWKSURI string `json:"jwks_uri"`
}{}
2019-02-19 23:50:46 +00:00
2019-02-08 01:26:45 +00:00
// must come in as https://<domain>/
2019-02-08 23:53:29 +00:00
url := baseURL + ".well-known/openid-configuration"
2019-02-08 01:26:45 +00:00
err := safeFetch(url, func(body io.Reader) error {
2019-02-19 23:50:46 +00:00
decoder := json.NewDecoder(body)
decoder.UseNumber()
return decoder.Decode(&oidcConf)
2019-02-08 01:26:45 +00:00
})
2019-02-19 23:50:46 +00:00
if nil != err {
return nil, nil, err
}
return fetchPublicKeys(oidcConf.JWKSURI)
}
2019-02-20 19:59:22 +00:00
func FetchOIDCPublicKey(id, baseURL string) (keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
return fetchOIDCPublicKey(id, baseURL, fetchAndCacheOIDCPublicKeys)
}
2019-02-20 19:59:22 +00:00
func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error)) (keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
_, keys, err := fetcher(baseURL)
2019-02-08 01:26:45 +00:00
if nil != err {
return nil, err
}
2019-02-19 23:50:46 +00:00
for i := range keys {
key := keys[i]
if id == key.Thumbprint() {
2019-02-20 19:26:37 +00:00
return key, nil
2019-02-19 23:50:46 +00:00
}
var kid string
switch k := key.(type) {
2019-02-20 19:59:22 +00:00
case *keypairs.RSAPublicKey:
2019-02-19 23:50:46 +00:00
kid = k.KID
2019-02-20 19:59:22 +00:00
case *keypairs.ECPublicKey:
2019-02-19 23:50:46 +00:00
kid = k.KID
default:
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
}
if id == kid {
2019-02-20 19:26:37 +00:00
return key, nil
2019-02-19 23:50:46 +00:00
}
}
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
2019-02-08 01:26:45 +00:00
}
2019-02-08 23:53:29 +00:00
// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
2019-02-20 19:59:22 +00:00
func FetchPublicKeys(jwksurl string) (map[string]keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
return nil, err
} else {
2019-02-20 19:26:37 +00:00
cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1))
2019-02-19 23:50:46 +00:00
return keys, err
}
}
2019-02-20 19:59:22 +00:00
func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
keys := map[string]keypairs.PublicKey{}
2019-02-19 23:50:46 +00:00
maps := map[string]map[string]string{}
2019-02-08 01:26:45 +00:00
resp := struct {
Keys []map[string]interface{} `json:"keys"`
}{
Keys: make([]map[string]interface{}, 0, 1),
}
if err := safeFetch(jwksurl, func(body io.Reader) error {
2019-02-19 23:50:46 +00:00
decoder := json.NewDecoder(body)
decoder.UseNumber()
return decoder.Decode(&resp)
2019-02-08 01:26:45 +00:00
}); nil != err {
2019-02-19 23:50:46 +00:00
return nil, nil, err
2019-02-08 01:26:45 +00:00
}
for i := range resp.Keys {
k := resp.Keys[i]
2019-02-19 23:50:46 +00:00
m := getStringMap(k)
2019-02-08 01:26:45 +00:00
2019-02-20 19:59:22 +00:00
if key, err := keypairs.NewJWKPublicKey(m); nil != err {
2019-02-19 23:50:46 +00:00
return nil, nil, err
2019-02-08 01:26:45 +00:00
} else {
2019-02-08 23:53:29 +00:00
keys[key.Thumbprint()] = key
2019-02-19 23:50:46 +00:00
maps[key.Thumbprint()] = m
2019-02-08 01:26:45 +00:00
}
}
2019-02-19 23:50:46 +00:00
return maps, keys, nil
2019-02-08 01:26:45 +00:00
}
2019-02-08 23:53:29 +00:00
// FetchPublicKey retrieves a JWK from a URL that specifies only one
2019-02-20 19:59:22 +00:00
func FetchPublicKey(url string) (keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
m, key, err := fetchPublicKey(url)
if nil != err {
return nil, err
}
2019-02-20 19:26:37 +00:00
maps := map[string]map[string]string{}
maps[key.Thumbprint()] = m
2019-02-20 19:59:22 +00:00
keys := map[string]keypairs.PublicKey{}
2019-02-20 19:26:37 +00:00
keys[key.Thumbprint()] = key
cacheKeys(maps, keys, url)
2019-02-19 23:50:46 +00:00
return key, nil
}
2019-02-20 19:59:22 +00:00
func fetchPublicKey(url string) (map[string]string, keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
var m map[string]interface{}
2019-02-08 01:26:45 +00:00
if err := safeFetch(url, func(body io.Reader) error {
2019-02-19 23:50:46 +00:00
decoder := json.NewDecoder(body)
decoder.UseNumber()
return decoder.Decode(&m)
2019-02-08 01:26:45 +00:00
}); nil != err {
2019-02-19 23:50:46 +00:00
return nil, nil, err
}
n := getStringMap(m)
2019-02-20 19:59:22 +00:00
key, err := keypairs.NewJWKPublicKey(n)
2019-02-19 23:50:46 +00:00
if nil != err {
return nil, nil, err
}
return n, key, nil
}
func hasPublicKey(kid, iss string) (*CachableKey, bool) {
id := kid + "@" + iss
KeyCacheMux.Lock()
hit, ok := KeyCache[id]
KeyCacheMux.Unlock()
2019-02-20 19:26:37 +00:00
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
return &hit, true
}
return nil, false
}
// it would be a security risk to pass kid as a thumbprint
func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) {
KeyCacheMux.Lock()
hit, ok := KeyCache[thumb]
KeyCacheMux.Unlock()
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
2019-02-19 23:50:46 +00:00
return &hit, true
}
return nil, false
}
2019-02-20 19:59:22 +00:00
func GetPublicKey(kidOrThumb, iss string) (keypairs.PublicKey, error) {
2019-02-19 23:50:46 +00:00
now := time.Now()
2019-02-20 19:26:37 +00:00
key, ok := hasPublicKeyByThumbprint(kidOrThumb)
2019-02-19 23:50:46 +00:00
if !ok {
2019-02-20 19:26:37 +00:00
key, ok = hasPublicKey(kidOrThumb, iss)
if !ok {
return FetchOIDCPublicKey(kidOrThumb, iss)
}
2019-02-19 23:50:46 +00:00
}
// Fetch just a little before the key actually expires
if key.Expiry.Sub(now) <= StaleTime {
2019-02-20 19:26:37 +00:00
go FetchOIDCPublicKey(kidOrThumb, iss)
2019-02-19 23:50:46 +00:00
}
return key.Key, nil
}
2019-02-20 19:59:22 +00:00
var cacheKey = func(kid, iss, expstr string, pub keypairs.PublicKey) error {
2019-02-19 23:50:46 +00:00
var expiry time.Time
exp, _ := strconv.ParseInt(expstr, 10, 64)
if 0 == exp {
// use default
expiry = time.Now().Add(DefaultKeyDuration)
} else if exp < time.Now().Add(MinimumKeyDuration).Unix() || exp > time.Now().Add(MaximumKeyDuration).Unix() {
// use at least one hour
expiry = time.Now().Add(MinimumKeyDuration)
} else {
expiry = time.Unix(exp, 0)
}
// Put the key in the cache by both kid and thumbprint, and set the expiry
KeyCacheMux.Lock()
id := kid + "@" + iss
KeyCache[id] = CachableKey{
Key: pub,
Expiry: expiry,
}
2019-02-20 19:26:37 +00:00
thumb := pub.Thumbprint()
id = thumb + "@" + iss
2019-02-19 23:50:46 +00:00
KeyCache[id] = CachableKey{
Key: pub,
Expiry: expiry,
}
2019-02-20 19:26:37 +00:00
// Since thumbprints are crypto secure, iss is not strictly needed
KeyCache[thumb] = CachableKey{
Key: pub,
Expiry: expiry,
}
2019-02-19 23:50:46 +00:00
KeyCacheMux.Unlock()
return nil
}
2019-02-20 19:59:22 +00:00
func cacheKeys(maps map[string]map[string]string, keys map[string]keypairs.PublicKey, issuer string) {
2019-02-19 23:50:46 +00:00
for i := range keys {
key := keys[i]
m := maps[i]
2019-02-20 19:26:37 +00:00
if "" != m["iss"] {
issuer = m["iss"]
}
cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key)
2019-02-19 23:50:46 +00:00
}
}
func getStringMap(m map[string]interface{}) map[string]string {
n := make(map[string]string)
2019-02-20 19:26:37 +00:00
// TODO get issuer from x5c, if exists
2019-02-19 23:50:46 +00:00
// convert map[string]interface{} to map[string]string
for j := range m {
switch s := m[j].(type) {
case string:
n[j] = s
default:
// safely ignore
}
2019-02-08 01:26:45 +00:00
}
2019-02-19 23:50:46 +00:00
return n
2019-02-08 01:26:45 +00:00
}
type decodeFunc func(io.Reader) error
2019-02-08 23:53:29 +00:00
// TODO: also limit the body size
2019-02-08 01:26:45 +00:00
func safeFetch(url string, decoder decodeFunc) error {
var netTransport = &http.Transport{
Dial: (&net.Dialer{
Timeout: 5 * time.Second,
}).Dial,
TLSHandshakeTimeout: 5 * time.Second,
}
var netClient = &http.Client{
Timeout: time.Second * 10,
Transport: netTransport,
}
res, err := netClient.Get(url)
if nil != err {
return err
}
defer res.Body.Close()
return decoder(res.Body)
}