From a6e3c042fe4095779a2a19c51178b0caf5c3cbc6 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Sat, 30 May 2020 17:14:40 -0600 Subject: [PATCH] WIP: authenticated management routes --- mplexer/cmd/mgmt/acmeroutes.go | 113 +++++++++++ mplexer/cmd/mgmt/devices.go | 90 +++++++++ mplexer/cmd/mgmt/mgmt.go | 170 +++------------- mplexer/cmd/mgmt/route.go | 181 ++++++++++++++++++ mplexer/cmd/sqlstore/sqlstore.go | 2 +- mplexer/mgmt/authstore/authstore.go | 28 +++ .../{ => mgmt}/authstore/authstore_test.go | 15 +- mplexer/{ => mgmt}/authstore/init.sql | 0 mplexer/{ => mgmt}/authstore/insert.sql | 0 mplexer/{ => mgmt}/authstore/postgresql.go | 46 +---- 10 files changed, 462 insertions(+), 183 deletions(-) create mode 100644 mplexer/cmd/mgmt/acmeroutes.go create mode 100644 mplexer/cmd/mgmt/devices.go create mode 100644 mplexer/cmd/mgmt/route.go create mode 100644 mplexer/mgmt/authstore/authstore.go rename mplexer/{ => mgmt}/authstore/authstore_test.go (70%) rename mplexer/{ => mgmt}/authstore/init.sql (100%) rename mplexer/{ => mgmt}/authstore/insert.sql (100%) rename mplexer/{ => mgmt}/authstore/postgresql.go (76%) diff --git a/mplexer/cmd/mgmt/acmeroutes.go b/mplexer/cmd/mgmt/acmeroutes.go new file mode 100644 index 0000000..a2fb73d --- /dev/null +++ b/mplexer/cmd/mgmt/acmeroutes.go @@ -0,0 +1,113 @@ +package main + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/go-acme/lego/v3/challenge" + "github.com/go-chi/chi" +) + +// A Challenge has the data necessary to create an ACME DNS-01 Key Authorization Digest. +type Challenge struct { + Domain string `json:"domain"` + Token string `json:"token"` + KeyAuth string `json:"key_authorization"` + error chan error +} + +type acmeProvider struct { + BaseURL string + provider challenge.Provider +} + +func (p *acmeProvider) Present(domain, token, keyAuth string) error { + return p.provider.Present(domain, token, keyAuth) +} + +func (p *acmeProvider) CleanUp(domain, token, keyAuth string) error { + return p.provider.CleanUp(domain, token, keyAuth) +} + +func handleDNSRoutes(r chi.Router) { + r.Route("/dns", func(r chi.Router) { + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + valid, _ := ctx.Value(MWKey("valid")).(bool) + + if !valid { + // misdirection + time.Sleep(250 * time.Millisecond) + w.Write([]byte("{\"success\":true}\n")) + //http.Error(w, `{"error":"could not verify token"}`, http.StatusBadRequest) + return + } + /* + if nil != err2 { + // a little misdirection there + msg := `{"error":"internal server error"}` + http.Error(w, msg, http.StatusInternalServerError) + return + } + */ + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + + r.Post("/{domain}", func(w http.ResponseWriter, r *http.Request) { + + ch := Challenge{} + + // TODO prevent slow loris + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&ch) + if nil != err || "" == ch.Token || "" == ch.KeyAuth { + msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}` + http.Error(w, msg, http.StatusUnprocessableEntity) + return + } + + domain := chi.URLParam(r, "domain") + //domain := chi.URLParam(r, "*") + ch.Domain = domain + + // TODO some additional error checking before the handoff + //ch.error = make(chan error, 1) + ch.error = make(chan error) + presenters <- &ch + err = <-ch.error + if nil != err || "" == ch.Token || "" == ch.KeyAuth { + msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}` + http.Error(w, msg, http.StatusUnprocessableEntity) + return + } + + w.Write([]byte("{\"success\":true}\n")) + }) + + // TODO ugly Delete, but whatever + r.Delete("/{domain}/{token}/{keyAuth}", func(w http.ResponseWriter, r *http.Request) { + + ch := Challenge{ + Domain: chi.URLParam(r, "domain"), + Token: chi.URLParam(r, "token"), + KeyAuth: chi.URLParam(r, "keyAuth"), + error: make(chan error), + //error: make(chan error, 1), + } + + cleanups <- &ch + err := <-ch.error + if nil != err || "" == ch.Token || "" == ch.KeyAuth { + msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}` + http.Error(w, msg, http.StatusUnprocessableEntity) + return + } + + w.Write([]byte("{\"success\":true}\n")) + }) + }) +} diff --git a/mplexer/cmd/mgmt/devices.go b/mplexer/cmd/mgmt/devices.go new file mode 100644 index 0000000..b5f974a --- /dev/null +++ b/mplexer/cmd/mgmt/devices.go @@ -0,0 +1,90 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + "time" + + "git.coolaj86.com/coolaj86/go-telebitd/mplexer/mgmt/authstore" + "github.com/go-chi/chi" +) + +func handleDeviceRoutes(r chi.Router) { + r.Route("/devices", func(r chi.Router) { + // TODO needs admin auth + // r.Use() // must have slug '*' + + r.Post("/", func(w http.ResponseWriter, r *http.Request) { + + auth := &authstore.Authorization{} + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&auth) + // Slug is mandatory, ID and MachinePPID must NOT be set + epoch := time.Time{} + if nil != err || "" != auth.ID || "" != auth.MachinePPID || + "" == auth.Slug || "" == auth.SharedKey || + epoch != auth.CreatedAt || epoch != auth.UpdatedAt || epoch != auth.DeletedAt { + result, _ := json.Marshal(&authstore.Authorization{}) + msg, _ := json.Marshal(&struct { + Error string `json:"error"` + }{ + Error: "expected JSON in the format " + string(result), + }) + http.Error(w, string(msg), http.StatusUnprocessableEntity) + return + } + + err = store.Add(auth) + if nil != err { + msg := `{"error":"not really sure what happened, but it didn't go well (check the logs)"}` + log.Printf("/api/devices/\n", auth.Slug) + log.Println(err) + http.Error(w, msg, http.StatusInternalServerError) + return + } + + //auth.SharedKey = "[redacted]" + result, _ := json.Marshal(auth) + w.Write(result) + }) + + r.Get("/{slug}", func(w http.ResponseWriter, r *http.Request) { + slug := chi.URLParam(r, "slug") + // TODO store should be concurrency-safe + auth, err := store.Get(slug) + if nil != err { + msg := `{"error":"not really sure what happened, but it didn't go well (check the logs)"}` + log.Printf("/api/devices/%s\n", slug) + log.Println(err) + http.Error(w, msg, http.StatusInternalServerError) + return + } + + // Redact private data + if "" != auth.MachinePPID { + auth.MachinePPID = "[redacted]" + } + if "" != auth.SharedKey { + auth.SharedKey = "[redacted]" + } + result, _ := json.Marshal(auth) + w.Write(result) + }) + + r.Delete("/{slug}", func(w http.ResponseWriter, r *http.Request) { + slug := chi.URLParam(r, "slug") + auth, err := store.Get(slug) + if nil == auth { + msg := `{"error":"not really sure what happened, but it didn't go well (check the logs)"}` + log.Printf("/api/devices/%s\n", slug) + log.Println(err) + http.Error(w, msg, http.StatusInternalServerError) + return + } + + w.Write([]byte("{\"success\":true}\n")) + }) + }) + +} diff --git a/mplexer/cmd/mgmt/mgmt.go b/mplexer/cmd/mgmt/mgmt.go index ab3d322..425bd13 100644 --- a/mplexer/cmd/mgmt/mgmt.go +++ b/mplexer/cmd/mgmt/mgmt.go @@ -3,21 +3,18 @@ package main import ( - "context" - "encoding/json" "flag" "fmt" + "log" "net/http" "os" "strings" - "time" - jwt "github.com/dgrijalva/jwt-go" + "git.coolaj86.com/coolaj86/go-telebitd/mplexer/mgmt/authstore" + "github.com/go-acme/lego/v3/challenge" "github.com/go-acme/lego/v3/providers/dns/duckdns" "github.com/go-acme/lego/v3/providers/dns/godaddy" - "github.com/go-chi/chi" - "github.com/go-chi/chi/middleware" _ "github.com/joho/godotenv/autoload" ) @@ -32,15 +29,21 @@ var ( type MWKey string +var store authstore.Store +var provider challenge.Provider = nil // TODO is this concurrency-safe? +var secret *string + func main() { var err error - var provider challenge.Provider = nil // TODO is this concurrency-safe? - var presenters = make(chan *Challenge) - var cleanups = make(chan *Challenge) addr := flag.String("address", "", "IPv4 or IPv6 bind address") port := flag.String("port", "3000", "port to listen to") - secret := flag.String("secret", "", "a >= 16-character random string for JWT key signing") // SECRET + dbURL := flag.String( + "db-url", + "postgres://postgres:postgres@localhost/postgres", + "database (postgres) connection url", + ) + secret = flag.String("secret", "", "a >= 16-character random string for JWT key signing") flag.Parse() if "" != os.Getenv("GODADDY_API_KEY") { @@ -66,142 +69,25 @@ func main() { return } - r := chi.NewRouter() - r.Use(middleware.Logger) - r.Use(middleware.Timeout(15 * time.Second)) - r.Use(middleware.Recoverer) + connStr := *dbURL + // TODO url.Parse + if strings.Contains(connStr, "@localhost/") { + connStr += "?sslmode=disable" + } else { + connStr += "?sslmode=required" + } + initSQL := "./init.sql" - r.Route("/api/dns", func(r chi.Router) { - r.Use(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var tokenString string - if auth := strings.Split(r.Header.Get("Authorization"), " "); len(auth) > 1 { - // TODO handle Basic auth tokens as well - tokenString = auth[1] - } - if "" == tokenString { - tokenString = r.URL.Query().Get("access_token") - } - - // TODO check expiration and such - tok, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte(*secret), nil - }) - if nil != err { - fmt.Println("validation error:", tokenString, err) - http.Error(w, "{\"error\":\"could not verify token\"}", http.StatusBadRequest) - return - } - - ctx := context.WithValue(r.Context(), MWKey("token"), tok) - - next.ServeHTTP(w, r.WithContext(ctx)) - }) - }) - - r.Post("/{domain}", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - ch := Challenge{} - - // TODO prevent slow loris - decoder := json.NewDecoder(r.Body) - err := decoder.Decode(&ch) - if nil != err || "" == ch.Token || "" == ch.KeyAuth { - msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}` - http.Error(w, msg, http.StatusUnprocessableEntity) - return - } - - domain := chi.URLParam(r, "domain") - //domain := chi.URLParam(r, "*") - ch.Domain = domain - - // TODO some additional error checking before the handoff - //ch.error = make(chan error, 1) - ch.error = make(chan error) - presenters <- &ch - err = <-ch.error - if nil != err || "" == ch.Token || "" == ch.KeyAuth { - msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}` - http.Error(w, msg, http.StatusUnprocessableEntity) - return - } - - w.Write([]byte("{\"success\":true}\n")) - }) - - // TODO ugly Delete, but whatever - r.Delete("/{domain}/{token}/{keyAuth}", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - ch := Challenge{ - Domain: chi.URLParam(r, "domain"), - Token: chi.URLParam(r, "token"), - KeyAuth: chi.URLParam(r, "keyAuth"), - error: make(chan error), - //error: make(chan error, 1), - } - - cleanups <- &ch - err = <-ch.error - if nil != err || "" == ch.Token || "" == ch.KeyAuth { - msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}` - http.Error(w, msg, http.StatusUnprocessableEntity) - return - } - - w.Write([]byte("{\"success\":true}\n")) - }) - }) - - r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("welcome\n")) - }) - - go func() { - for { - // TODO make parallel? - // TODO make cancellable? - ch := <-presenters - err := provider.Present(ch.Domain, ch.Token, ch.KeyAuth) - ch.error <- err - } - }() - - go func() { - for { - // TODO make parallel? - // TODO make cancellable? - ch := <-cleanups - ch.error <- provider.CleanUp(ch.Domain, ch.Token, ch.KeyAuth) - } - }() + store, err = authstore.NewStore(connStr, initSQL) + if nil != err { + log.Fatal("connection error", err) + return + } + defer store.Close() bind := *addr + ":" + *port fmt.Println("Listening on", bind) - fmt.Fprintf(os.Stderr, "failed:", http.ListenAndServe(bind, r)) -} - -// A Challenge has the data necessary to create an ACME DNS-01 Key Authorization Digest. -type Challenge struct { - Domain string `json:"domain"` - Token string `json:"token"` - KeyAuth string `json:"key_authorization"` - error chan error -} - -type acmeProvider struct { - BaseURL string - provider challenge.Provider -} - -func (p *acmeProvider) Present(domain, token, keyAuth string) error { - return p.provider.Present(domain, token, keyAuth) -} - -func (p *acmeProvider) CleanUp(domain, token, keyAuth string) error { - return p.provider.CleanUp(domain, token, keyAuth) + fmt.Fprintf(os.Stderr, "failed:", http.ListenAndServe(bind, routeAll())) } // newDuckDNSProvider is for the sake of demoing the tunnel diff --git a/mplexer/cmd/mgmt/route.go b/mplexer/cmd/mgmt/route.go new file mode 100644 index 0000000..203b575 --- /dev/null +++ b/mplexer/cmd/mgmt/route.go @@ -0,0 +1,181 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "time" + + "git.coolaj86.com/coolaj86/go-telebitd/mplexer/mgmt/authstore" + "github.com/dgrijalva/jwt-go" + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" +) + +type MgmtClaims struct { + Slug string `json:"slug"` + jwt.StandardClaims +} + +var presenters = make(chan *Challenge) +var cleanups = make(chan *Challenge) + +func routeAll() chi.Router { + + go func() { + for { + // TODO make parallel? + // TODO make cancellable? + ch := <-presenters + err := provider.Present(ch.Domain, ch.Token, ch.KeyAuth) + ch.error <- err + } + }() + + go func() { + for { + // TODO make parallel? + // TODO make cancellable? + ch := <-cleanups + ch.error <- provider.CleanUp(ch.Domain, ch.Token, ch.KeyAuth) + } + }() + + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Use(middleware.Timeout(15 * time.Second)) + r.Use(middleware.Recoverer) + + r.Route("/api", func(r chi.Router) { + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var tokenString string + if auth := strings.Split(r.Header.Get("Authorization"), " "); len(auth) > 1 { + // TODO handle Basic auth tokens as well + tokenString = auth[1] + } + + //var err2 error = nil + tok, err := jwt.ParseWithClaims( + tokenString, + &MgmtClaims{}, + func(token *jwt.Token) (interface{}, error) { + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("missing jwt header 'kid' (key id)") + } + auth, err := store.Get(kid) + if nil != err { + return nil, fmt.Errorf("invalid jwt header 'kid' (key id)") + } + + claims := token.Claims.(*MgmtClaims) + jti := claims.Id + if "" == jti { + return nil, fmt.Errorf("missing jwt payload 'jti' (jwt id / nonce)") + } + iat := claims.IssuedAt + if 0 == iat { + return nil, fmt.Errorf("missing jwt payload 'iat' (issued at)") + } + exp := claims.ExpiresAt + if 0 == exp { + return nil, fmt.Errorf("missing jwt payload 'exp' (expires at)") + } + + if "" != claims.Slug { + return nil, fmt.Errorf("extra jwt payload 'slug' (unknown)") + } + claims.Slug = auth.Slug + + /* + // a little misdirection there + mac := hmac.New(sha256.New, auth.MachinePPID) + _ = mac.Write([]byte(auth.SharedKey)) + _ = mac.Write([]byte(fmt.Sprintf("%d", exp))) + return []byte(auth.SharedKey), nil + */ + + return []byte(auth.MachinePPID), nil + }, + ) + + var ctx context.Context + if nil != tok { + ctx = context.WithValue(r.Context(), MWKey("token"), tok) + if tok.Valid { + ctx = context.WithValue(r.Context(), MWKey("valid"), nil != tok) + } + } + if nil != err { + ctx = context.WithValue(r.Context(), MWKey("error"), nil != tok) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + + handleDNSRoutes(r) + handleDeviceRoutes(r) + + r.Route("/register-device", func(r chi.Router) { + // r.Use() // must NOT have slug '*' + + r.Post("/{otp}", func(w http.ResponseWriter, r *http.Request) { + sharedKey := chi.URLParam(r, "otp") + original, err := store.Get(sharedKey) + if "" != original.MachinePPID { + msg := `{"error":"the presented key has already been used"}` + log.Printf("/api/register-device/\n") + log.Println(err) + http.Error(w, msg, http.StatusInternalServerError) + return + } + + auth := &authstore.Authorization{} + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&auth) + // MachinePPID and PublicKey are required. ID must NOT be set. Slug is ignored. + epoch := time.Time{} + auth.SharedKey = sharedKey + if nil != err || "" != auth.ID || "" == auth.MachinePPID || + "" == auth.PublicKey || "" == auth.SharedKey || + epoch != auth.CreatedAt || epoch != auth.UpdatedAt || epoch != auth.DeletedAt { + msg, _ := json.Marshal(&struct { + Error string `json:"error"` + }{ + Error: "expected JSON in the format {\"machine_ppid\":\"\",\"public_key\":\"\"}", + }) + http.Error(w, string(msg), http.StatusUnprocessableEntity) + return + } + + // TODO hash the PPID and check against the Public Key? + original.PublicKey = auth.PublicKey + original.MachinePPID = auth.MachinePPID + err = store.Set(original) + if nil != err { + msg := `{"error":"not really sure what happened, but it didn't go well (check the logs)"}` + log.Printf("/api/register-device/\n") + log.Println(err) + http.Error(w, msg, http.StatusInternalServerError) + return + } + + result, _ := json.Marshal(auth) + w.Write(result) + }) + }) + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("welcome\n")) + }) + }) + + return r +} diff --git a/mplexer/cmd/sqlstore/sqlstore.go b/mplexer/cmd/sqlstore/sqlstore.go index 8a6ef03..f8697ed 100644 --- a/mplexer/cmd/sqlstore/sqlstore.go +++ b/mplexer/cmd/sqlstore/sqlstore.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "git.coolaj86.com/coolaj86/go-telebitd/mplexer/authstore" + "git.coolaj86.com/coolaj86/go-telebitd/mplexer/mgmt/authstore" ) func main() { diff --git a/mplexer/mgmt/authstore/authstore.go b/mplexer/mgmt/authstore/authstore.go new file mode 100644 index 0000000..97e7531 --- /dev/null +++ b/mplexer/mgmt/authstore/authstore.go @@ -0,0 +1,28 @@ +package authstore + +import ( + "time" +) + +type Authorization struct { + ID string `db:"id,omitempty" json:"-"` + + MachinePPID string `db:"machine_ppid,omitempty" json:"machine_ppid,omitempty"` + PublicKey string `db:"public_key,omitempty" json:"public_key,omitempty"` + SharedKey string `db:"shared_key,omitempty" json:"shared_key"` + Slug string `db:"slug,omitempty" json:"slug"` + + CreatedAt time.Time `db:"created_at,omitempty" json:"created_at,omitempty"` + UpdatedAt time.Time `db:"updated_at,omitempty" json:"updated_at,omitempty"` + DeletedAt time.Time `db:"deleted_at,omitempty" json:"-"` +} + +type Store interface { + Add(auth *Authorization) error + Set(auth *Authorization) error + Get(id string) (*Authorization, error) + GetBySlug(id string) (*Authorization, error) + GetByPub(id string) (*Authorization, error) + Delete(auth *Authorization) error + Close() error +} diff --git a/mplexer/authstore/authstore_test.go b/mplexer/mgmt/authstore/authstore_test.go similarity index 70% rename from mplexer/authstore/authstore_test.go rename to mplexer/mgmt/authstore/authstore_test.go index a17d790..339bdf6 100644 --- a/mplexer/authstore/authstore_test.go +++ b/mplexer/mgmt/authstore/authstore_test.go @@ -2,13 +2,24 @@ package authstore import ( "fmt" + "strings" "testing" ) func TestStore(t *testing.T) { - // Note: output is cached + // Note: test output is cached (running twice will not result in two records) - store, err := NewStore(nil) + connStr := "postgres://postgres:postgres@localhost/postgres" + if strings.Contains(connStr, "@localhost/") { + connStr += "?sslmode=disable" + } else { + connStr += "?sslmode=required" + } + initSQL := "./init.sql" + + // TODO url.Parse + + store, err := NewStore(connStr, initSQL) if nil != err { t.Fatal("connection error", err) return diff --git a/mplexer/authstore/init.sql b/mplexer/mgmt/authstore/init.sql similarity index 100% rename from mplexer/authstore/init.sql rename to mplexer/mgmt/authstore/init.sql diff --git a/mplexer/authstore/insert.sql b/mplexer/mgmt/authstore/insert.sql similarity index 100% rename from mplexer/authstore/insert.sql rename to mplexer/mgmt/authstore/insert.sql diff --git a/mplexer/authstore/postgresql.go b/mplexer/mgmt/authstore/postgresql.go similarity index 76% rename from mplexer/authstore/postgresql.go rename to mplexer/mgmt/authstore/postgresql.go index fb0883c..24b6da4 100644 --- a/mplexer/authstore/postgresql.go +++ b/mplexer/mgmt/authstore/postgresql.go @@ -5,58 +5,24 @@ import ( "database/sql" "fmt" "io/ioutil" - "strings" "time" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" ) -type Authorization struct { - ID string `db:"id,omitempty"` - Slug string `db:"slug,omitempty"` - MachinePPID string `db:"machine_ppid,omitempty"` - PublicKey string `db:"public_key,omitempty"` - SharedKey string `db:"shared_key,omitempty"` - CreatedAt time.Time `db:"created_at,omitempty"` - UpdatedAt time.Time `db:"updated_at,omitempty"` - DeletedAt time.Time `db:"deleted_at,omitempty"` -} - -type Store interface { - Add(auth *Authorization) error - Set(auth *Authorization) error - Get(id string) (*Authorization, error) - GetBySlug(id string) (*Authorization, error) - GetByPub(id string) (*Authorization, error) - Delete(auth *Authorization) error - Close() error -} - -type StoreConfig interface { - Type() string - URL() string -} - -func NewStore(c StoreConfig) (Store, error) { +func NewStore(pgURL, initSQL string) (Store, error) { // https://godoc.org/github.com/lib/pq - connStr := "postgres://postgres:postgres@localhost/postgres" - if strings.Contains(connStr, "@localhost/") { - connStr += "?sslmode=disable" - } else { - connStr += "?sslmode=required" - } - // TODO url.Parse dbtype := "postgres" - sqlBytes, err := ioutil.ReadFile("./init.sql") + sqlBytes, err := ioutil.ReadFile(initSQL) if nil != err { return nil, err } ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) defer done() - db, err := sql.Open(dbtype, connStr) + db, err := sql.Open(dbtype, pgURL) if err := db.PingContext(ctx); nil != err { return nil, err } @@ -139,7 +105,11 @@ func (s *PGStore) Set(auth *Authorization) error { func (s *PGStore) Get(id string) (*Authorization, error) { ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) defer done() - query := `SELECT * FROM authorizations WHERE deleted_at = '1970-01-01 00:00:00' AND (slug = $1 OR public_key = $1)` + query := ` + SELECT * FROM authorizations + WHERE deleted_at = '1970-01-01 00:00:00' + AND (slug = $1 OR public_key = $1 OR shared_key = $1) + ` row := s.dbx.QueryRowxContext(ctx, query, id) if nil != row { auth := &Authorization{}