move fetch to own package
This commit is contained in:
parent
4ff0e898f1
commit
daea45a09f
|
@ -1,4 +1,4 @@
|
||||||
package keypairs
|
package fetch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -11,6 +11,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
keypairs "github.com/big-squid/go-keypairs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
|
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
|
||||||
|
@ -18,7 +20,7 @@ var KeyCache = map[string]CachableKey{}
|
||||||
var KeyCacheMux = sync.Mutex{}
|
var KeyCacheMux = sync.Mutex{}
|
||||||
|
|
||||||
type CachableKey struct {
|
type CachableKey struct {
|
||||||
Key PublicKey
|
Key keypairs.PublicKey
|
||||||
Expiry time.Time
|
Expiry time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +52,7 @@ var MinimumKeyDuration = time.Hour
|
||||||
var MaximumKeyDuration = 72 * time.Hour
|
var MaximumKeyDuration = 72 * time.Hour
|
||||||
|
|
||||||
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
|
// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
|
||||||
func FetchOIDCPublicKeys(baseURL string) (map[string]PublicKey, error) {
|
func FetchOIDCPublicKeys(baseURL string) (map[string]keypairs.PublicKey, error) {
|
||||||
if _, keys, err := fetchAndCacheOIDCPublicKeys(baseURL); nil != err {
|
if _, keys, err := fetchAndCacheOIDCPublicKeys(baseURL); nil != err {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -58,7 +60,7 @@ func FetchOIDCPublicKeys(baseURL string) (map[string]PublicKey, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]PublicKey, error) {
|
func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
if maps, keys, err := fetchOIDCPublicKeys(baseURL); nil != err {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -67,7 +69,7 @@ func fetchAndCacheOIDCPublicKeys(baseURL string) (map[string]map[string]string,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]PublicKey, error) {
|
func fetchOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
oidcConf := struct {
|
oidcConf := struct {
|
||||||
JWKSURI string `json:"jwks_uri"`
|
JWKSURI string `json:"jwks_uri"`
|
||||||
}{}
|
}{}
|
||||||
|
@ -86,10 +88,10 @@ func fetchOIDCPublicKeys(baseURL string) (map[string]map[string]string, map[stri
|
||||||
return fetchPublicKeys(oidcConf.JWKSURI)
|
return fetchPublicKeys(oidcConf.JWKSURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchOIDCPublicKey(id, baseURL string) (PublicKey, error) {
|
func FetchOIDCPublicKey(id, baseURL string) (keypairs.PublicKey, error) {
|
||||||
return fetchOIDCPublicKey(id, baseURL, fetchAndCacheOIDCPublicKeys)
|
return fetchOIDCPublicKey(id, baseURL, fetchAndCacheOIDCPublicKeys)
|
||||||
}
|
}
|
||||||
func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map[string]string, map[string]PublicKey, error)) (PublicKey, error) {
|
func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error)) (keypairs.PublicKey, error) {
|
||||||
_, keys, err := fetcher(baseURL)
|
_, keys, err := fetcher(baseURL)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -104,9 +106,9 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
|
||||||
|
|
||||||
var kid string
|
var kid string
|
||||||
switch k := key.(type) {
|
switch k := key.(type) {
|
||||||
case *RSAPublicKey:
|
case *keypairs.RSAPublicKey:
|
||||||
kid = k.KID
|
kid = k.KID
|
||||||
case *ECPublicKey:
|
case *keypairs.ECPublicKey:
|
||||||
kid = k.KID
|
kid = k.KID
|
||||||
default:
|
default:
|
||||||
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
panic(errors.New("Developer Error: Only ECPublicKey and RSAPublicKey are handled"))
|
||||||
|
@ -120,7 +122,7 @@ func fetchOIDCPublicKey(id, baseURL string, fetcher func(string) (map[string]map
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
|
// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
|
||||||
func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
func FetchPublicKeys(jwksurl string) (map[string]keypairs.PublicKey, error) {
|
||||||
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
if maps, keys, err := fetchPublicKeys(jwksurl); nil != err {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -129,8 +131,8 @@ func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]PublicKey, error) {
|
func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) {
|
||||||
keys := map[string]PublicKey{}
|
keys := map[string]keypairs.PublicKey{}
|
||||||
maps := map[string]map[string]string{}
|
maps := map[string]map[string]string{}
|
||||||
resp := struct {
|
resp := struct {
|
||||||
Keys []map[string]interface{} `json:"keys"`
|
Keys []map[string]interface{} `json:"keys"`
|
||||||
|
@ -150,7 +152,7 @@ func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]P
|
||||||
k := resp.Keys[i]
|
k := resp.Keys[i]
|
||||||
m := getStringMap(k)
|
m := getStringMap(k)
|
||||||
|
|
||||||
if key, err := NewJWKPublicKey(m); nil != err {
|
if key, err := keypairs.NewJWKPublicKey(m); nil != err {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
} else {
|
} else {
|
||||||
keys[key.Thumbprint()] = key
|
keys[key.Thumbprint()] = key
|
||||||
|
@ -162,7 +164,7 @@ func fetchPublicKeys(jwksurl string) (map[string]map[string]string, map[string]P
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchPublicKey retrieves a JWK from a URL that specifies only one
|
// FetchPublicKey retrieves a JWK from a URL that specifies only one
|
||||||
func FetchPublicKey(url string) (PublicKey, error) {
|
func FetchPublicKey(url string) (keypairs.PublicKey, error) {
|
||||||
m, key, err := fetchPublicKey(url)
|
m, key, err := fetchPublicKey(url)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -171,7 +173,7 @@ func FetchPublicKey(url string) (PublicKey, error) {
|
||||||
maps := map[string]map[string]string{}
|
maps := map[string]map[string]string{}
|
||||||
maps[key.Thumbprint()] = m
|
maps[key.Thumbprint()] = m
|
||||||
|
|
||||||
keys := map[string]PublicKey{}
|
keys := map[string]keypairs.PublicKey{}
|
||||||
keys[key.Thumbprint()] = key
|
keys[key.Thumbprint()] = key
|
||||||
|
|
||||||
cacheKeys(maps, keys, url)
|
cacheKeys(maps, keys, url)
|
||||||
|
@ -179,7 +181,7 @@ func FetchPublicKey(url string) (PublicKey, error) {
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchPublicKey(url string) (map[string]string, PublicKey, error) {
|
func fetchPublicKey(url string) (map[string]string, keypairs.PublicKey, error) {
|
||||||
var m map[string]interface{}
|
var m map[string]interface{}
|
||||||
if err := safeFetch(url, func(body io.Reader) error {
|
if err := safeFetch(url, func(body io.Reader) error {
|
||||||
decoder := json.NewDecoder(body)
|
decoder := json.NewDecoder(body)
|
||||||
|
@ -190,7 +192,7 @@ func fetchPublicKey(url string) (map[string]string, PublicKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
n := getStringMap(m)
|
n := getStringMap(m)
|
||||||
key, err := NewJWKPublicKey(n)
|
key, err := keypairs.NewJWKPublicKey(n)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -225,7 +227,7 @@ func hasPublicKeyByThumbprint(thumb string) (*CachableKey, bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPublicKey(kidOrThumb, iss string) (PublicKey, error) {
|
func GetPublicKey(kidOrThumb, iss string) (keypairs.PublicKey, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
key, ok := hasPublicKeyByThumbprint(kidOrThumb)
|
key, ok := hasPublicKeyByThumbprint(kidOrThumb)
|
||||||
|
|
||||||
|
@ -244,7 +246,7 @@ func GetPublicKey(kidOrThumb, iss string) (PublicKey, error) {
|
||||||
return key.Key, nil
|
return key.Key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var cacheKey = func(kid, iss, expstr string, pub PublicKey) error {
|
var cacheKey = func(kid, iss, expstr string, pub keypairs.PublicKey) error {
|
||||||
var expiry time.Time
|
var expiry time.Time
|
||||||
|
|
||||||
exp, _ := strconv.ParseInt(expstr, 10, 64)
|
exp, _ := strconv.ParseInt(expstr, 10, 64)
|
||||||
|
@ -281,7 +283,7 @@ var cacheKey = func(kid, iss, expstr string, pub PublicKey) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cacheKeys(maps map[string]map[string]string, keys map[string]PublicKey, issuer string) {
|
func cacheKeys(maps map[string]map[string]string, keys map[string]keypairs.PublicKey, issuer string) {
|
||||||
for i := range keys {
|
for i := range keys {
|
||||||
key := keys[i]
|
key := keys[i]
|
||||||
m := maps[i]
|
m := maps[i]
|
|
@ -1,4 +1,4 @@
|
||||||
package keypairs
|
package fetch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
|
@ -6,13 +6,14 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
keypairs "github.com/big-squid/go-keypairs"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFetchOIDCPublicKeys(t *testing.T) {
|
func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||||
urls := []string{
|
urls := []string{
|
||||||
//"https://bigsquid.auth0.com/.well-known/jwks.json",
|
//"https://bigsquid.auth0.com/.well-known/jwks.json",
|
||||||
"https://bigsquid.auth0.com/",
|
"https://bigsquid.auth0.com/",
|
||||||
"https://api-dev.bigsquid.com/",
|
|
||||||
}
|
}
|
||||||
for i := range urls {
|
for i := range urls {
|
||||||
url := urls[i]
|
url := urls[i]
|
||||||
|
@ -24,9 +25,9 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
|
||||||
for kid := range keys {
|
for kid := range keys {
|
||||||
switch key := keys[kid].Key().(type) {
|
switch key := keys[kid].Key().(type) {
|
||||||
case *rsa.PublicKey:
|
case *rsa.PublicKey:
|
||||||
_ = ThumbprintRSAPublicKey(key)
|
_ = keypairs.ThumbprintRSAPublicKey(key)
|
||||||
case *ecdsa.PublicKey:
|
case *ecdsa.PublicKey:
|
||||||
_ = ThumbprintECPublicKey(key)
|
_ = keypairs.ThumbprintECPublicKey(key)
|
||||||
default:
|
default:
|
||||||
t.Fatal(errors.New("unsupported interface type"))
|
t.Fatal(errors.New("unsupported interface type"))
|
||||||
}
|
}
|
||||||
|
@ -46,7 +47,7 @@ func TestCachesKey(t *testing.T) {
|
||||||
t.Fatal("Should discover 1 or more keys via", url)
|
t.Fatal("Should discover 1 or more keys via", url)
|
||||||
}
|
}
|
||||||
|
|
||||||
var key PublicKey
|
var key keypairs.PublicKey
|
||||||
for i := range keys {
|
for i := range keys {
|
||||||
key = keys[i]
|
key = keys[i]
|
||||||
break
|
break
|
Loading…
Reference in New Issue