// Package keyfetch retrieve and cache PublicKeys // from OIDC (https://example.com/.well-known/openid-configuration) // and Auth0 (https://example.com/.well-known/jwks.json) // JWKs URLs and expires them when `exp` is reached // (or a default expiry if the key does not provide one). // It uses the keypairs package to Unmarshal the JWKs into their // native types (with a very thin shim to provide the type safety // that Go's crypto.PublicKey and crypto.PrivateKey interfaces lack). package keyfetch import ( "errors" "fmt" "log" "net/http" "net/url" "strconv" "strings" "sync" "time" "git.rootprojects.org/root/keypairs" "git.rootprojects.org/root/keypairs/keyfetch/uncached" ) // TODO should be ErrInvalidJWKURL // ErrInvalidJWKURL means that the url did not provide JWKs var ErrInvalidJWKURL = errors.New("url does not lead to valid JWKs") // KeyCache is an in-memory key cache var KeyCache = map[string]CachableKey{} // KeyCacheMux is used to guard the in-memory cache var KeyCacheMux = sync.Mutex{} // ErrInsecureDomain means that plain http was used where https was expected var ErrInsecureDomain = errors.New("Whitelists should only allow secure URLs (i.e. https://). To allow unsecured private networking (i.e. Docker) pass PrivateWhitelist as a list of private URLs") // TODO Cacheable key (shouldn't this be private)? // CachableKey represents type CachableKey struct { Key keypairs.PublicKey Expiry time.Time } // maybe TODO use this poor-man's enum to allow kids thumbs to be accepted by the same method? /* type KeyID string func (kid KeyID) ID() string { return string(kid) } func (kid KeyID) isID() {} type Thumbprint string func (thumb Thumbprint) ID() string { return string(thumb) } func (thumb Thumbprint) isID() {} type ID interface { ID() string isID() } */ // StaleTime defines when public keys should be renewed (15 minutes by default) var StaleTime = 15 * time.Minute // DefaultKeyDuration defines how long a key should be considered fresh (48 hours by default) var DefaultKeyDuration = 48 * time.Hour // MinimumKeyDuration defines the minimum time that a key will be cached (1 hour by default) var MinimumKeyDuration = time.Hour // MaximumKeyDuration defines the maximum time that a key will be cached (72 hours by default) var MaximumKeyDuration = 72 * time.Hour // PublicKeysMap is a newtype for a map of keypairs.PublicKey type PublicKeysMap map[string]keypairs.PublicKey // OIDCJWKs fetches baseURL + ".well-known/openid-configuration" and then fetches and returns the Public Keys. func OIDCJWKs(baseURL string) (PublicKeysMap, error) { maps, keys, err := uncached.OIDCJWKs(baseURL) if nil != err { return nil, err } cacheKeys(maps, keys, baseURL) return keys, err } // OIDCJWK fetches baseURL + ".well-known/openid-configuration" and then returns the key matching kid (or thumbprint) func OIDCJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.OIDCJWKs) } // WellKnownJWKs fetches baseURL + ".well-known/jwks.json" and caches and returns the keys func WellKnownJWKs(kidOrThumb, iss string) (PublicKeysMap, error) { maps, keys, err := uncached.WellKnownJWKs(iss) if nil != err { return nil, err } cacheKeys(maps, keys, iss) return keys, err } // WellKnownJWK fetches baseURL + ".well-known/jwks.json" and returns the key matching kid (or thumbprint) func WellKnownJWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.WellKnownJWKs) } // JWKs returns a map of keys identified by their thumbprint // (since kid may or may not be present) func JWKs(jwksurl string) (PublicKeysMap, error) { maps, keys, err := uncached.JWKs(jwksurl) if nil != err { return nil, err } iss := strings.Replace(jwksurl, ".well-known/jwks.json", "", 1) cacheKeys(maps, keys, iss) return keys, err } // JWK tries to return a key from cache, falling back to the /.well-known/jwks.json of the issuer func JWK(kidOrThumb, iss string) (keypairs.PublicKeyTransitional, error) { return immediateOneOrFetch(kidOrThumb, iss, uncached.JWKs) } // PEM tries to return a key from cache, falling back to the specified PEM url func PEM(url string) (keypairs.PublicKeyTransitional, error) { // url is kid in this case return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { m, key, err := uncached.PEM(url) if nil != err { return nil, nil, err } // put in a map, just for caching maps := map[string]map[string]string{} maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m maps[url] = m keys := map[string]keypairs.PublicKey{} keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key keys[url] = key return maps, keys, nil }) } // Fetch returns a key from cache, falling back to an exact url as the "issuer" func Fetch(url string) (keypairs.PublicKeyTransitional, error) { // url is kid in this case return immediateOneOrFetch(url, url, func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) { m, key, err := uncached.Fetch(url) if nil != err { return nil, nil, err } // put in a map, just for caching maps := map[string]map[string]string{} maps[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = m keys := map[string]keypairs.PublicKey{} keys[keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional))] = key return maps, keys, nil }) } // Get retrieves a key from cache, or returns an error. // The issuer string may be empty if using a thumbprint rather than a kid. func Get(kidOrThumb, iss string) keypairs.PublicKey { if pub := get(kidOrThumb, iss); nil != pub { return pub.Key } return nil } func get(kidOrThumb, iss string) *CachableKey { iss = normalizeIssuer(iss) KeyCacheMux.Lock() defer KeyCacheMux.Unlock() // we're safe to check the cache by kid alone // by virtue that we never set it by kid alone hit, ok := KeyCache[kidOrThumb] if ok { if now := time.Now(); hit.Expiry.Sub(now) > 0 { // only return non-expired keys return &hit } } id := kidOrThumb + "@" + iss hit, ok = KeyCache[id] if ok { if now := time.Now(); hit.Expiry.Sub(now) > 0 { // only return non-expired keys return &hit } } return nil } func immediateOneOrFetch(kidOrThumb, iss string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { now := time.Now() key := get(kidOrThumb, iss) if nil == key { return fetchAndSelect(kidOrThumb, iss, fetcher) } // Fetch just a little before the key actually expires if key.Expiry.Sub(now) <= StaleTime { go fetchAndSelect(kidOrThumb, iss, fetcher) } return key.Key.Key().(keypairs.PublicKeyTransitional), nil } type myfetcher func(string) (map[string]map[string]string, map[string]keypairs.PublicKey, error) func fetchAndSelect(id, baseURL string, fetcher myfetcher) (keypairs.PublicKeyTransitional, error) { maps, keys, err := fetcher(baseURL) if nil != err { return nil, err } cacheKeys(maps, keys, baseURL) for i := range keys { key := keys[i] if id == keypairs.Thumbprint(key.Key().(keypairs.PublicKeyTransitional)) { return key.Key().(keypairs.PublicKeyTransitional), nil } if id == key.KeyID() { return key.Key().(keypairs.PublicKeyTransitional), nil } } return nil, fmt.Errorf("Key identified by '%s' was not found at %s", id, baseURL) } func cacheKeys(maps map[string]map[string]string, keys map[string]keypairs.PublicKey, issuer string) { for i := range keys { key := keys[i] m := maps[i] iss := issuer if "" != m["iss"] { iss = m["iss"] } iss = normalizeIssuer(iss) cacheKey(m["kid"], iss, m["exp"], key) } } func cacheKey(kid, iss, expstr string, pub keypairs.PublicKey) error { var expiry time.Time iss = normalizeIssuer(iss) exp, _ := strconv.ParseInt(expstr, 10, 64) if 0 == exp { // use default expiry = time.Now().Add(DefaultKeyDuration) } else if exp < time.Now().Add(MinimumKeyDuration).Unix() || exp > time.Now().Add(MaximumKeyDuration).Unix() { // use at least one hour expiry = time.Now().Add(MinimumKeyDuration) } else { expiry = time.Unix(exp, 0) } KeyCacheMux.Lock() defer KeyCacheMux.Unlock() // Put the key in the cache by both kid and thumbprint, and set the expiry id := kid + "@" + iss KeyCache[id] = CachableKey{ Key: pub, Expiry: expiry, } // Since thumbprints are crypto secure, iss isn't needed thumb := keypairs.Thumbprint(pub.Key().(keypairs.PublicKeyTransitional)) KeyCache[thumb] = CachableKey{ Key: pub, Expiry: expiry, } return nil } func clear() { KeyCacheMux.Lock() defer KeyCacheMux.Unlock() KeyCache = map[string]CachableKey{} } func normalizeIssuer(iss string) string { return strings.TrimRight(iss, "/") } func isTrustedIssuer(iss string, whitelist Whitelist, rs ...*http.Request) bool { if "" == iss { return false } // Normalize the http:// and https:// and parse iss = strings.TrimRight(iss, "/") + "/" if strings.HasPrefix(iss, "http://") { // ignore } else if strings.HasPrefix(iss, "//") { return false // TODO } else if !strings.HasPrefix(iss, "https://") { iss = "https://" + iss } issURL, err := url.Parse(iss) if nil != err { return false } // Check that // * schemes match (https: == https:) // * paths match (/foo/ == /foo/, always with trailing slash added) // * hostnames are compatible (a == b or "sub.foo.com".HasSufix(".foo.com")) for i := range []*url.URL(whitelist) { u := whitelist[i] if issURL.Scheme != u.Scheme { continue } else if u.Path != strings.TrimRight(issURL.Path, "/")+"/" { continue } else if issURL.Host != u.Host { if '.' == u.Host[0] && strings.HasSuffix(issURL.Host, u.Host) { return true } continue } // All failures have been handled return true } // Check if implicit issuer is available if 0 == len(rs) { return false } return hasImplicitTrust(issURL, rs[0]) } // hasImplicitTrust relies on the security of DNS and TLS to determine if the // headers of the request can be trusted as identifying the server itself as // a valid issuer, without additional configuration. // // Helpful for testing, but in the wrong hands could easily lead to a zero-day. func hasImplicitTrust(issURL *url.URL, r *http.Request) bool { if nil == r { return false } // Sanity check that, if a load balancer exists, it isn't misconfigured proto := r.Header.Get("X-Forwarded-Proto") if "" != proto && proto != "https" { return false } // Get the host // * If TLS, block Domain Fronting // * Otherwise assume trusted proxy // * Otherwise assume test environment var host string if nil != r.TLS { // Note that if this were to be implemented for HTTP/2 it would need to // check all names on the certificate, not just the one with which the // original connection was established. However, not our problem here. // See https://serverfault.com/a/908087/93930 if r.TLS.ServerName != r.Host { return false } host = r.Host } else { host = r.Header.Get("X-Forwarded-Host") if "" == host { host = r.Host } } // Same tests as above, adjusted since it can't handle wildcards and, since // the path is variable, we make the assumption that a child can trust a // parent, but that a parent cannot trust a child. if r.Host != issURL.Host { return false } if !strings.HasPrefix(strings.TrimRight(r.URL.Path, "/")+"/", issURL.Path) { // Ex: Request URL Token Issuer // !"https:example.com/johndoe/api/dothing".HasPrefix("https:example.com/") return false } return true } // Whitelist is a newtype for an array of URLs type Whitelist []*url.URL // NewWhitelist turns an array of URLs (such as https://example.com/) into // a parsed array of *url.URLs that can be used by the IsTrustedIssuer function func NewWhitelist(issuers []string, privateList ...[]string) (Whitelist, error) { var err error list := []*url.URL{} if 0 != len(issuers) { insecure := false list, err = newWhitelist(list, issuers, insecure) if nil != err { return nil, err } } if 0 != len(privateList) && 0 != len(privateList[0]) { insecure := true list, err = newWhitelist(list, privateList[0], insecure) if nil != err { return nil, err } } return Whitelist(list), nil } func newWhitelist(list []*url.URL, issuers []string, insecure bool) (Whitelist, error) { for i := range issuers { iss := issuers[i] if "" == strings.TrimSpace(iss) { fmt.Println("[Warning] You have an empty string in your keyfetch whitelist.") continue } // Should have a valid http or https prefix // TODO support custom prefixes (i.e. app://) ? if strings.HasPrefix(iss, "http://") { if !insecure { log.Println("Oops! You have an insecure domain in your whitelist: ", iss) return nil, ErrInsecureDomain } } else if strings.HasPrefix(iss, "//") { // TODO return nil, errors.New("Rather than prefixing with // to support multiple protocols, add them seperately:" + iss) } else if !strings.HasPrefix(iss, "https://") { iss = "https://" + iss } // trailing slash as a boundary character, which may or may not denote a directory iss = strings.TrimRight(iss, "/") + "/" u, err := url.Parse(iss) if nil != err { return nil, err } // Strip any * prefix, for easier comparison later // *.example.com => .example.com if strings.HasPrefix(u.Host, "*.") { u.Host = u.Host[1:] } list = append(list, u) } return list, nil } /* IsTrustedIssuer returns true when the `iss` (i.e. from a token) matches one in the provided whitelist (also matches wildcard domains). You may explicitly allow insecure http (i.e. for automated testing) by including http:// Otherwise the scheme in each item of the whitelist should include the "https://" prefix. SECURITY CONSIDERATIONS (Please Read) You'll notice that *http.Request is optional. It should only be used under these three circumstances: 1) Something else guarantees http -> https redirection happens before the connection gets here AND this server directly handles TLS/SSL. 2) If you're using a load balancer or web server, and this doesn't handle TLS/SSL directly, that server is _explicitly_ configured to protect against Domain Fronting attacks. As of 2019, most web servers and load balancers do not protect against that by default. 3) If you only use it to make your automated integration testing more and it isn't enabled in production. Otherwise, DO NOT pass in *http.Request as you will introduce a 0-day vulnerability allowing an attacker to spoof any token issuer of their choice. The only reason I allowed this in a public library where non-experts would encounter it is to make testing easier. */ func (w Whitelist) IsTrustedIssuer(iss string, rs ...*http.Request) bool { return isTrustedIssuer(iss, w, rs...) } // String will generate a space-delimited list of whitelisted URLs func (w Whitelist) String() string { s := []string{} for i := range w { s = append(s, w[i].String()) } return strings.Join(s, " ") }