diff --git a/keyfetch/fetch.go b/keyfetch/fetch.go index ef4ebe7..91b66c0 100644 --- a/keyfetch/fetch.go +++ b/keyfetch/fetch.go @@ -123,6 +123,7 @@ func Get(kidOrThumb, iss string) keypairs.PublicKey { } func get(kidOrThumb, iss string) *CachableKey { + iss = normalizeIssuer(iss) KeyCacheMux.Lock() defer KeyCacheMux.Unlock() @@ -136,7 +137,7 @@ func get(kidOrThumb, iss string) *CachableKey { } } - id := kidOrThumb + "@" + strings.TrimRight(iss, "/") + id := kidOrThumb + "@" + normalizeIssuer(iss) hit, ok = KeyCache[id] if ok { if now := time.Now(); hit.Expiry.Sub(now) > 0 { @@ -192,15 +193,17 @@ func cacheKeys(maps map[string]map[string]string, keys map[string]keypairs.Publi for i := range keys { key := keys[i] m := maps[i] + iss := issuer if "" != m["iss"] { - issuer = m["iss"] + iss = m["iss"] } - cacheKey(m["kid"], strings.TrimRight(issuer, "/"), m["exp"], key) + cacheKey(m["kid"], iss, m["exp"], key) } } func cacheKey(kid, iss, expstr string, pub keypairs.PublicKey) error { var expiry time.Time + iss = normalizeIssuer(iss) exp, _ := strconv.ParseInt(expstr, 10, 64) if 0 == exp { @@ -230,3 +233,13 @@ func cacheKey(kid, iss, expstr string, pub keypairs.PublicKey) error { return nil } + +func clear() { + KeyCacheMux.Lock() + defer KeyCacheMux.Unlock() + KeyCache = map[string]CachableKey{} +} + +func normalizeIssuer(iss string) string { + return strings.TrimRight(iss, "/") + "/" +} diff --git a/keyfetch/fetch_test.go b/keyfetch/fetch_test.go index 66dad49..aa8cedc 100644 --- a/keyfetch/fetch_test.go +++ b/keyfetch/fetch_test.go @@ -9,8 +9,12 @@ import ( ) func TestCachesKey(t *testing.T) { - url := "https://bigsquid.auth0.com/" + testCachesKey(t, "https://bigsquid.auth0.com/") + clear() + testCachesKey(t, "https://bigsquid.auth0.com") +} +func testCachesKey(t *testing.T, url string) { // Raw fetch a key and get KID and Thumbprint _, keys, err := uncached.OIDCJWKs(url) if nil != err { diff --git a/keyfetch/uncached/fetch.go b/keyfetch/uncached/fetch.go index 7acc7a9..9914996 100644 --- a/keyfetch/uncached/fetch.go +++ b/keyfetch/uncached/fetch.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "strings" "time" keypairs "github.com/big-squid/go-keypairs" @@ -13,6 +14,7 @@ import ( // OIDCJWKs gets the OpenID Connect configuration from the baseURL and then calls JWKs with the specified jwks_uri func OIDCJWKs(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { + baseURL = normalizeBaseURL(baseURL) oidcConf := struct { JWKSURI string `json:"jwks_uri"` }{} @@ -33,6 +35,7 @@ func OIDCJWKs(baseURL string) (map[string]map[string]string, map[string]keypairs // WellKnownJWKs calls JWKs with baseURL + /.well-known/jwks.json as constructs the jwks_uri func WellKnownJWKs(baseURL string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { + baseURL = normalizeBaseURL(baseURL) if '/' == baseURL[len(baseURL)-1] { baseURL = baseURL[:len(baseURL)-1] } @@ -134,3 +137,7 @@ func safeFetch(url string, decoder decodeFunc) error { return decoder(res.Body) } + +func normalizeBaseURL(iss string) string { + return strings.TrimRight(iss, "/") + "/" +}