From 517865f3343bcfada0f19f95e0b657422005870d Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 8 Feb 2019 23:53:29 +0000 Subject: [PATCH] provide a typesafe PublicKey interface --- fetch.go | 14 +++-- fetch_test.go | 4 +- keypairs.go | 143 +++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 130 insertions(+), 31 deletions(-) diff --git a/fetch.go b/fetch.go index 4f66e02..0224f70 100644 --- a/fetch.go +++ b/fetch.go @@ -12,12 +12,13 @@ import ( var EInvalidJWKURL = errors.New("url does not lead to valid JWKs") -func FetchOIDCPublicKeys(host string) ([]crypto.PublicKey, error) { +// FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri). +func FetchOIDCPublicKeys(baseURL string) (map[string]PublicKey, error) { oidcConf := struct { JWKSURI string `json:"jwks_uri"` }{} // must come in as https:/// - url := host + ".well-known/openid-configuration" + url := baseURL + ".well-known/openid-configuration" err := safeFetch(url, func(body io.Reader) error { return json.NewDecoder(body).Decode(&oidcConf) }) @@ -28,8 +29,9 @@ func FetchOIDCPublicKeys(host string) ([]crypto.PublicKey, error) { return FetchPublicKeys(oidcConf.JWKSURI) } -func FetchPublicKeys(jwksurl string) ([]crypto.PublicKey, error) { - var keys []crypto.PublicKey +// FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified) +func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) { + keys := map[string]PublicKey{} resp := struct { Keys []map[string]interface{} `json:"keys"` }{ @@ -59,13 +61,14 @@ func FetchPublicKeys(jwksurl string) ([]crypto.PublicKey, error) { if key, err := NewJWKPublicKey(n); nil != err { return nil, err } else { - keys = append(keys, key) + keys[key.Thumbprint()] = key } } return keys, nil } +// FetchPublicKey retrieves a JWK from a URL that specifies only one func FetchPublicKey(url string) (crypto.PublicKey, error) { var m map[string]string if err := safeFetch(url, func(body io.Reader) error { @@ -79,6 +82,7 @@ func FetchPublicKey(url string) (crypto.PublicKey, error) { type decodeFunc func(io.Reader) error +// TODO: also limit the body size func safeFetch(url string, decoder decodeFunc) error { var netTransport = &http.Transport{ Dial: (&net.Dialer{ diff --git a/fetch_test.go b/fetch_test.go index d1ba9da..3947efe 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -20,8 +20,8 @@ func TestFetchOIDCPublicKeys(t *testing.T) { t.Fatal(url, err) } - for i := range keys { - switch key := keys[i].(type) { + for kid := range keys { + switch key := keys[kid].Key().(type) { case *rsa.PublicKey: _ = ThumbprintRSAPublicKey(key) case *ecdsa.PublicKey: diff --git a/keypairs.go b/keypairs.go index 9710716..8a00e71 100644 --- a/keypairs.go +++ b/keypairs.go @@ -14,7 +14,9 @@ import ( "errors" "fmt" "io" + "log" "math/big" + "time" ) var EInvalidPrivateKey = errors.New("PrivateKey must be of type rsa.PrivateKey or ecdsa.PrivateKey") @@ -24,14 +26,91 @@ var EParseJWK = errors.New("JWK is missing required base64-encoded JSON fields") var EInvalidKeyType = errors.New("The JWK's 'kty' must be either 'RSA' or 'EC'") var EInvalidCurve = errors.New("The JWK's 'crv' must be either of the NIST standards 'P-256' or 'P-384'") -// PrivateKey acts as the missing would-be interface crypto.PrivateKey +const EDevSwapPrivatePublic = "[Developer Error] You passed either crypto.PrivateKey or crypto.PublicKey where the other was expected." + +const EDevBadKeyType = "[Developer Error] crypto.PublicKey and crypto.PrivateKey are somewhat deceptive. They're actually empty interfaces that accept any object, even non-crypto objects. You passed an object of type '%T' by mistake." + +// PrivateKey is a zero-cost typesafe substitue for crypto.PrivateKey type PrivateKey interface { Public() crypto.PublicKey } -func MarshalPublicJWK(key crypto.PublicKey) []byte { +// PublicKey thinly veils crypto.PublicKey for type safety +type PublicKey interface { + crypto.PublicKey + Thumbprint() string + Key() crypto.PublicKey +} + +type ECPublicKey struct { + PublicKey *ecdsa.PublicKey // empty interface + KID string + Expiry time.Time +} + +type RSAPublicKey struct { + PublicKey *rsa.PublicKey // empty interface + KID string + Expiry time.Time +} + +func (p *ECPublicKey) Thumbprint() string { + return ThumbprintUntypedPublicKey(p.PublicKey) +} +func (p *ECPublicKey) Key() crypto.PublicKey { + return p.PublicKey +} +func (p *RSAPublicKey) Thumbprint() string { + return ThumbprintUntypedPublicKey(p.PublicKey) +} +func (p *RSAPublicKey) Key() crypto.PublicKey { + return p.PublicKey +} + +// TypesafePublicKey wraps a crypto.PublicKey to make it typesafe. +func NewPublicKey(pub crypto.PublicKey, exp time.Time, kid ...string) PublicKey { + var k PublicKey + switch p := pub.(type) { + case *ecdsa.PublicKey: + eckey := &ECPublicKey{ + PublicKey: p, + } + if 0 != len(kid) { + eckey.KID = kid[0] + } else { + eckey.KID = k.Thumbprint() + } + eckey.Expiry = exp + k = eckey + case *rsa.PublicKey: + rsakey := &RSAPublicKey{ + PublicKey: p, + } + if 0 != len(kid) { + rsakey.KID = kid[0] + } else { + rsakey.KID = k.Thumbprint() + } + rsakey.Expiry = exp + k = rsakey + case *ecdsa.PrivateKey: + panic(errors.New(EDevSwapPrivatePublic)) + case *rsa.PrivateKey: + panic(errors.New(EDevSwapPrivatePublic)) + case *dsa.PublicKey: + panic(EInvalidPublicKey) + case *dsa.PrivateKey: + panic(EInvalidPublicKey) + default: + panic(errors.New(fmt.Sprintf(EDevBadKeyType, pub))) + } + + return k +} + +func MarshalJWKPublicKey(key PublicKey) []byte { // thumbprint keys are alphabetically sorted and only include the necessary public parts - switch k := key.(type) { + switch k := key.Key().(type) { case *rsa.PublicKey: return MarshalRSAPublicKey(k) case *ecdsa.PublicKey: @@ -40,11 +119,16 @@ func MarshalPublicJWK(key crypto.PublicKey) []byte { panic(EInvalidPublicKey) default: // this is unreachable because we know the types that we pass in + log.Printf("keytype: %t, %+v\n", key, key) panic(EInvalidPublicKey) } } -func ThumbprintPublicKey(pub crypto.PublicKey) string { +func ThumbprintPublicKey(pub *PublicKey) string { + return ThumbprintUntypedPublicKey(pub) +} + +func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string { switch p := pub.(type) { case *ecdsa.PublicKey: return ThumbprintECPublicKey(p) @@ -128,6 +212,10 @@ func ParsePrivateKey(block []byte) (PrivateKey, error) { return nil, EParsePrivateKey } +func ParsePrivateKeyString(block string) (PrivateKey, error) { + return ParsePrivateKey([]byte(block)) +} + func parsePrivateKey(der []byte) (PrivateKey, error) { var key PrivateKey @@ -167,7 +255,7 @@ func parsePrivateKey(der []byte) (PrivateKey, error) { return key, nil } -func NewJWKPublicKey(m map[string]string) (crypto.PublicKey, error) { +func NewJWKPublicKey(m map[string]string) (PublicKey, error) { switch m["kty"] { case "RSA": return parseRSAPublicKey(m) @@ -178,19 +266,19 @@ func NewJWKPublicKey(m map[string]string) (crypto.PublicKey, error) { } } -func ParseJWKPublicKey(b []byte) (crypto.PublicKey, error) { +func ParseJWKPublicKey(b []byte) (PublicKey, error) { return newJWKPublicKey(b) } -func ParseJWKPublicKeyString(s string) (crypto.PublicKey, error) { +func ParseJWKPublicKeyString(s string) (PublicKey, error) { return newJWKPublicKey(s) } -func DecodeJWKPublicKey(r io.Reader) (crypto.PublicKey, error) { +func DecodeJWKPublicKey(r io.Reader) (PublicKey, error) { return newJWKPublicKey(r) } -func newJWKPublicKey(data interface{}) (crypto.PublicKey, error) { +func newJWKPublicKey(data interface{}) (PublicKey, error) { var m map[string]string switch d := data.(type) { @@ -232,29 +320,32 @@ func ParseJWKPrivateKey(b []byte) (PrivateKey, error) { } } -func parseRSAPublicKey(m map[string]string) (pub *rsa.PublicKey, err error) { +func parseRSAPublicKey(m map[string]string) (*RSAPublicKey, error) { + // TODO grab expiry? + kid, _ := m["kid"] n, _ := base64.RawURLEncoding.DecodeString(m["n"]) e, _ := base64.RawURLEncoding.DecodeString(m["e"]) if 0 == len(n) || 0 == len(e) { - err = EParseJWK - return + return nil, EParseJWK } ni := &big.Int{} ni.SetBytes(n) ei := &big.Int{} ei.SetBytes(e) - pub = &rsa.PublicKey{ + pub := &rsa.PublicKey{ N: ni, E: int(ei.Int64()), } - return + + return &RSAPublicKey{ + PublicKey: pub, + KID: kid, + }, nil } func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) { - var pub *rsa.PublicKey - - pub, err = parseRSAPublicKey(m) + pub, err := parseRSAPublicKey(m) if nil != err { return } @@ -283,7 +374,7 @@ func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) { qinvi.SetBytes(qinv) key = &rsa.PrivateKey{ - PublicKey: *pub, + PublicKey: *pub.PublicKey, D: di, Primes: []*big.Int{pi, qi}, Precomputed: rsa.PrecomputedValues{ @@ -296,7 +387,9 @@ func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) { return } -func parseECPublicKey(m map[string]string) (pub *ecdsa.PublicKey, err error) { +func parseECPublicKey(m map[string]string) (*ECPublicKey, error) { + // TODO grab expiry? + kid, _ := m["kid"] x, _ := base64.RawURLEncoding.DecodeString(m["x"]) y, _ := base64.RawURLEncoding.DecodeString(m["y"]) if 0 == len(x) || 0 == len(y) || 0 == len(m["crv"]) { @@ -318,17 +411,19 @@ func parseECPublicKey(m map[string]string) (pub *ecdsa.PublicKey, err error) { case "P-521": crv = elliptic.P521() default: - err = EInvalidCurve - return + return nil, EInvalidCurve } - pub = &ecdsa.PublicKey{ + pub := &ecdsa.PublicKey{ Curve: crv, X: xi, Y: yi, } - return + return &ECPublicKey{ + PublicKey: pub, + KID: kid, + }, nil } func parseECPrivateKey(m map[string]string) (*ecdsa.PrivateKey, error) { @@ -345,7 +440,7 @@ func parseECPrivateKey(m map[string]string) (*ecdsa.PrivateKey, error) { di.SetBytes(d) return &ecdsa.PrivateKey{ - PublicKey: *pub, + PublicKey: *pub.PublicKey, D: di, }, nil }