provide a typesafe PublicKey interface

This commit is contained in:
AJ ONeal 2019-02-08 23:53:29 +00:00
parent 211016b05e
commit 517865f334
3 changed files with 130 additions and 31 deletions

View File

@ -12,12 +12,13 @@ import (
var EInvalidJWKURL = errors.New("url does not lead to valid JWKs") var EInvalidJWKURL = errors.New("url does not lead to valid JWKs")
func FetchOIDCPublicKeys(host string) ([]crypto.PublicKey, error) { // FetchOIDCPublicKeys fetches baseURL + ".well-known/openid-configuration" and then returns FetchPublicKeys(jwks_uri).
func FetchOIDCPublicKeys(baseURL string) (map[string]PublicKey, error) {
oidcConf := struct { oidcConf := struct {
JWKSURI string `json:"jwks_uri"` JWKSURI string `json:"jwks_uri"`
}{} }{}
// must come in as https://<domain>/ // must come in as https://<domain>/
url := host + ".well-known/openid-configuration" url := baseURL + ".well-known/openid-configuration"
err := safeFetch(url, func(body io.Reader) error { err := safeFetch(url, func(body io.Reader) error {
return json.NewDecoder(body).Decode(&oidcConf) return json.NewDecoder(body).Decode(&oidcConf)
}) })
@ -28,8 +29,9 @@ func FetchOIDCPublicKeys(host string) ([]crypto.PublicKey, error) {
return FetchPublicKeys(oidcConf.JWKSURI) return FetchPublicKeys(oidcConf.JWKSURI)
} }
func FetchPublicKeys(jwksurl string) ([]crypto.PublicKey, error) { // FetchPublicKeys returns a map of keys identified by their kid or thumbprint (if kid is not specified)
var keys []crypto.PublicKey func FetchPublicKeys(jwksurl string) (map[string]PublicKey, error) {
keys := map[string]PublicKey{}
resp := struct { resp := struct {
Keys []map[string]interface{} `json:"keys"` Keys []map[string]interface{} `json:"keys"`
}{ }{
@ -59,13 +61,14 @@ func FetchPublicKeys(jwksurl string) ([]crypto.PublicKey, error) {
if key, err := NewJWKPublicKey(n); nil != err { if key, err := NewJWKPublicKey(n); nil != err {
return nil, err return nil, err
} else { } else {
keys = append(keys, key) keys[key.Thumbprint()] = key
} }
} }
return keys, nil return keys, nil
} }
// FetchPublicKey retrieves a JWK from a URL that specifies only one
func FetchPublicKey(url string) (crypto.PublicKey, error) { func FetchPublicKey(url string) (crypto.PublicKey, error) {
var m map[string]string var m map[string]string
if err := safeFetch(url, func(body io.Reader) error { if err := safeFetch(url, func(body io.Reader) error {
@ -79,6 +82,7 @@ func FetchPublicKey(url string) (crypto.PublicKey, error) {
type decodeFunc func(io.Reader) error type decodeFunc func(io.Reader) error
// TODO: also limit the body size
func safeFetch(url string, decoder decodeFunc) error { func safeFetch(url string, decoder decodeFunc) error {
var netTransport = &http.Transport{ var netTransport = &http.Transport{
Dial: (&net.Dialer{ Dial: (&net.Dialer{

View File

@ -20,8 +20,8 @@ func TestFetchOIDCPublicKeys(t *testing.T) {
t.Fatal(url, err) t.Fatal(url, err)
} }
for i := range keys { for kid := range keys {
switch key := keys[i].(type) { switch key := keys[kid].Key().(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
_ = ThumbprintRSAPublicKey(key) _ = ThumbprintRSAPublicKey(key)
case *ecdsa.PublicKey: case *ecdsa.PublicKey:

View File

@ -14,7 +14,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"math/big" "math/big"
"time"
) )
var EInvalidPrivateKey = errors.New("PrivateKey must be of type rsa.PrivateKey or ecdsa.PrivateKey") var EInvalidPrivateKey = errors.New("PrivateKey must be of type rsa.PrivateKey or ecdsa.PrivateKey")
@ -24,14 +26,91 @@ var EParseJWK = errors.New("JWK is missing required base64-encoded JSON fields")
var EInvalidKeyType = errors.New("The JWK's 'kty' must be either 'RSA' or 'EC'") var EInvalidKeyType = errors.New("The JWK's 'kty' must be either 'RSA' or 'EC'")
var EInvalidCurve = errors.New("The JWK's 'crv' must be either of the NIST standards 'P-256' or 'P-384'") var EInvalidCurve = errors.New("The JWK's 'crv' must be either of the NIST standards 'P-256' or 'P-384'")
// PrivateKey acts as the missing would-be interface crypto.PrivateKey const EDevSwapPrivatePublic = "[Developer Error] You passed either crypto.PrivateKey or crypto.PublicKey where the other was expected."
const EDevBadKeyType = "[Developer Error] crypto.PublicKey and crypto.PrivateKey are somewhat deceptive. They're actually empty interfaces that accept any object, even non-crypto objects. You passed an object of type '%T' by mistake."
// PrivateKey is a zero-cost typesafe substitue for crypto.PrivateKey
type PrivateKey interface { type PrivateKey interface {
Public() crypto.PublicKey Public() crypto.PublicKey
} }
func MarshalPublicJWK(key crypto.PublicKey) []byte { // PublicKey thinly veils crypto.PublicKey for type safety
type PublicKey interface {
crypto.PublicKey
Thumbprint() string
Key() crypto.PublicKey
}
type ECPublicKey struct {
PublicKey *ecdsa.PublicKey // empty interface
KID string
Expiry time.Time
}
type RSAPublicKey struct {
PublicKey *rsa.PublicKey // empty interface
KID string
Expiry time.Time
}
func (p *ECPublicKey) Thumbprint() string {
return ThumbprintUntypedPublicKey(p.PublicKey)
}
func (p *ECPublicKey) Key() crypto.PublicKey {
return p.PublicKey
}
func (p *RSAPublicKey) Thumbprint() string {
return ThumbprintUntypedPublicKey(p.PublicKey)
}
func (p *RSAPublicKey) Key() crypto.PublicKey {
return p.PublicKey
}
// TypesafePublicKey wraps a crypto.PublicKey to make it typesafe.
func NewPublicKey(pub crypto.PublicKey, exp time.Time, kid ...string) PublicKey {
var k PublicKey
switch p := pub.(type) {
case *ecdsa.PublicKey:
eckey := &ECPublicKey{
PublicKey: p,
}
if 0 != len(kid) {
eckey.KID = kid[0]
} else {
eckey.KID = k.Thumbprint()
}
eckey.Expiry = exp
k = eckey
case *rsa.PublicKey:
rsakey := &RSAPublicKey{
PublicKey: p,
}
if 0 != len(kid) {
rsakey.KID = kid[0]
} else {
rsakey.KID = k.Thumbprint()
}
rsakey.Expiry = exp
k = rsakey
case *ecdsa.PrivateKey:
panic(errors.New(EDevSwapPrivatePublic))
case *rsa.PrivateKey:
panic(errors.New(EDevSwapPrivatePublic))
case *dsa.PublicKey:
panic(EInvalidPublicKey)
case *dsa.PrivateKey:
panic(EInvalidPublicKey)
default:
panic(errors.New(fmt.Sprintf(EDevBadKeyType, pub)))
}
return k
}
func MarshalJWKPublicKey(key PublicKey) []byte {
// thumbprint keys are alphabetically sorted and only include the necessary public parts // thumbprint keys are alphabetically sorted and only include the necessary public parts
switch k := key.(type) { switch k := key.Key().(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
return MarshalRSAPublicKey(k) return MarshalRSAPublicKey(k)
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
@ -40,11 +119,16 @@ func MarshalPublicJWK(key crypto.PublicKey) []byte {
panic(EInvalidPublicKey) panic(EInvalidPublicKey)
default: default:
// this is unreachable because we know the types that we pass in // this is unreachable because we know the types that we pass in
log.Printf("keytype: %t, %+v\n", key, key)
panic(EInvalidPublicKey) panic(EInvalidPublicKey)
} }
} }
func ThumbprintPublicKey(pub crypto.PublicKey) string { func ThumbprintPublicKey(pub *PublicKey) string {
return ThumbprintUntypedPublicKey(pub)
}
func ThumbprintUntypedPublicKey(pub crypto.PublicKey) string {
switch p := pub.(type) { switch p := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
return ThumbprintECPublicKey(p) return ThumbprintECPublicKey(p)
@ -128,6 +212,10 @@ func ParsePrivateKey(block []byte) (PrivateKey, error) {
return nil, EParsePrivateKey return nil, EParsePrivateKey
} }
func ParsePrivateKeyString(block string) (PrivateKey, error) {
return ParsePrivateKey([]byte(block))
}
func parsePrivateKey(der []byte) (PrivateKey, error) { func parsePrivateKey(der []byte) (PrivateKey, error) {
var key PrivateKey var key PrivateKey
@ -167,7 +255,7 @@ func parsePrivateKey(der []byte) (PrivateKey, error) {
return key, nil return key, nil
} }
func NewJWKPublicKey(m map[string]string) (crypto.PublicKey, error) { func NewJWKPublicKey(m map[string]string) (PublicKey, error) {
switch m["kty"] { switch m["kty"] {
case "RSA": case "RSA":
return parseRSAPublicKey(m) return parseRSAPublicKey(m)
@ -178,19 +266,19 @@ func NewJWKPublicKey(m map[string]string) (crypto.PublicKey, error) {
} }
} }
func ParseJWKPublicKey(b []byte) (crypto.PublicKey, error) { func ParseJWKPublicKey(b []byte) (PublicKey, error) {
return newJWKPublicKey(b) return newJWKPublicKey(b)
} }
func ParseJWKPublicKeyString(s string) (crypto.PublicKey, error) { func ParseJWKPublicKeyString(s string) (PublicKey, error) {
return newJWKPublicKey(s) return newJWKPublicKey(s)
} }
func DecodeJWKPublicKey(r io.Reader) (crypto.PublicKey, error) { func DecodeJWKPublicKey(r io.Reader) (PublicKey, error) {
return newJWKPublicKey(r) return newJWKPublicKey(r)
} }
func newJWKPublicKey(data interface{}) (crypto.PublicKey, error) { func newJWKPublicKey(data interface{}) (PublicKey, error) {
var m map[string]string var m map[string]string
switch d := data.(type) { switch d := data.(type) {
@ -232,29 +320,32 @@ func ParseJWKPrivateKey(b []byte) (PrivateKey, error) {
} }
} }
func parseRSAPublicKey(m map[string]string) (pub *rsa.PublicKey, err error) { func parseRSAPublicKey(m map[string]string) (*RSAPublicKey, error) {
// TODO grab expiry?
kid, _ := m["kid"]
n, _ := base64.RawURLEncoding.DecodeString(m["n"]) n, _ := base64.RawURLEncoding.DecodeString(m["n"])
e, _ := base64.RawURLEncoding.DecodeString(m["e"]) e, _ := base64.RawURLEncoding.DecodeString(m["e"])
if 0 == len(n) || 0 == len(e) { if 0 == len(n) || 0 == len(e) {
err = EParseJWK return nil, EParseJWK
return
} }
ni := &big.Int{} ni := &big.Int{}
ni.SetBytes(n) ni.SetBytes(n)
ei := &big.Int{} ei := &big.Int{}
ei.SetBytes(e) ei.SetBytes(e)
pub = &rsa.PublicKey{ pub := &rsa.PublicKey{
N: ni, N: ni,
E: int(ei.Int64()), E: int(ei.Int64()),
} }
return
return &RSAPublicKey{
PublicKey: pub,
KID: kid,
}, nil
} }
func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) { func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) {
var pub *rsa.PublicKey pub, err := parseRSAPublicKey(m)
pub, err = parseRSAPublicKey(m)
if nil != err { if nil != err {
return return
} }
@ -283,7 +374,7 @@ func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) {
qinvi.SetBytes(qinv) qinvi.SetBytes(qinv)
key = &rsa.PrivateKey{ key = &rsa.PrivateKey{
PublicKey: *pub, PublicKey: *pub.PublicKey,
D: di, D: di,
Primes: []*big.Int{pi, qi}, Primes: []*big.Int{pi, qi},
Precomputed: rsa.PrecomputedValues{ Precomputed: rsa.PrecomputedValues{
@ -296,7 +387,9 @@ func parseRSAPrivateKey(m map[string]string) (key *rsa.PrivateKey, err error) {
return return
} }
func parseECPublicKey(m map[string]string) (pub *ecdsa.PublicKey, err error) { func parseECPublicKey(m map[string]string) (*ECPublicKey, error) {
// TODO grab expiry?
kid, _ := m["kid"]
x, _ := base64.RawURLEncoding.DecodeString(m["x"]) x, _ := base64.RawURLEncoding.DecodeString(m["x"])
y, _ := base64.RawURLEncoding.DecodeString(m["y"]) y, _ := base64.RawURLEncoding.DecodeString(m["y"])
if 0 == len(x) || 0 == len(y) || 0 == len(m["crv"]) { if 0 == len(x) || 0 == len(y) || 0 == len(m["crv"]) {
@ -318,17 +411,19 @@ func parseECPublicKey(m map[string]string) (pub *ecdsa.PublicKey, err error) {
case "P-521": case "P-521":
crv = elliptic.P521() crv = elliptic.P521()
default: default:
err = EInvalidCurve return nil, EInvalidCurve
return
} }
pub = &ecdsa.PublicKey{ pub := &ecdsa.PublicKey{
Curve: crv, Curve: crv,
X: xi, X: xi,
Y: yi, Y: yi,
} }
return return &ECPublicKey{
PublicKey: pub,
KID: kid,
}, nil
} }
func parseECPrivateKey(m map[string]string) (*ecdsa.PrivateKey, error) { func parseECPrivateKey(m map[string]string) (*ecdsa.PrivateKey, error) {
@ -345,7 +440,7 @@ func parseECPrivateKey(m map[string]string) (*ecdsa.PrivateKey, error) {
di.SetBytes(d) di.SetBytes(d)
return &ecdsa.PrivateKey{ return &ecdsa.PrivateKey{
PublicKey: *pub, PublicKey: *pub.PublicKey,
D: di, D: di,
}, nil }, nil
} }