From aebef71a9559fbb5283a4383c6c3bee8051be1b5 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 10 Apr 2026 01:07:58 -0600 Subject: [PATCH] test(sqlmigrate): add ordering, end-to-end, rollback, and dialect-specific tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Across all four backends: - TestAppliedOrdering: insert rows out of order, verify Applied() returns them sorted by name. Guards against the ORDER BY clause being dropped or the query returning rows in arbitrary order. - TestEndToEndCycle: Collect → Up → Applied → Down → Applied via the sqlmigrate orchestrator with real migration files. Catches wiring bugs between Migrator and orchestrator that the in-package mockMigrator tests cannot. - TestDMLRollback: multi-statement DML migration where the last statement fails, verifies earlier INSERTs are rolled back. MySQL note: DML-only because MySQL implicitly commits DDL. Dialect-specific: - mymigrate TestMultiStatementsRequired: strip multiStatements=true from the DSN, verify ExecUp fails with a clear error mentioning multiStatements (rather than silently running only the first statement of a multi-statement migration). - litemigrate TestForeignKeyEnforcement: verifies FK constraints are enforced when the DSN includes _pragma=foreign_keys(1). Test fixture fix: cleanup closures now use context.Background() instead of the test context. t.Context() is canceled before t.Cleanup runs, so DB cleanup silently failed. Previously the _migrations cleanup appeared to work because the next test's connect() re-ran DROP TABLE at setup, but domain tables (test_*) leaked across runs. New tests also pre-clean at setup for self-healing after interrupted runs. --- .../litemigrate/litemigrate_test.go | 184 ++++++++++++++- .../sqlmigrate/msmigrate/msmigrate_test.go | 168 +++++++++++++- .../sqlmigrate/mymigrate/mymigrate_test.go | 216 +++++++++++++++++- .../sqlmigrate/pgmigrate/pgmigrate_test.go | 152 +++++++++++- 4 files changed, 715 insertions(+), 5 deletions(-) diff --git a/database/sqlmigrate/litemigrate/litemigrate_test.go b/database/sqlmigrate/litemigrate/litemigrate_test.go index 8b58522..1f0bf2d 100644 --- a/database/sqlmigrate/litemigrate/litemigrate_test.go +++ b/database/sqlmigrate/litemigrate/litemigrate_test.go @@ -2,10 +2,13 @@ package litemigrate_test import ( "database/sql" + "errors" "testing" + "testing/fstest" _ "modernc.org/sqlite" + "github.com/therootcompany/golib/database/sqlmigrate" "github.com/therootcompany/golib/database/sqlmigrate/litemigrate" ) @@ -13,7 +16,13 @@ import ( // The cleanup closes both the conn and the underlying *sql.DB. func openMem(t *testing.T) *sql.Conn { t.Helper() - db, err := sql.Open("sqlite", ":memory:") + return openMemDSN(t, ":memory:") +} + +// openMemDSN opens an in-memory SQLite with the given DSN (for pragmas). +func openMemDSN(t *testing.T, dsn string) *sql.Conn { + t.Helper() + db, err := sql.Open("sqlite", dsn) if err != nil { t.Fatalf("open: %v", err) } @@ -122,3 +131,176 @@ func TestAppliedAfterDropTable(t *testing.T) { t.Errorf("Applied() len = %d, want 0", len(applied)) } } + +// TestAppliedOrdering verifies Applied sorts by name (ascending), regardless +// of insertion order. Guards against the ORDER BY clause being removed or +// the underlying query returning rows in arbitrary order. +func TestAppliedOrdering(t *testing.T) { + conn := openMem(t) + ctx := t.Context() + + if _, err := conn.ExecContext(ctx, ` + CREATE TABLE _migrations (id TEXT, name TEXT); + INSERT INTO _migrations (id, name) VALUES ('ccc33333', '003_posts'); + INSERT INTO _migrations (id, name) VALUES ('aaa11111', '001_init'); + INSERT INTO _migrations (id, name) VALUES ('bbb22222', '002_users'); + `); err != nil { + t.Fatalf("setup: %v", err) + } + + m := litemigrate.New(conn) + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied() error = %v", err) + } + wantNames := []string{"001_init", "002_users", "003_posts"} + if len(applied) != len(wantNames) { + t.Fatalf("Applied() len = %d, want %d", len(applied), len(wantNames)) + } + for i, w := range wantNames { + if applied[i].Name != w { + t.Errorf("applied[%d].Name = %q, want %q", i, applied[i].Name, w) + } + } +} + +// TestEndToEndCycle runs a real Collect → Up → Applied → Down → Applied +// cycle through the sqlmigrate orchestrator. Catches wiring bugs between +// Migrator and the orchestrator that the in-package mockMigrator tests +// cannot. +func TestEndToEndCycle(t *testing.T) { + conn := openMem(t) + ctx := t.Context() + + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte(` + CREATE TABLE _migrations (id TEXT, name TEXT); + CREATE TABLE test_widgets (n INTEGER); + INSERT INTO _migrations (name, id) VALUES ('001_init', 'aaaa1111'); + `)}, + "001_init.down.sql": {Data: []byte(` + DROP TABLE test_widgets; + DROP TABLE _migrations; + `)}, + "002_gadgets.up.sql": {Data: []byte(` + CREATE TABLE test_gadgets (n INTEGER); + INSERT INTO _migrations (name, id) VALUES ('002_gadgets', 'bbbb2222'); + `)}, + "002_gadgets.down.sql": {Data: []byte(` + DROP TABLE test_gadgets; + DELETE FROM _migrations WHERE id = 'bbbb2222'; + `)}, + } + ddls, err := sqlmigrate.Collect(fsys, ".") + if err != nil { + t.Fatalf("Collect: %v", err) + } + + m := litemigrate.New(conn) + + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Up: %v", err) + } + if len(ran) != 2 { + t.Fatalf("ran = %d, want 2", len(ran)) + } + + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied: %v", err) + } + if len(applied) != 2 { + t.Fatalf("applied = %d, want 2", len(applied)) + } + if applied[0].ID != "aaaa1111" || applied[1].ID != "bbbb2222" { + t.Errorf("applied IDs = %+v, want [aaaa1111 bbbb2222]", applied) + } + + // Verify the domain tables actually exist + for _, tbl := range []string{"test_widgets", "test_gadgets"} { + if _, err := conn.ExecContext(ctx, "SELECT COUNT(*) FROM "+tbl); err != nil { + t.Errorf("expected table %q to exist: %v", tbl, err) + } + } + + rolled, err := sqlmigrate.Down(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Down: %v", err) + } + if len(rolled) != 2 { + t.Fatalf("rolled = %d, want 2", len(rolled)) + } + + applied, err = m.Applied(ctx) + if err != nil { + t.Fatalf("Applied after Down: %v", err) + } + if len(applied) != 0 { + t.Errorf("applied after Down = %d, want 0", len(applied)) + } +} + +// TestDMLRollback verifies that when a migration contains multiple DML +// statements and one fails, earlier statements in the same migration are +// rolled back. Uses an INSERT into a nonexistent table as the failure +// trigger so the test is portable across backends. +func TestDMLRollback(t *testing.T) { + conn := openMem(t) + ctx := t.Context() + + if _, err := conn.ExecContext(ctx, `CREATE TABLE test_rollback (n INTEGER)`); err != nil { + t.Fatalf("create: %v", err) + } + + m := litemigrate.New(conn) + err := m.ExecUp(ctx, sqlmigrate.Migration{Name: "rollback"}, ` + INSERT INTO test_rollback (n) VALUES (1); + INSERT INTO test_rollback (n) VALUES (2); + INSERT INTO nonexistent_table (n) VALUES (3); + `) + if err == nil { + t.Fatal("ExecUp() = nil, want error") + } + if !errors.Is(err, sqlmigrate.ErrExecFailed) { + t.Errorf("err = %v, want ErrExecFailed", err) + } + + var count int + if err := conn.QueryRowContext(ctx, "SELECT COUNT(*) FROM test_rollback").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Errorf("test_rollback count = %d, want 0 (rows should have been rolled back)", count) + } +} + +// TestForeignKeyEnforcement verifies that SQLite foreign key constraints +// are enforced when _pragma=foreign_keys(1) is in the DSN. Without the +// pragma, SQLite silently allows FK violations. +func TestForeignKeyEnforcement(t *testing.T) { + conn := openMemDSN(t, ":memory:?_pragma=foreign_keys(1)") + ctx := t.Context() + + if _, err := conn.ExecContext(ctx, ` + CREATE TABLE test_parent (id INTEGER PRIMARY KEY); + CREATE TABLE test_child ( + id INTEGER PRIMARY KEY, + parent_id INTEGER, + FOREIGN KEY (parent_id) REFERENCES test_parent(id) + ); + `); err != nil { + t.Fatalf("schema: %v", err) + } + + m := litemigrate.New(conn) + err := m.ExecUp(ctx, sqlmigrate.Migration{Name: "fk"}, ` + INSERT INTO test_child (id, parent_id) VALUES (1, 999); + `) + if err == nil { + t.Fatal("ExecUp() = nil, want FK violation error") + } + if !errors.Is(err, sqlmigrate.ErrExecFailed) { + t.Errorf("err = %v, want ErrExecFailed", err) + } +} diff --git a/database/sqlmigrate/msmigrate/msmigrate_test.go b/database/sqlmigrate/msmigrate/msmigrate_test.go index 417530b..baaed00 100644 --- a/database/sqlmigrate/msmigrate/msmigrate_test.go +++ b/database/sqlmigrate/msmigrate/msmigrate_test.go @@ -1,12 +1,16 @@ package msmigrate_test import ( + "context" "database/sql" + "errors" "os" "testing" + "testing/fstest" _ "github.com/microsoft/go-mssqldb" + "github.com/therootcompany/golib/database/sqlmigrate" "github.com/therootcompany/golib/database/sqlmigrate/msmigrate" ) @@ -40,8 +44,10 @@ func connect(t *testing.T) *sql.Conn { if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS _migrations"); err != nil { t.Fatalf("pre-cleanup _migrations: %v", err) } + // Cleanup uses a fresh context because t.Context() is canceled + // before cleanup runs, which would silently fail the DROP. t.Cleanup(func() { - _, _ = conn.ExecContext(ctx, "DROP TABLE IF EXISTS _migrations") + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS _migrations") }) return conn @@ -143,3 +149,163 @@ func TestAppliedAfterDropTable(t *testing.T) { t.Errorf("Applied() len = %d, want 0", len(applied)) } } + +// TestAppliedOrdering verifies Applied sorts by name (ascending), regardless +// of insertion order. Guards against the ORDER BY clause being removed or +// the underlying query returning rows in arbitrary order. +func TestAppliedOrdering(t *testing.T) { + conn := connect(t) + ctx := t.Context() + + if _, err := conn.ExecContext(ctx, `CREATE TABLE _migrations (id NVARCHAR(16), name NVARCHAR(255))`); err != nil { + t.Fatalf("create: %v", err) + } + if _, err := conn.ExecContext(ctx, + `INSERT INTO _migrations (id, name) VALUES ('ccc33333', '003_posts'), ('aaa11111', '001_init'), ('bbb22222', '002_users')`, + ); err != nil { + t.Fatalf("insert: %v", err) + } + + m := msmigrate.New(conn) + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied() error = %v", err) + } + wantNames := []string{"001_init", "002_users", "003_posts"} + if len(applied) != len(wantNames) { + t.Fatalf("Applied() len = %d, want %d", len(applied), len(wantNames)) + } + for i, w := range wantNames { + if applied[i].Name != w { + t.Errorf("applied[%d].Name = %q, want %q", i, applied[i].Name, w) + } + } +} + +// TestEndToEndCycle runs a real Collect → Up → Applied → Down → Applied +// cycle through the sqlmigrate orchestrator. Catches wiring bugs between +// Migrator and the orchestrator that the in-package mockMigrator tests +// cannot. +func TestEndToEndCycle(t *testing.T) { + conn := connect(t) + ctx := t.Context() + // Pre-clean and post-clean domain tables. Pre-clean handles leftovers + // from interrupted runs; post-clean uses context.Background() because + // t.Context() is canceled before cleanup runs. + for _, tbl := range []string{"test_widgets", "test_gadgets"} { + if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS "+tbl); err != nil { + t.Fatalf("pre-clean %s: %v", tbl, err) + } + } + t.Cleanup(func() { + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS test_widgets") + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS test_gadgets") + }) + + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte(` + CREATE TABLE _migrations (id NVARCHAR(16), name NVARCHAR(255)); + CREATE TABLE test_widgets (n INT); + INSERT INTO _migrations (name, id) VALUES ('001_init', 'aaaa1111'); + `)}, + "001_init.down.sql": {Data: []byte(` + DROP TABLE test_widgets; + DROP TABLE _migrations; + `)}, + "002_gadgets.up.sql": {Data: []byte(` + CREATE TABLE test_gadgets (n INT); + INSERT INTO _migrations (name, id) VALUES ('002_gadgets', 'bbbb2222'); + `)}, + "002_gadgets.down.sql": {Data: []byte(` + DROP TABLE test_gadgets; + DELETE FROM _migrations WHERE id = 'bbbb2222'; + `)}, + } + ddls, err := sqlmigrate.Collect(fsys, ".") + if err != nil { + t.Fatalf("Collect: %v", err) + } + + m := msmigrate.New(conn) + + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Up: %v", err) + } + if len(ran) != 2 { + t.Fatalf("ran = %d, want 2", len(ran)) + } + + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied: %v", err) + } + if len(applied) != 2 { + t.Fatalf("applied = %d, want 2", len(applied)) + } + if applied[0].ID != "aaaa1111" || applied[1].ID != "bbbb2222" { + t.Errorf("applied IDs = %+v, want [aaaa1111 bbbb2222]", applied) + } + + for _, tbl := range []string{"test_widgets", "test_gadgets"} { + if _, err := conn.ExecContext(ctx, "SELECT COUNT(*) FROM "+tbl); err != nil { + t.Errorf("expected table %q to exist: %v", tbl, err) + } + } + + rolled, err := sqlmigrate.Down(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Down: %v", err) + } + if len(rolled) != 2 { + t.Fatalf("rolled = %d, want 2", len(rolled)) + } + + applied, err = m.Applied(ctx) + if err != nil { + t.Fatalf("Applied after Down: %v", err) + } + if len(applied) != 0 { + t.Errorf("applied after Down = %d, want 0", len(applied)) + } +} + +// TestDMLRollback verifies that when a migration contains multiple DML +// statements and one fails, earlier statements in the same migration are +// rolled back. Uses an INSERT into a nonexistent table as the failure +// trigger so the test is portable across backends. +func TestDMLRollback(t *testing.T) { + conn := connect(t) + ctx := t.Context() + if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS test_rollback"); err != nil { + t.Fatalf("pre-clean: %v", err) + } + t.Cleanup(func() { + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS test_rollback") + }) + + if _, err := conn.ExecContext(ctx, `CREATE TABLE test_rollback (n INT)`); err != nil { + t.Fatalf("create: %v", err) + } + + m := msmigrate.New(conn) + err := m.ExecUp(ctx, sqlmigrate.Migration{Name: "rollback"}, ` + INSERT INTO test_rollback (n) VALUES (1); + INSERT INTO test_rollback (n) VALUES (2); + INSERT INTO nonexistent_table (n) VALUES (3); + `) + if err == nil { + t.Fatal("ExecUp() = nil, want error") + } + if !errors.Is(err, sqlmigrate.ErrExecFailed) { + t.Errorf("err = %v, want ErrExecFailed", err) + } + + var count int + if err := conn.QueryRowContext(ctx, "SELECT COUNT(*) FROM test_rollback").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Errorf("test_rollback count = %d, want 0 (rows should have been rolled back)", count) + } +} diff --git a/database/sqlmigrate/mymigrate/mymigrate_test.go b/database/sqlmigrate/mymigrate/mymigrate_test.go index 72cb02f..4cf5b17 100644 --- a/database/sqlmigrate/mymigrate/mymigrate_test.go +++ b/database/sqlmigrate/mymigrate/mymigrate_test.go @@ -1,12 +1,17 @@ package mymigrate_test import ( + "context" "database/sql" + "errors" "os" + "strings" "testing" + "testing/fstest" _ "github.com/go-sql-driver/mysql" + "github.com/therootcompany/golib/database/sqlmigrate" "github.com/therootcompany/golib/database/sqlmigrate/mymigrate" ) @@ -40,8 +45,10 @@ func connect(t *testing.T) *sql.Conn { if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS _migrations"); err != nil { t.Fatalf("pre-cleanup _migrations: %v", err) } + // Cleanup uses a fresh context because t.Context() is canceled + // before cleanup runs, which would silently fail the DROP. t.Cleanup(func() { - _, _ = conn.ExecContext(ctx, "DROP TABLE IF EXISTS _migrations") + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS _migrations") }) return conn @@ -143,3 +150,210 @@ func TestAppliedAfterDropTable(t *testing.T) { t.Errorf("Applied() len = %d, want 0", len(applied)) } } + +// TestAppliedOrdering verifies Applied sorts by name (ascending), regardless +// of insertion order. Guards against the ORDER BY clause being removed or +// the underlying query returning rows in arbitrary order. +func TestAppliedOrdering(t *testing.T) { + conn := connect(t) + ctx := t.Context() + + if _, err := conn.ExecContext(ctx, `CREATE TABLE _migrations (id VARCHAR(16), name VARCHAR(255))`); err != nil { + t.Fatalf("create: %v", err) + } + if _, err := conn.ExecContext(ctx, + `INSERT INTO _migrations (id, name) VALUES ('ccc33333', '003_posts'), ('aaa11111', '001_init'), ('bbb22222', '002_users')`, + ); err != nil { + t.Fatalf("insert: %v", err) + } + + m := mymigrate.New(conn) + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied() error = %v", err) + } + wantNames := []string{"001_init", "002_users", "003_posts"} + if len(applied) != len(wantNames) { + t.Fatalf("Applied() len = %d, want %d", len(applied), len(wantNames)) + } + for i, w := range wantNames { + if applied[i].Name != w { + t.Errorf("applied[%d].Name = %q, want %q", i, applied[i].Name, w) + } + } +} + +// TestEndToEndCycle runs a real Collect → Up → Applied → Down → Applied +// cycle through the sqlmigrate orchestrator. Catches wiring bugs between +// Migrator and the orchestrator that the in-package mockMigrator tests +// cannot. Note: MySQL implicitly commits DDL, so rolling back a migration +// mid-exec cannot be tested here — see TestDMLRollback for the DML case. +func TestEndToEndCycle(t *testing.T) { + conn := connect(t) + ctx := t.Context() + // Pre-clean and post-clean domain tables. Pre-clean handles leftovers + // from interrupted runs; post-clean uses context.Background() because + // t.Context() is canceled before cleanup runs. + for _, tbl := range []string{"test_widgets", "test_gadgets"} { + if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS "+tbl); err != nil { + t.Fatalf("pre-clean %s: %v", tbl, err) + } + } + t.Cleanup(func() { + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS test_widgets") + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS test_gadgets") + }) + + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte(` + CREATE TABLE _migrations (id VARCHAR(16), name VARCHAR(255)); + CREATE TABLE test_widgets (n INT); + INSERT INTO _migrations (name, id) VALUES ('001_init', 'aaaa1111'); + `)}, + "001_init.down.sql": {Data: []byte(` + DROP TABLE test_widgets; + DROP TABLE _migrations; + `)}, + "002_gadgets.up.sql": {Data: []byte(` + CREATE TABLE test_gadgets (n INT); + INSERT INTO _migrations (name, id) VALUES ('002_gadgets', 'bbbb2222'); + `)}, + "002_gadgets.down.sql": {Data: []byte(` + DROP TABLE test_gadgets; + DELETE FROM _migrations WHERE id = 'bbbb2222'; + `)}, + } + ddls, err := sqlmigrate.Collect(fsys, ".") + if err != nil { + t.Fatalf("Collect: %v", err) + } + + m := mymigrate.New(conn) + + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Up: %v", err) + } + if len(ran) != 2 { + t.Fatalf("ran = %d, want 2", len(ran)) + } + + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied: %v", err) + } + if len(applied) != 2 { + t.Fatalf("applied = %d, want 2", len(applied)) + } + if applied[0].ID != "aaaa1111" || applied[1].ID != "bbbb2222" { + t.Errorf("applied IDs = %+v, want [aaaa1111 bbbb2222]", applied) + } + + for _, tbl := range []string{"test_widgets", "test_gadgets"} { + if _, err := conn.ExecContext(ctx, "SELECT COUNT(*) FROM "+tbl); err != nil { + t.Errorf("expected table %q to exist: %v", tbl, err) + } + } + + rolled, err := sqlmigrate.Down(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Down: %v", err) + } + if len(rolled) != 2 { + t.Fatalf("rolled = %d, want 2", len(rolled)) + } + + applied, err = m.Applied(ctx) + if err != nil { + t.Fatalf("Applied after Down: %v", err) + } + if len(applied) != 0 { + t.Errorf("applied after Down = %d, want 0", len(applied)) + } +} + +// TestDMLRollback verifies that when a migration contains multiple DML +// statements and one fails, earlier statements in the same migration are +// rolled back. MySQL can roll back DML in a transaction but silently +// commits DDL; this test is DML-only. Uses an INSERT into a nonexistent +// table as the failure trigger so the test is portable across backends. +func TestDMLRollback(t *testing.T) { + conn := connect(t) + ctx := t.Context() + if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS test_rollback"); err != nil { + t.Fatalf("pre-clean: %v", err) + } + t.Cleanup(func() { + _, _ = conn.ExecContext(context.Background(), "DROP TABLE IF EXISTS test_rollback") + }) + + if _, err := conn.ExecContext(ctx, `CREATE TABLE test_rollback (n INT) ENGINE=InnoDB`); err != nil { + t.Fatalf("create: %v", err) + } + + m := mymigrate.New(conn) + err := m.ExecUp(ctx, sqlmigrate.Migration{Name: "rollback"}, ` + INSERT INTO test_rollback (n) VALUES (1); + INSERT INTO test_rollback (n) VALUES (2); + INSERT INTO nonexistent_table (n) VALUES (3); + `) + if err == nil { + t.Fatal("ExecUp() = nil, want error") + } + if !errors.Is(err, sqlmigrate.ErrExecFailed) { + t.Errorf("err = %v, want ErrExecFailed", err) + } + + var count int + if err := conn.QueryRowContext(ctx, "SELECT COUNT(*) FROM test_rollback").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Errorf("test_rollback count = %d, want 0 (rows should have been rolled back)", count) + } +} + +// TestMultiStatementsRequired verifies that mymigrate rejects a DSN that +// does not have multiStatements=true, with a clear error that mentions +// multiStatements. Without the lazy validation, a multi-statement +// migration would silently run only the first statement, corrupting the +// schema. +func TestMultiStatementsRequired(t *testing.T) { + dsn := os.Getenv("MYSQL_TEST_DSN") + if dsn == "" { + t.Skip("MYSQL_TEST_DSN not set") + } + // Strip multiStatements=true from the DSN so the lazy probe fires. + stripped := dsn + for _, pat := range []string{"multiStatements=true&", "&multiStatements=true", "multiStatements=true"} { + stripped = strings.ReplaceAll(stripped, pat, "") + } + if stripped == dsn { + t.Fatalf("MYSQL_TEST_DSN does not contain multiStatements=true; cannot run this test") + } + + ctx := t.Context() + db, err := sql.Open("mysql", stripped) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf("conn: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + m := mymigrate.New(conn) + err = m.ExecUp(ctx, sqlmigrate.Migration{Name: "probe"}, "SELECT 1") + if err == nil { + t.Fatal("ExecUp() = nil, want error for missing multiStatements") + } + if !errors.Is(err, sqlmigrate.ErrExecFailed) { + t.Errorf("err = %v, want ErrExecFailed", err) + } + if !strings.Contains(err.Error(), "multiStatements") { + t.Errorf("err message = %q, want to mention multiStatements", err.Error()) + } +} diff --git a/database/sqlmigrate/pgmigrate/pgmigrate_test.go b/database/sqlmigrate/pgmigrate/pgmigrate_test.go index 0083465..4f4e986 100644 --- a/database/sqlmigrate/pgmigrate/pgmigrate_test.go +++ b/database/sqlmigrate/pgmigrate/pgmigrate_test.go @@ -1,11 +1,15 @@ package pgmigrate_test import ( + "context" + "errors" "os" "testing" + "testing/fstest" "github.com/jackc/pgx/v5" + "github.com/therootcompany/golib/database/sqlmigrate" "github.com/therootcompany/golib/database/sqlmigrate/pgmigrate" ) @@ -24,7 +28,7 @@ func connect(t *testing.T) *pgx.Conn { if err != nil { t.Fatalf("connect: %v", err) } - t.Cleanup(func() { _ = conn.Close(ctx) }) + t.Cleanup(func() { _ = conn.Close(context.Background()) }) // Use a per-test schema so concurrent tests don't collide and // _migrations is guaranteed not to exist on entry. @@ -35,8 +39,10 @@ func connect(t *testing.T) *pgx.Conn { if _, err := conn.Exec(ctx, "CREATE SCHEMA "+schema); err != nil { t.Fatalf("create schema: %v", err) } + // Cleanup uses a fresh context because t.Context() is canceled + // before cleanup runs, which would silently fail the DROP SCHEMA. t.Cleanup(func() { - _, _ = conn.Exec(ctx, "DROP SCHEMA IF EXISTS "+schema+" CASCADE") + _, _ = conn.Exec(context.Background(), "DROP SCHEMA IF EXISTS "+schema+" CASCADE") }) if _, err := conn.Exec(ctx, "SET search_path TO "+schema); err != nil { t.Fatalf("set search_path: %v", err) @@ -158,3 +164,145 @@ func TestAppliedAfterDropTable(t *testing.T) { t.Errorf("Applied() len = %d, want 0", len(applied)) } } + +// TestAppliedOrdering verifies Applied sorts by name (ascending), regardless +// of insertion order. Guards against the ORDER BY clause being removed or +// the underlying query returning rows in arbitrary order. +func TestAppliedOrdering(t *testing.T) { + conn := connect(t) + ctx := t.Context() + + if _, err := conn.Exec(ctx, ` + CREATE TABLE _migrations (id TEXT, name TEXT); + INSERT INTO _migrations (id, name) VALUES ('ccc33333', '003_posts'); + INSERT INTO _migrations (id, name) VALUES ('aaa11111', '001_init'); + INSERT INTO _migrations (id, name) VALUES ('bbb22222', '002_users'); + `); err != nil { + t.Fatalf("setup: %v", err) + } + + m := pgmigrate.New(conn) + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied() error = %v", err) + } + wantNames := []string{"001_init", "002_users", "003_posts"} + if len(applied) != len(wantNames) { + t.Fatalf("Applied() len = %d, want %d", len(applied), len(wantNames)) + } + for i, w := range wantNames { + if applied[i].Name != w { + t.Errorf("applied[%d].Name = %q, want %q", i, applied[i].Name, w) + } + } +} + +// TestEndToEndCycle runs a real Collect → Up → Applied → Down → Applied +// cycle through the sqlmigrate orchestrator. Catches wiring bugs between +// Migrator and the orchestrator that the in-package mockMigrator tests +// cannot. +func TestEndToEndCycle(t *testing.T) { + conn := connect(t) + ctx := t.Context() + + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte(` + CREATE TABLE _migrations (id TEXT, name TEXT); + CREATE TABLE test_widgets (n INTEGER); + INSERT INTO _migrations (name, id) VALUES ('001_init', 'aaaa1111'); + `)}, + "001_init.down.sql": {Data: []byte(` + DROP TABLE test_widgets; + DROP TABLE _migrations; + `)}, + "002_gadgets.up.sql": {Data: []byte(` + CREATE TABLE test_gadgets (n INTEGER); + INSERT INTO _migrations (name, id) VALUES ('002_gadgets', 'bbbb2222'); + `)}, + "002_gadgets.down.sql": {Data: []byte(` + DROP TABLE test_gadgets; + DELETE FROM _migrations WHERE id = 'bbbb2222'; + `)}, + } + ddls, err := sqlmigrate.Collect(fsys, ".") + if err != nil { + t.Fatalf("Collect: %v", err) + } + + m := pgmigrate.New(conn) + + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Up: %v", err) + } + if len(ran) != 2 { + t.Fatalf("ran = %d, want 2", len(ran)) + } + + applied, err := m.Applied(ctx) + if err != nil { + t.Fatalf("Applied: %v", err) + } + if len(applied) != 2 { + t.Fatalf("applied = %d, want 2", len(applied)) + } + if applied[0].ID != "aaaa1111" || applied[1].ID != "bbbb2222" { + t.Errorf("applied IDs = %+v, want [aaaa1111 bbbb2222]", applied) + } + + for _, tbl := range []string{"test_widgets", "test_gadgets"} { + if _, err := conn.Exec(ctx, "SELECT COUNT(*) FROM "+tbl); err != nil { + t.Errorf("expected table %q to exist: %v", tbl, err) + } + } + + rolled, err := sqlmigrate.Down(ctx, m, ddls, -1) + if err != nil { + t.Fatalf("Down: %v", err) + } + if len(rolled) != 2 { + t.Fatalf("rolled = %d, want 2", len(rolled)) + } + + applied, err = m.Applied(ctx) + if err != nil { + t.Fatalf("Applied after Down: %v", err) + } + if len(applied) != 0 { + t.Errorf("applied after Down = %d, want 0", len(applied)) + } +} + +// TestDMLRollback verifies that when a migration contains multiple DML +// statements and one fails, earlier statements in the same migration are +// rolled back. Uses an INSERT into a nonexistent table as the failure +// trigger so the test is portable across backends. +func TestDMLRollback(t *testing.T) { + conn := connect(t) + ctx := t.Context() + + if _, err := conn.Exec(ctx, `CREATE TABLE test_rollback (n INTEGER)`); err != nil { + t.Fatalf("create: %v", err) + } + + m := pgmigrate.New(conn) + err := m.ExecUp(ctx, sqlmigrate.Migration{Name: "rollback"}, ` + INSERT INTO test_rollback (n) VALUES (1); + INSERT INTO test_rollback (n) VALUES (2); + INSERT INTO nonexistent_table (n) VALUES (3); + `) + if err == nil { + t.Fatal("ExecUp() = nil, want error") + } + if !errors.Is(err, sqlmigrate.ErrExecFailed) { + t.Errorf("err = %v, want ErrExecFailed", err) + } + + var count int + if err := conn.QueryRow(ctx, "SELECT COUNT(*) FROM test_rollback").Scan(&count); err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Errorf("test_rollback count = %d, want 0 (rows should have been rolled back)", count) + } +}