From 67ad7a9fa2c85b27e24085ada37d5737d63f745a Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Thu, 9 Apr 2026 10:51:41 -0600 Subject: [PATCH] fix(litemigrate,mymigrate,msmigrate): take *sql.Conn instead of *sql.DB Same issue as pgmigrate: *sql.DB is a connection pool, so each call may land on a different connection. Migrations need a pinned connection for session state (SET search_path, temp tables, etc.) to persist across sequential calls. *sql.Conn (from db.Conn(ctx)) pins one underlying connection for its lifetime. --- .../sqlmigrate/litemigrate/litemigrate.go | 16 ++++++----- database/sqlmigrate/msmigrate/msmigrate.go | 16 ++++++----- database/sqlmigrate/mymigrate/mymigrate.go | 27 ++++++++++--------- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/database/sqlmigrate/litemigrate/litemigrate.go b/database/sqlmigrate/litemigrate/litemigrate.go index 16ddc96..c1fc8e0 100644 --- a/database/sqlmigrate/litemigrate/litemigrate.go +++ b/database/sqlmigrate/litemigrate/litemigrate.go @@ -4,6 +4,7 @@ // import _ "modernc.org/sqlite" // // db, err := sql.Open("sqlite", "app.db?_pragma=foreign_keys(1)") +// conn, err := db.Conn(ctx) // // SQLite disables foreign key enforcement by default. The _pragma DSN // parameter enables it on every connection the pool opens. @@ -18,14 +19,15 @@ import ( "github.com/therootcompany/golib/database/sqlmigrate" ) -// Migrator implements sqlmigrate.Migrator using a *sql.DB with SQLite. +// Migrator implements sqlmigrate.Migrator using a *sql.Conn with SQLite. type Migrator struct { - DB *sql.DB + Conn *sql.Conn } -// New creates a Migrator from the given database handle. -func New(db *sql.DB) *Migrator { - return &Migrator{DB: db} +// New creates a Migrator from the given connection. +// Use db.Conn(ctx) to obtain a *sql.Conn from a *sql.DB. +func New(conn *sql.Conn) *Migrator { + return &Migrator{Conn: conn} } var _ sqlmigrate.Migrator = (*Migrator)(nil) @@ -41,7 +43,7 @@ func (m *Migrator) ExecDown(ctx context.Context, mig sqlmigrate.Migration, sql s } func (m *Migrator) execInTx(ctx context.Context, sqlStr string) error { - tx, err := m.DB.BeginTx(ctx, nil) + tx, err := m.Conn.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("%w: begin: %w", sqlmigrate.ErrExecFailed, err) } @@ -61,7 +63,7 @@ func (m *Migrator) execInTx(ctx context.Context, sqlStr string) error { // Applied returns all applied migrations from the _migrations table. // Returns an empty slice if the table does not exist. func (m *Migrator) Applied(ctx context.Context) ([]sqlmigrate.Migration, error) { - rows, err := m.DB.QueryContext(ctx, "SELECT id, name FROM _migrations ORDER BY name") + rows, err := m.Conn.QueryContext(ctx, "SELECT id, name FROM _migrations ORDER BY name") if err != nil { // SQLite reports "no such table: _migrations" — stable across versions if strings.Contains(err.Error(), "no such table") { diff --git a/database/sqlmigrate/msmigrate/msmigrate.go b/database/sqlmigrate/msmigrate/msmigrate.go index 6eb25e2..1af4ab3 100644 --- a/database/sqlmigrate/msmigrate/msmigrate.go +++ b/database/sqlmigrate/msmigrate/msmigrate.go @@ -2,6 +2,7 @@ // using database/sql with github.com/microsoft/go-mssqldb. // // db, err := sql.Open("sqlserver", "sqlserver://user:pass@host:1433?database=mydb") +// conn, err := db.Conn(ctx) package msmigrate import ( @@ -15,14 +16,15 @@ import ( "github.com/therootcompany/golib/database/sqlmigrate" ) -// Migrator implements sqlmigrate.Migrator using a *sql.DB with SQL Server. +// Migrator implements sqlmigrate.Migrator using a *sql.Conn with SQL Server. type Migrator struct { - DB *sql.DB + Conn *sql.Conn } -// New creates a Migrator from the given database handle. -func New(db *sql.DB) *Migrator { - return &Migrator{DB: db} +// New creates a Migrator from the given connection. +// Use db.Conn(ctx) to obtain a *sql.Conn from a *sql.DB. +func New(conn *sql.Conn) *Migrator { + return &Migrator{Conn: conn} } var _ sqlmigrate.Migrator = (*Migrator)(nil) @@ -38,7 +40,7 @@ func (m *Migrator) ExecDown(ctx context.Context, mig sqlmigrate.Migration, sql s } func (m *Migrator) execInTx(ctx context.Context, sqlStr string) error { - tx, err := m.DB.BeginTx(ctx, nil) + tx, err := m.Conn.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("%w: begin: %w", sqlmigrate.ErrExecFailed, err) } @@ -58,7 +60,7 @@ func (m *Migrator) execInTx(ctx context.Context, sqlStr string) error { // Applied returns all applied migrations from the _migrations table. // Returns an empty slice if the table does not exist (SQL Server error 208). func (m *Migrator) Applied(ctx context.Context) ([]sqlmigrate.Migration, error) { - rows, err := m.DB.QueryContext(ctx, "SELECT id, name FROM _migrations ORDER BY name") + rows, err := m.Conn.QueryContext(ctx, "SELECT id, name FROM _migrations ORDER BY name") if err != nil { // SQL Server error 208: "Invalid object name '_migrations'" if msErr, ok := errors.AsType[mssql.Error](err); ok && msErr.Number == 208 { diff --git a/database/sqlmigrate/mymigrate/mymigrate.go b/database/sqlmigrate/mymigrate/mymigrate.go index 1270c44..fcc9da1 100644 --- a/database/sqlmigrate/mymigrate/mymigrate.go +++ b/database/sqlmigrate/mymigrate/mymigrate.go @@ -1,12 +1,14 @@ // Package mymigrate implements sqlmigrate.Migrator for MySQL and MariaDB // using database/sql with github.com/go-sql-driver/mysql. // -// The *sql.DB must be opened with multiStatements=true in the DSN; -// without it, multi-statement migration files will silently execute only -// the first statement. The multiStatements requirement is validated lazily -// on the first ExecUp or ExecDown call: +// The *sql.Conn must originate from a *sql.DB opened with +// multiStatements=true in the DSN; without it, multi-statement migration +// files will silently execute only the first statement. The +// multiStatements requirement is validated lazily on the first ExecUp or +// ExecDown call: // // db, err := sql.Open("mysql", "user:pass@tcp(host:3306)/dbname?multiStatements=true") +// conn, err := db.Conn(ctx) // // MySQL and MariaDB do not support transactional DDL. Statements like // CREATE TABLE and ALTER TABLE cause an implicit commit, so if a migration @@ -25,17 +27,18 @@ import ( "github.com/therootcompany/golib/database/sqlmigrate" ) -// Migrator implements sqlmigrate.Migrator using a *sql.DB with MySQL/MariaDB. +// Migrator implements sqlmigrate.Migrator using a *sql.Conn with MySQL/MariaDB. type Migrator struct { - DB *sql.DB + Conn *sql.Conn validated bool } -// New creates a Migrator from the given database handle. +// New creates a Migrator from the given connection. +// Use db.Conn(ctx) to obtain a *sql.Conn from a *sql.DB. // The multiStatements=true DSN requirement is validated lazily on the // first ExecUp or ExecDown call. -func New(db *sql.DB) *Migrator { - return &Migrator{DB: db} +func New(conn *sql.Conn) *Migrator { + return &Migrator{Conn: conn} } var _ sqlmigrate.Migrator = (*Migrator)(nil) @@ -56,7 +59,7 @@ func (m *Migrator) exec(ctx context.Context, sqlStr string) error { if !m.validated { // Probe for multi-statement support. Without it, migration files // that contain more than one statement silently execute only the first. - if _, err := m.DB.ExecContext(ctx, "DO 1; DO 1"); err != nil { + if _, err := m.Conn.ExecContext(ctx, "DO 1; DO 1"); err != nil { return fmt.Errorf( "%w: mymigrate: migration requires multiStatements=true in the MySQL DSN", sqlmigrate.ErrExecFailed, @@ -65,7 +68,7 @@ func (m *Migrator) exec(ctx context.Context, sqlStr string) error { m.validated = true } - tx, err := m.DB.BeginTx(ctx, nil) + tx, err := m.Conn.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("%w: begin: %w", sqlmigrate.ErrExecFailed, err) } @@ -85,7 +88,7 @@ func (m *Migrator) exec(ctx context.Context, sqlStr string) error { // Applied returns all applied migrations from the _migrations table. // Returns an empty slice if the table does not exist (MySQL error 1146). func (m *Migrator) Applied(ctx context.Context) ([]sqlmigrate.Migration, error) { - rows, err := m.DB.QueryContext(ctx, "SELECT id, name FROM _migrations ORDER BY name") + rows, err := m.Conn.QueryContext(ctx, "SELECT id, name FROM _migrations ORDER BY name") if err != nil { if mysqlErr, ok := errors.AsType[*mysql.MySQLError](err); ok && mysqlErr.Number == 1146 { return nil, nil