From 85147ca776715c9a2f03cc8dc244787d4c9df2c8 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 12 Oct 2021 11:36:19 -0600 Subject: [PATCH] bugfix: error instead of panic for invalid connection string --- cmd/mgmt/mgmt.go | 10 +--------- cmd/sqlstore/sqlstore.go | 16 ++++++++++++---- .../mgmt/authstore/authstore_postgres_test.go | 8 -------- internal/mgmt/authstore/postgresql.go | 16 +++++++++++++++- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/cmd/mgmt/mgmt.go b/cmd/mgmt/mgmt.go index 896c1d6..4fad3be 100644 --- a/cmd/mgmt/mgmt.go +++ b/cmd/mgmt/mgmt.go @@ -143,15 +143,7 @@ func main() { return } - connStr := dbURL - // TODO url.Parse - if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") { - connStr += "?sslmode=disable" - } else { - connStr += "?sslmode=required" - } - - store, err = authstore.NewStore(connStr, mgmt.InitSQL) + store, err = authstore.NewStore(dbURL, mgmt.InitSQL) if nil != err { log.Fatal("connection error", err) return diff --git a/cmd/sqlstore/sqlstore.go b/cmd/sqlstore/sqlstore.go index 2dac26b..e5c1337 100644 --- a/cmd/sqlstore/sqlstore.go +++ b/cmd/sqlstore/sqlstore.go @@ -10,10 +10,18 @@ import ( func main() { connStr := "postgres://postgres:postgres@localhost:5432/postgres" - if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") { - connStr += "?sslmode=disable" - } else { - connStr += "?sslmode=required" + + if !strings.Contains(connStr, "sslmode=") { + sep := "?" + if strings.Contains(connStr, sep) { + sep = "&" + } + if strings.Contains(connStr, "@localhost/") || + strings.Contains(connStr, "@localhost:") { + connStr += sep + "sslmode=disable" + } else { + connStr += sep + "sslmode=required" + } } store, err := authstore.NewStore(connStr, initSQL) diff --git a/internal/mgmt/authstore/authstore_postgres_test.go b/internal/mgmt/authstore/authstore_postgres_test.go index b114a26..a9f6a59 100644 --- a/internal/mgmt/authstore/authstore_postgres_test.go +++ b/internal/mgmt/authstore/authstore_postgres_test.go @@ -1,14 +1,6 @@ package authstore -import "strings" - var connStr = "postgres://postgres:postgres@localhost/postgres" func init() { - // TODO url.Parse - if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") { - connStr += "?sslmode=disable" - } else { - connStr += "?sslmode=required" - } } diff --git a/internal/mgmt/authstore/postgresql.go b/internal/mgmt/authstore/postgresql.go index 8e5129a..5336701 100644 --- a/internal/mgmt/authstore/postgresql.go +++ b/internal/mgmt/authstore/postgresql.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "io/ioutil" + "strings" "time" "git.rootprojects.org/root/telebit/assets/files" @@ -16,9 +17,22 @@ import ( var initSQL = "./postgres.init.sql" -func NewStore(pgURL, initSQL string) (Store, error) { +func NewStore(dbURL, initSQL string) (Store, error) { // https://godoc.org/github.com/lib/pq + // TODO url.Parse + if !strings.Contains(dbURL, "sslmode=") { + sep := "?" + if strings.Contains(connStr, sep) { + sep = "&" + } + if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") { + connStr += sep + "sslmode=disable" + } else { + connStr += sep + "sslmode=required" + } + } + f, err := files.Open(initSQL) if nil != err { return nil, err