refactor
This commit is contained in:
parent
daea45a09f
commit
04259f1e96
|
@ -0,0 +1,36 @@
|
||||||
|
/*
|
||||||
|
Package keypairs complements Go's standard keypair-related packages
|
||||||
|
(encoding/pem, crypto/x509, crypto/rsa, crypto/ecdsa, crypto/elliptic)
|
||||||
|
with JWK encoding support and typesafe PrivateKey and PublicKey interfaces.
|
||||||
|
|
||||||
|
Basics
|
||||||
|
|
||||||
|
key, err := keypairs.ParsePrivateKey(bytesForJWKOrPEMOrDER)
|
||||||
|
|
||||||
|
pub, err := keypairs.ParsePublicKey(bytesForJWKOrPEMOrDER)
|
||||||
|
|
||||||
|
jwk, err := keypairs.MarshalJWKPublicKey(pub, time.Now().Add(2 * time.Day))
|
||||||
|
|
||||||
|
kid, err := keypairs.ThumbprintPublicKey(pub)
|
||||||
|
|
||||||
|
Convenience functions are available which will fetch keys
|
||||||
|
(or retrieve them from cache) via OIDC, .well-known/jwks.json, and direct urls.
|
||||||
|
All keys are cached by Thumbprint, as well as kid(@issuer), if available.
|
||||||
|
|
||||||
|
import "github.com/big-squid/go-keypairs/keyfetch"
|
||||||
|
|
||||||
|
pubs, err := keyfetch.OIDCJWKs("https://example.com/")
|
||||||
|
pubs, err := keyfetch.OIDCJWK(keyIDOrThumb, "https://example.com/")
|
||||||
|
|
||||||
|
pubs, err := keyfetch.WellKnownJWKs("https://example.com/")
|
||||||
|
pubs, err := keyfetch.WellKnownJWK(keyIDOrThumb, "https://example.com/")
|
||||||
|
|
||||||
|
pubs, err := keyfetch.JWKs("https://example.com/path/to/jwks/")
|
||||||
|
pubs, err := keyfetch.JWK(keyIDOrThumb, "https://example.com/path/to/jwks/)
|
||||||
|
|
||||||
|
pubs, err := keyfetch.Get("https://example.com/jwk.json)
|
||||||
|
|
||||||
|
A non-caching version with the same capabilities is also available.
|
||||||
|
|
||||||
|
*/
|
||||||
|
package keypairs
|
337
fetch/fetch.go
337
fetch/fetch.go
|
@ -1,337 +0,0 @@
|
||||||
package fetch
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
keypairs "github.com/big-squid/go-keypairs"
|
|
||||||
)
|
|
||||||
|
|
||||||
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
|
|
||||||
var KeyCache = map[string]CachableKey{}
|
|
||||||
var KeyCacheMux = sync.Mutex{}
|
|
||||||
|
|
||||||
type CachableKey struct {
|
|
||||||
Key keypairs.PublicKey
|
|
||||||
Expiry time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method
|
|
||||||
/*
|
|
||||||
type KeyID string
|
|
||||||
|
|
||||||
func (kid KeyID) ID() string {
|
|
||||||
return string(kid)
|
|
||||||
}
|
|
||||||
func (kid KeyID) isID() {}
|
|
||||||
|
|
||||||
type Thumbprint string
|
|
||||||
|
|
||||||
func (thumb Thumbprint) ID() string {
|
|
||||||
return string(thumb)
|
|
||||||
}
|
|
||||||
func (thumb Thumbprint) isID() {}
|
|
||||||
|
|
||||||
type ID interface {
|
|
||||||
ID() string
|
|
||||||
isID()
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
var StaleTime = 15 * time.Minute
|
|
||||||
var DefaultKeyDuration = 48 * time.Hour
|
|
||||||
var MinimumKeyDuration = time.Hour
|
|
||||||
var MaximumKeyDuration = 72 * time.Hour
|
|
||||||
|
|
||||||
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
|
|
||||||
func FetchOIDCPublicKeys(baseURL string) (map[string]keypairs.PublicKey, error) {
|
|
||||||
if _, keys, err := fetchAndCacheOIDCPublicKeys(baseURL); nil != err {
|
|
||||||
return nil, err
|
|
||||||
} else {
|
|
||||||
return keys, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
|
||||||
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
|
||||||
return nil, nil, err
|
|
||||||
} else {
|
|
||||||
cacheKeys(maps, keys, baseURL)
|
|
||||||
return maps, keys, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
|
||||||
oidcConf := struct {
|
|
||||||
JWKSURI string `json:"jwks_uri"`
|
|
||||||
}{}
|
|
||||||
|
|
||||||
// must come in as https://<domain>/
|
|
||||||
url := baseURL + ".well-known/openid-configuration"
|
|
||||||
err := safeFetch(url, func(body io.Reader) error {
|
|
||||||
decoder := json.NewDecoder(body)
|
|
||||||
decoder.UseNumber()
|
|
||||||
return decoder.Decode(&oidcConf)
|
|
||||||
})
|
|
||||||
if nil != err {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fetchPublicKeys(oidcConf.JWKSURI)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FetchOIDCPublicKey(id, baseURL string) (keypairs.PublicKey, error) {
|
|
||||||
return fetchOIDCPublicKey(id, baseURL, fetchAndCacheOIDCPublicKeys)
|
|
||||||
}
|
|
||||||
func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error)) (keypairs.PublicKey, error) {
|
|
||||||
_, keys, err := fetcher(baseURL)
|
|
||||||
if nil != err {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range keys {
|
|
||||||
key := keys[i]
|
|
||||||
|
|
||||||
if id == key.Thumbprint() {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var kid string
|
|
||||||
switch k := key.(type) {
|
|
||||||
case *keypairs.RSAPublicKey:
|
|
||||||
kid = k.KID
|
|
||||||
case *keypairs.ECPublicKey:
|
|
||||||
kid = k.KID
|
|
||||||
default:
|
|
||||||
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
|
||||||
}
|
|
||||||
if id == kid {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
|
|
||||||
func FetchPublicKeys(jwksurl string) (map[string]keypairs.PublicKey, error) {
|
|
||||||
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
|
||||||
return nil, err
|
|
||||||
} else {
|
|
||||||
cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1))
|
|
||||||
return keys, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
|
||||||
keys := map[string]keypairs.PublicKey{}
|
|
||||||
maps := map[string]map[string]string{}
|
|
||||||
resp := struct {
|
|
||||||
Keys []map[string]interface{} `json:"keys"`
|
|
||||||
}{
|
|
||||||
Keys: make([]map[string]interface{}, 0, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := safeFetch(jwksurl, func(body io.Reader) error {
|
|
||||||
decoder := json.NewDecoder(body)
|
|
||||||
decoder.UseNumber()
|
|
||||||
return decoder.Decode(&resp)
|
|
||||||
}); nil != err {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range resp.Keys {
|
|
||||||
k := resp.Keys[i]
|
|
||||||
m := getStringMap(k)
|
|
||||||
|
|
||||||
if key, err := keypairs.NewJWKPublicKey(m); nil != err {
|
|
||||||
return nil, nil, err
|
|
||||||
} else {
|
|
||||||
keys[key.Thumbprint()] = key
|
|
||||||
maps[key.Thumbprint()] = m
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return maps, keys, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchPublicKey retrieves a JWK from a URL that specifies only one
|
|
||||||
func FetchPublicKey(url string) (keypairs.PublicKey, error) {
|
|
||||||
m, key, err := fetchPublicKey(url)
|
|
||||||
if nil != err {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
maps := map[string]map[string]string{}
|
|
||||||
maps[key.Thumbprint()] = m
|
|
||||||
|
|
||||||
keys := map[string]keypairs.PublicKey{}
|
|
||||||
keys[key.Thumbprint()] = key
|
|
||||||
|
|
||||||
cacheKeys(maps, keys, url)
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchPublicKey(url string) (map[string]string, keypairs.PublicKey, error) {
|
|
||||||
var m map[string]interface{}
|
|
||||||
if err := safeFetch(url, func(body io.Reader) error {
|
|
||||||
decoder := json.NewDecoder(body)
|
|
||||||
decoder.UseNumber()
|
|
||||||
return decoder.Decode(&m)
|
|
||||||
}); nil != err {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n := getStringMap(m)
|
|
||||||
key, err := keypairs.NewJWKPublicKey(n)
|
|
||||||
if nil != err {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasPublicKey(kid, iss string) (*CachableKey, bool) {
|
|
||||||
id := kid + "@" + iss
|
|
||||||
|
|
||||||
KeyCacheMux.Lock()
|
|
||||||
hit, ok := KeyCache[id]
|
|
||||||
KeyCacheMux.Unlock()
|
|
||||||
|
|
||||||
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
|
|
||||||
return &hit, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// it would be a security risk to pass kid as a thumbprint
|
|
||||||
func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) {
|
|
||||||
KeyCacheMux.Lock()
|
|
||||||
hit, ok := KeyCache[thumb]
|
|
||||||
KeyCacheMux.Unlock()
|
|
||||||
|
|
||||||
if now := time.Now(); ok && hit.Expiry.Sub(now) > 0 {
|
|
||||||
return &hit, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPublicKey(kidOrThumb, iss string) (keypairs.PublicKey, error) {
|
|
||||||
now := time.Now()
|
|
||||||
key, ok := hasPublicKeyByThumbprint(kidOrThumb)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
key, ok = hasPublicKey(kidOrThumb, iss)
|
|
||||||
if !ok {
|
|
||||||
return FetchOIDCPublicKey(kidOrThumb, iss)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch just a little before the key actually expires
|
|
||||||
if key.Expiry.Sub(now) <= StaleTime {
|
|
||||||
go FetchOIDCPublicKey(kidOrThumb, iss)
|
|
||||||
}
|
|
||||||
|
|
||||||
return key.Key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var cacheKey = func(kid, iss, expstr string, pub keypairs.PublicKey) error {
|
|
||||||
var expiry time.Time
|
|
||||||
|
|
||||||
exp, _ := strconv.ParseInt(expstr, 10, 64)
|
|
||||||
if 0 == exp {
|
|
||||||
// use default
|
|
||||||
expiry = time.Now().Add(DefaultKeyDuration)
|
|
||||||
} else if exp < time.Now().Add(MinimumKeyDuration).Unix() || exp > time.Now().Add(MaximumKeyDuration).Unix() {
|
|
||||||
// use at least one hour
|
|
||||||
expiry = time.Now().Add(MinimumKeyDuration)
|
|
||||||
} else {
|
|
||||||
expiry = time.Unix(exp, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put the key in the cache by both kid and thumbprint, and set the expiry
|
|
||||||
KeyCacheMux.Lock()
|
|
||||||
id := kid + "@" + iss
|
|
||||||
KeyCache[id] = CachableKey{
|
|
||||||
Key: pub,
|
|
||||||
Expiry: expiry,
|
|
||||||
}
|
|
||||||
thumb := pub.Thumbprint()
|
|
||||||
id = thumb + "@" + iss
|
|
||||||
KeyCache[id] = CachableKey{
|
|
||||||
Key: pub,
|
|
||||||
Expiry: expiry,
|
|
||||||
}
|
|
||||||
// Since thumbprints are crypto secure, iss is not strictly needed
|
|
||||||
KeyCache[thumb] = CachableKey{
|
|
||||||
Key: pub,
|
|
||||||
Expiry: expiry,
|
|
||||||
}
|
|
||||||
KeyCacheMux.Unlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cacheKeys(maps map[string]map[string]string, keys map[string]keypairs.PublicKey, issuer string) {
|
|
||||||
for i := range keys {
|
|
||||||
key := keys[i]
|
|
||||||
m := maps[i]
|
|
||||||
if "" != m["iss"] {
|
|
||||||
issuer = m["iss"]
|
|
||||||
}
|
|
||||||
cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getStringMap(m map[string]interface{}) map[string]string {
|
|
||||||
n := make(map[string]string)
|
|
||||||
|
|
||||||
// TODO get issuer from x5c, if exists
|
|
||||||
|
|
||||||
// convert map[string]interface{} to map[string]string
|
|
||||||
for j := range m {
|
|
||||||
switch s := m[j].(type) {
|
|
||||||
case string:
|
|
||||||
n[j] = s
|
|
||||||
default:
|
|
||||||
// safely ignore
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
type decodeFunc func(io.Reader) error
|
|
||||||
|
|
||||||
// TODO: also limit the body size
|
|
||||||
func safeFetch(url string, decoder decodeFunc) error {
|
|
||||||
var netTransport = &http.Transport{
|
|
||||||
Dial: (&net.Dialer{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
var netClient = &http.Client{
|
|
||||||
Timeout: time.Second * 10,
|
|
||||||
Transport: netTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := netClient.Get(url)
|
|
||||||
if nil != err {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
return decoder(res.Body)
|
|
||||||
}
|
|
|
@ -0,0 +1,232 @@
|
||||||
|
package keyfetch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
keypairs "github.com/big-squid/go-keypairs"
|
||||||
|
"github.com/big-squid/go-keypairs/keyfetch/uncached"
|
||||||
|
)
|
||||||
|
|
||||||
|
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
|
||||||
|
var KeyCache = map[string]CachableKey{}
|
||||||
|
var KeyCacheMux = sync.Mutex{}
|
||||||
|
|
||||||
|
type CachableKey struct {
|
||||||
|
Key keypairs.PublicKey
|
||||||
|
Expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybe TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method?
|
||||||
|
/*
|
||||||
|
type KeyID string
|
||||||
|
|
||||||
|
func (kid KeyID) ID() string {
|
||||||
|
return string(kid)
|
||||||
|
}
|
||||||
|
func (kid KeyID) isID() {}
|
||||||
|
|
||||||
|
type Thumbprint string
|
||||||
|
|
||||||
|
func (thumb Thumbprint) ID() string {
|
||||||
|
return string(thumb)
|
||||||
|
}
|
||||||
|
func (thumb Thumbprint) isID() {}
|
||||||
|
|
||||||
|
type ID interface {
|
||||||
|
ID() string
|
||||||
|
isID()
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
var StaleTime = 15 * time.Minute
|
||||||
|
var DefaultKeyDuration = 48 * time.Hour
|
||||||
|
var MinimumKeyDuration = time.Hour
|
||||||
|
var MaximumKeyDuration = 72 * time.Hour
|
||||||
|
|
||||||
|
type publicKeysMap map[string]keypairs.PublicKey
|
||||||
|
|
||||||
|
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
|
||||||
|
func OIDCJWKs(baseURL string) (publicKeysMap, error) {
|
||||||
|
if maps, keys, err := uncached.OIDCJWKs(baseURL); nil != err {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
cacheKeys(maps, keys, baseURL)
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) {
|
||||||
|
return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WellKnownJWKs(kidOrThumb, iss string) (publicKeysMap, error) {
|
||||||
|
if maps, keys, err := uncached.WellKnownJWKs(iss); nil != err {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
cacheKeys(maps, keys, iss)
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKey, error) {
|
||||||
|
return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWKs returns a map of keys identified by their thumbprint
|
||||||
|
// (since kid may or may not be present)
|
||||||
|
func JWKs(jwksurl string) (publicKeysMap, error) {
|
||||||
|
if maps, keys, err := uncached.JWKs(jwksurl); nil != err {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
cacheKeys(maps, keys, strings.Replace(jwksurl, ".well-known/jwks.json", "", 1))
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer
|
||||||
|
func JWK(kidOrThumb, iss string) (keypairs.PublicKey, error) {
|
||||||
|
return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch returns a key from cache, falling back to an exact url as the "issuer"
|
||||||
|
func Fetch(url string) (keypairs.PublicKey, error) {
|
||||||
|
// url is kid in this case
|
||||||
|
return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
|
m, key, err := uncached.Fetch(url)
|
||||||
|
if nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// put in a map, just for caching
|
||||||
|
maps := map[string]map[string]string{}
|
||||||
|
maps[key.Thumbprint()] = m
|
||||||
|
|
||||||
|
keys := map[string]keypairs.PublicKey{}
|
||||||
|
keys[key.Thumbprint()] = key
|
||||||
|
|
||||||
|
return maps, keys, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a key from cache, or returns an error.
|
||||||
|
// The issuer string may be empty if using a thumbprint rather than a kid.
|
||||||
|
func Get(kidOrThumb, iss string) keypairs.PublicKey {
|
||||||
|
if pub := get(kidOrThumb, iss); nil != pub {
|
||||||
|
return pub.Key
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func get(kidOrThumb, iss string) *CachableKey {
|
||||||
|
KeyCacheMux.Lock()
|
||||||
|
defer KeyCacheMux.Unlock()
|
||||||
|
|
||||||
|
// we're safe to check the cache by kid alone
|
||||||
|
// by virtue that we never set it by kid alone
|
||||||
|
hit, ok := KeyCache[kidOrThumb]
|
||||||
|
if ok {
|
||||||
|
if now := time.Now(); hit.Expiry.Sub(now) > 0 {
|
||||||
|
// only return non-expired keys
|
||||||
|
return &hit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
id := kidOrThumb + "@" + strings.TrimRight(iss, "/")
|
||||||
|
hit, ok = KeyCache[id]
|
||||||
|
if ok {
|
||||||
|
if now := time.Now(); hit.Expiry.Sub(now) > 0 {
|
||||||
|
// only return non-expired keys
|
||||||
|
return &hit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKey, error) {
|
||||||
|
now := time.Now()
|
||||||
|
key := get(kidOrThumb, iss)
|
||||||
|
|
||||||
|
if nil == key {
|
||||||
|
return fetchAndSelect(kidOrThumb, iss, fetcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch just a little before the key actually expires
|
||||||
|
if key.Expiry.Sub(now) <= StaleTime {
|
||||||
|
go fetchAndSelect(kidOrThumb, iss, fetcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key.Key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error)
|
||||||
|
|
||||||
|
func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKey, error) {
|
||||||
|
maps, keys, err := fetcher(baseURL)
|
||||||
|
if nil != err {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cacheKeys(maps, keys, baseURL)
|
||||||
|
|
||||||
|
for i := range keys {
|
||||||
|
key := keys[i]
|
||||||
|
|
||||||
|
if id == key.Thumbprint() {
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == key.KeyID() {
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheKeys(maps map[string]map[string]string, keys map[string]keypairs.PublicKey, issuer string) {
|
||||||
|
for i := range keys {
|
||||||
|
key := keys[i]
|
||||||
|
m := maps[i]
|
||||||
|
if "" != m["iss"] {
|
||||||
|
issuer = m["iss"]
|
||||||
|
}
|
||||||
|
cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheKey(kid, iss, expstr string, pub keypairs.PublicKey) error {
|
||||||
|
var expiry time.Time
|
||||||
|
|
||||||
|
exp, _ := strconv.ParseInt(expstr, 10, 64)
|
||||||
|
if 0 == exp {
|
||||||
|
// use default
|
||||||
|
expiry = time.Now().Add(DefaultKeyDuration)
|
||||||
|
} else if exp < time.Now().Add(MinimumKeyDuration).Unix() || exp > time.Now().Add(MaximumKeyDuration).Unix() {
|
||||||
|
// use at least one hour
|
||||||
|
expiry = time.Now().Add(MinimumKeyDuration)
|
||||||
|
} else {
|
||||||
|
expiry = time.Unix(exp, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
KeyCacheMux.Lock()
|
||||||
|
defer KeyCacheMux.Unlock()
|
||||||
|
// Put the key in the cache by both kid and thumbprint, and set the expiry
|
||||||
|
id := kid + "@" + iss
|
||||||
|
KeyCache[id] = CachableKey{
|
||||||
|
Key: pub,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
// Since thumbprints are crypto secure, iss isn't needed
|
||||||
|
thumb := pub.Thumbprint()
|
||||||
|
KeyCache[thumb] = CachableKey{
|
||||||
|
Key: pub,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,45 +1,18 @@
|
||||||
package fetch
|
package keyfetch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
keypairs "github.com/big-squid/go-keypairs"
|
keypairs "github.com/big-squid/go-keypairs"
|
||||||
|
"github.com/big-squid/go-keypairs/keyfetch/uncached"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFetchOIDCPublicKeys(t *testing.T) {
|
|
||||||
urls := []string{
|
|
||||||
//"https://bigsquid.auth0.com/.well-known/jwks.json",
|
|
||||||
"https://bigsquid.auth0.com/",
|
|
||||||
}
|
|
||||||
for i := range urls {
|
|
||||||
url := urls[i]
|
|
||||||
_, keys, err := fetchOIDCPublicKeys(url)
|
|
||||||
if nil != err {
|
|
||||||
t.Fatal(url, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for kid := range keys {
|
|
||||||
switch key := keys[kid].Key().(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
_ = keypairs.ThumbprintRSAPublicKey(key)
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
_ = keypairs.ThumbprintECPublicKey(key)
|
|
||||||
default:
|
|
||||||
t.Fatal(errors.New("unsupported interface type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCachesKey(t *testing.T) {
|
func TestCachesKey(t *testing.T) {
|
||||||
url := "https://bigsquid.auth0.com/"
|
url := "https://bigsquid.auth0.com/"
|
||||||
|
|
||||||
// Raw fetch a key and get KID and Thumbprint
|
// Raw fetch a key and get KID and Thumbprint
|
||||||
_, keys, err := fetchOIDCPublicKeys(url)
|
_, keys, err := uncached.OIDCJWKs(url)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
t.Fatal(url, err)
|
t.Fatal(url, err)
|
||||||
}
|
}
|
||||||
|
@ -55,26 +28,23 @@ func TestCachesKey(t *testing.T) {
|
||||||
thumb := key.Thumbprint()
|
thumb := key.Thumbprint()
|
||||||
|
|
||||||
// Look in cache for each (and fail)
|
// Look in cache for each (and fail)
|
||||||
if _, ok := hasPublicKeyByThumbprint(thumb); ok {
|
if pub := Get(thumb, ""); nil != pub {
|
||||||
t.Fatal("SANITY: Should not have any key cached by thumbprint")
|
t.Fatal("SANITY: Should not have any key cached by thumbprint")
|
||||||
}
|
}
|
||||||
if _, ok := hasPublicKey(key.KeyID(), url); ok {
|
|
||||||
t.Fatal("SANITY: Should not have any key cached by kid")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get with caching
|
// Get with caching
|
||||||
k2, err := GetPublicKey(thumb, url)
|
k2, err := OIDCJWK(thumb, url)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
t.Fatal("Error fetching and caching key:", err)
|
t.Fatal("Error fetching and caching key:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look in cache for each (and succeed)
|
// Look in cache for each (and succeed)
|
||||||
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
|
if pub := Get(thumb, ""); nil == pub {
|
||||||
t.Fatal("key was not properly cached by thumbprint")
|
t.Fatal("key was not properly cached by thumbprint", thumb)
|
||||||
}
|
}
|
||||||
if "" != k2.KeyID() {
|
if "" != k2.KeyID() {
|
||||||
if _, ok := hasPublicKeyByThumbprint(thumb); !ok {
|
if pub := Get(k2.KeyID(), url); nil == pub {
|
||||||
t.Fatal("key was not properly cached by thumbprint")
|
t.Fatal("key was not properly cached by kid", k2.KeyID())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Log("Key did not have an explicit KeyID")
|
t.Log("Key did not have an explicit KeyID")
|
||||||
|
@ -82,7 +52,7 @@ func TestCachesKey(t *testing.T) {
|
||||||
|
|
||||||
// Get again (should be sub-ms instant)
|
// Get again (should be sub-ms instant)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, err = GetPublicKey(thumb, url)
|
_, err = OIDCJWK(thumb, url)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
t.Fatal("SANITY: Failed to get the key we just got...", err)
|
t.Fatal("SANITY: Failed to get the key we just got...", err)
|
||||||
}
|
}
|
|
@ -0,0 +1,136 @@
|
||||||
|
// Package uncached provides uncached versions of go-keypairs/keyfetch
|
||||||
|
package uncached
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
keypairs "github.com/big-squid/go-keypairs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OIDCJWKs gets the OpenID Connect configuration from the baseURL and then calls JWKs with the specified jwks_uri
|
||||||
|
func OIDCJWKs(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
|
oidcConf := struct {
|
||||||
|
JWKSURI string `json:"jwks_uri"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
// must come in as https://<domain>/
|
||||||
|
url := baseURL + ".well-known/openid-configuration"
|
||||||
|
err := safeFetch(url, func(body io.Reader) error {
|
||||||
|
decoder := json.NewDecoder(body)
|
||||||
|
decoder.UseNumber()
|
||||||
|
return decoder.Decode(&oidcConf)
|
||||||
|
})
|
||||||
|
if nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return JWKs(oidcConf.JWKSURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WellKnownJWKs calls JWKs with baseURL + /.well-known/jwks.json as constructs the jwks_uri
|
||||||
|
func WellKnownJWKs(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
|
if '/' == baseURL[len(baseURL)-1] {
|
||||||
|
baseURL = baseURL[:len(baseURL)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return JWKs(baseURL + "/.well-known/jwks.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWKs fetches and parses a jwks.json (assuming well-known format)
|
||||||
|
func JWKs(jwksurl string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
|
keys := map[string]keypairs.PublicKey{}
|
||||||
|
maps := map[string]map[string]string{}
|
||||||
|
resp := struct {
|
||||||
|
Keys []map[string]interface{} `json:"keys"`
|
||||||
|
}{
|
||||||
|
Keys: make([]map[string]interface{}, 0, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := safeFetch(jwksurl, func(body io.Reader) error {
|
||||||
|
decoder := json.NewDecoder(body)
|
||||||
|
decoder.UseNumber()
|
||||||
|
return decoder.Decode(&resp)
|
||||||
|
}); nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range resp.Keys {
|
||||||
|
k := resp.Keys[i]
|
||||||
|
m := getStringMap(k)
|
||||||
|
|
||||||
|
if key, err := keypairs.NewJWKPublicKey(m); nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
} else {
|
||||||
|
keys[key.Thumbprint()] = key
|
||||||
|
maps[key.Thumbprint()] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maps, keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch retrieves a single JWK (plain, bare jwk) from a URL (off-spec)
|
||||||
|
func Fetch(url string) (map[string]string, keypairs.PublicKey, error) {
|
||||||
|
var m map[string]interface{}
|
||||||
|
if err := safeFetch(url, func(body io.Reader) error {
|
||||||
|
decoder := json.NewDecoder(body)
|
||||||
|
decoder.UseNumber()
|
||||||
|
return decoder.Decode(&m)
|
||||||
|
}); nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n := getStringMap(m)
|
||||||
|
key, err := keypairs.NewJWKPublicKey(n)
|
||||||
|
if nil != err {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringMap(m map[string]interface{}) map[string]string {
|
||||||
|
n := make(map[string]string)
|
||||||
|
|
||||||
|
// TODO get issuer from x5c, if exists
|
||||||
|
|
||||||
|
// convert map[string]interface{} to map[string]string
|
||||||
|
for j := range m {
|
||||||
|
switch s := m[j].(type) {
|
||||||
|
case string:
|
||||||
|
n[j] = s
|
||||||
|
default:
|
||||||
|
// safely ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
type decodeFunc func(io.Reader) error
|
||||||
|
|
||||||
|
// TODO: also limit the body size
|
||||||
|
func safeFetch(url string, decoder decodeFunc) error {
|
||||||
|
var netTransport = &http.Transport{
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}).Dial,
|
||||||
|
TLSHandshakeTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
var netClient = &http.Client{
|
||||||
|
Timeout: time.Second * 10,
|
||||||
|
Transport: netTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := netClient.Get(url)
|
||||||
|
if nil != err {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
return decoder(res.Body)
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
package uncached
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
keypairs "github.com/big-squid/go-keypairs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWKs(t *testing.T) {
|
||||||
|
urls := []string{
|
||||||
|
"https://bigsquid.auth0.com/.well-known/jwks.json",
|
||||||
|
}
|
||||||
|
for i := range urls {
|
||||||
|
url := urls[i]
|
||||||
|
_, keys, err := JWKs(url)
|
||||||
|
if nil != err {
|
||||||
|
t.Fatal(url, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for kid := range keys {
|
||||||
|
switch key := keys[kid].Key().(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
_ = keypairs.ThumbprintRSAPublicKey(key)
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
_ = keypairs.ThumbprintECPublicKey(key)
|
||||||
|
default:
|
||||||
|
t.Fatal(errors.New("unsupported interface type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWellKnownJWKs(t *testing.T) {
|
||||||
|
urls := []string{
|
||||||
|
//"https://bigsquid.auth0.com/.well-known/jwks.json"
|
||||||
|
"https://bigsquid.auth0.com/",
|
||||||
|
}
|
||||||
|
for i := range urls {
|
||||||
|
url := urls[i]
|
||||||
|
_, keys, err := WellKnownJWKs(url)
|
||||||
|
if nil != err {
|
||||||
|
t.Fatal(url, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for kid := range keys {
|
||||||
|
switch key := keys[kid].Key().(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
_ = keypairs.ThumbprintRSAPublicKey(key)
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
_ = keypairs.ThumbprintECPublicKey(key)
|
||||||
|
default:
|
||||||
|
t.Fatal(errors.New("unsupported interface type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDCJWKs(t *testing.T) {
|
||||||
|
urls := []string{
|
||||||
|
//"https://bigsquid.auth0.com/.well-known/openid-configuration"
|
||||||
|
//"https://bigsquid.auth0.com/.well-known/jwks.json"
|
||||||
|
"https://bigsquid.auth0.com/",
|
||||||
|
}
|
||||||
|
for i := range urls {
|
||||||
|
url := urls[i]
|
||||||
|
_, keys, err := OIDCJWKs(url)
|
||||||
|
if nil != err {
|
||||||
|
t.Fatal(url, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for kid := range keys {
|
||||||
|
switch key := keys[kid].Key().(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
_ = keypairs.ThumbprintRSAPublicKey(key)
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
_ = keypairs.ThumbprintECPublicKey(key)
|
||||||
|
default:
|
||||||
|
t.Fatal(errors.New("unsupported interface type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -81,7 +81,7 @@ func (p *RSAPublicKey) ExpireAt(t time.Time) {
|
||||||
p.Expiry = t
|
p.Expiry = t
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypesafePublicKey wraps a crypto.PublicKey to make it typesafe.
|
// NewPublicKey wraps a crypto.PublicKey to make it typesafe.
|
||||||
func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey {
|
func NewPublicKey(pub crypto.PublicKey, kid ...string) PublicKey {
|
||||||
var k PublicKey
|
var k PublicKey
|
||||||
switch p := pub.(type) {
|
switch p := pub.(type) {
|
||||||
|
|
Loading…
Reference in New Issue