golib/database/sqlmigrate/sqlmigrate_test.go
AJ ONeal 17bbc881a9
database/sqlmigrate: allow optional schema prefix in INSERT INTO _migrations
Update idFromInsert regex to match schema-qualified table references
such as INSERT INTO authz._migrations, in addition to the existing
unqualified INSERT INTO _migrations form.
2026-04-17 02:30:41 -06:00

497 lines
14 KiB
Go

package sqlmigrate_test
import (
"context"
"errors"
"slices"
"strings"
"testing"
"testing/fstest"
"github.com/therootcompany/golib/database/sqlmigrate"
)
// migs builds []Migration from names (IDs empty).
func migs(names ...string) []sqlmigrate.Migration {
out := make([]sqlmigrate.Migration, len(names))
for i, n := range names {
out[i] = sqlmigrate.Migration{Name: n}
}
return out
}
// names extracts just the Name field from a slice of Migration.
func names(ms []sqlmigrate.Migration) []string {
out := make([]string, len(ms))
for i, m := range ms {
out[i] = m.Name
}
return out
}
// mockMigrator tracks applied migrations in memory.
type mockMigrator struct {
applied []sqlmigrate.Migration
execErr error // if set, ExecUp/ExecDown return this on every call
upCalls []string
downCalls []string
}
func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.upCalls = append(m.upCalls, mig.Name)
if m.execErr != nil {
return m.execErr
}
m.applied = append(m.applied, mig)
slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int {
return strings.Compare(a.Name, b.Name)
})
return nil
}
func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.downCalls = append(m.downCalls, mig.Name)
if m.execErr != nil {
return m.execErr
}
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name })
return nil
}
func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) {
return slices.Clone(m.applied), nil
}
// --- Collect ---
func TestCollect(t *testing.T) {
t.Run("pairs and sorts", func(t *testing.T) {
fsys := fstest.MapFS{
"002_second.up.sql": {Data: []byte("CREATE TABLE b;")},
"002_second.down.sql": {Data: []byte("DROP TABLE b;")},
"001_first.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_first.down.sql": {Data: []byte("DROP TABLE a;")},
}
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if len(ddls) != 2 {
t.Fatalf("got %d ddls, want 2", len(ddls))
}
if ddls[0].Name != "001_first" {
t.Errorf("first = %q, want %q", ddls[0].Name, "001_first")
}
if ddls[1].Name != "002_second" {
t.Errorf("second = %q, want %q", ddls[1].Name, "002_second")
}
if ddls[0].Up != "CREATE TABLE a;" {
t.Errorf("first.Up = %q", ddls[0].Up)
}
if ddls[0].Down != "DROP TABLE a;" {
t.Errorf("first.Down = %q", ddls[0].Down)
}
})
t.Run("parses ID from INSERT", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")},
}
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if ddls[0].ID != "abcd1234" {
t.Errorf("ID = %q, want %q", ddls[0].ID, "abcd1234")
}
})
t.Run("parses ID from schema-prefixed INSERT", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("CREATE TABLE authz._migrations (id CHAR(8) PRIMARY KEY, name VARCHAR(80));\nINSERT INTO authz._migrations (name, id) VALUES ('001_init', 'abcd1234');")},
"001_init.down.sql": {Data: []byte("DROP TABLE authz._migrations;")},
}
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if ddls[0].ID != "abcd1234" {
t.Errorf("ID = %q, want %q", ddls[0].ID, "abcd1234")
}
})
t.Run("no ID when no INSERT", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;")},
}
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if ddls[0].ID != "" {
t.Errorf("ID = %q, want empty", ddls[0].ID)
}
})
t.Run("missing down", func(t *testing.T) {
fsys := fstest.MapFS{
"001_only-up.up.sql": {Data: []byte("CREATE TABLE x;")},
}
_, err := sqlmigrate.Collect(fsys, ".")
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err)
}
})
t.Run("missing up", func(t *testing.T) {
fsys := fstest.MapFS{
"001_only-down.down.sql": {Data: []byte("DROP TABLE x;")},
}
_, err := sqlmigrate.Collect(fsys, ".")
if !errors.Is(err, sqlmigrate.ErrMissingUp) {
t.Errorf("got %v, want ErrMissingUp", err)
}
})
t.Run("ignores non-sql files", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("UP")},
"001_init.down.sql": {Data: []byte("DOWN")},
"README.md": {Data: []byte("# Migrations")},
"_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")},
}
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if len(ddls) != 1 {
t.Fatalf("got %d ddls, want 1", len(ddls))
}
})
t.Run("empty fs", func(t *testing.T) {
fsys := fstest.MapFS{}
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if len(ddls) != 0 {
t.Fatalf("got %d ddls, want 0", len(ddls))
}
})
}
// --- NamesOnly ---
func TestNamesOnly(t *testing.T) {
ns := []string{"001_init", "002_users"}
ddls := sqlmigrate.NamesOnly(ns)
if len(ddls) != 2 {
t.Fatalf("got %d, want 2", len(ddls))
}
for i, d := range ddls {
if d.Name != ns[i] {
t.Errorf("[%d].Name = %q, want %q", i, d.Name, ns[i])
}
if d.Up != "" || d.Down != "" {
t.Errorf("[%d] has non-empty content", i)
}
}
}
// --- Up ---
func TestUp(t *testing.T) {
ctx := t.Context()
ddls := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
}
t.Run("apply all", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(names(ran), []string{"001_init", "002_users", "003_posts"}) {
t.Errorf("applied = %v", names(ran))
}
})
t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{}
_, err := sqlmigrate.Up(ctx, m, ddls, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err)
}
})
t.Run("apply n", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, ddls, 2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(names(ran), []string{"001_init", "002_users"}) {
t.Errorf("applied = %v", names(ran))
}
})
t.Run("none pending", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if len(ran) != 0 {
t.Fatalf("applied %d, want 0", len(ran))
}
})
t.Run("partial pending", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init")}
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(names(ran), []string{"002_users", "003_posts"}) {
t.Errorf("applied = %v", names(ran))
}
})
t.Run("n exceeds pending", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init")}
ran, err := sqlmigrate.Up(ctx, m, ddls, 99)
if err != nil {
t.Fatal(err)
}
if len(ran) != 2 {
t.Fatalf("applied %d, want 2", len(ran))
}
})
t.Run("exec error stops and returns partial", func(t *testing.T) {
m := &failOnNthMigrator{failAt: 1}
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err == nil {
t.Fatal("expected error")
}
if len(ran) != 1 {
t.Errorf("applied %d before error, want 1", len(ran))
}
})
t.Run("skips migration matched by ID", func(t *testing.T) {
// DB has migration applied under old name, but same ID
m := &mockMigrator{applied: []sqlmigrate.Migration{
{Name: "001_old-name", ID: "aa11bb22"},
}}
idDDLs := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
}
ran, err := sqlmigrate.Up(ctx, m, idDDLs, -1)
if err != nil {
t.Fatal(err)
}
// Only 002_users should be applied; 001 is matched by ID
if !slices.Equal(names(ran), []string{"002_users"}) {
t.Errorf("applied = %v, want [002_users]", names(ran))
}
})
}
// failOnNthMigrator fails on the Nth ExecUp call (0-indexed).
type failOnNthMigrator struct {
applied []sqlmigrate.Migration
calls int
failAt int
}
func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error {
if m.calls == m.failAt {
m.calls++
return errors.New("connection lost")
}
m.calls++
m.applied = append(m.applied, mig)
slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int {
return strings.Compare(a.Name, b.Name)
})
return nil
}
func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name })
return nil
}
func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) {
return slices.Clone(m.applied), nil
}
// --- Down ---
func TestDown(t *testing.T) {
ctx := t.Context()
ddls := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
}
t.Run("rollback all", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 3 {
t.Fatalf("rolled %d, want 3", len(rolled))
}
if rolled[0].Name != "003_posts" {
t.Errorf("first rollback = %q, want 003_posts", rolled[0].Name)
}
if rolled[2].Name != "001_init" {
t.Errorf("last rollback = %q, want 001_init", rolled[2].Name)
}
})
t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init", "002_users")}
_, err := sqlmigrate.Down(ctx, m, ddls, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err)
}
})
t.Run("rollback n", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, ddls, 2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(names(rolled), []string{"003_posts", "002_users"}) {
t.Errorf("rolled = %v", names(rolled))
}
})
t.Run("none applied", func(t *testing.T) {
m := &mockMigrator{}
rolled, err := sqlmigrate.Down(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 0 {
t.Fatalf("rolled %d, want 0", len(rolled))
}
})
t.Run("exec error", func(t *testing.T) {
m := &mockMigrator{
applied: migs("001_init", "002_users"),
execErr: errors.New("permission denied"),
}
rolled, err := sqlmigrate.Down(ctx, m, ddls, 1)
if err == nil {
t.Fatal("expected error")
}
if len(rolled) != 0 {
t.Errorf("rolled %d on error, want 0", len(rolled))
}
})
t.Run("unknown migration in applied", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init", "999_unknown")}
_, err := sqlmigrate.Down(ctx, m, ddls, -1)
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err)
}
})
t.Run("n exceeds applied", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init")}
rolled, err := sqlmigrate.Down(ctx, m, ddls, 99)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 1 {
t.Fatalf("rolled %d, want 1", len(rolled))
}
})
t.Run("finds migration by ID when name changed", func(t *testing.T) {
// DB has old name, file has new name, same ID
m := &mockMigrator{applied: []sqlmigrate.Migration{
{Name: "001_old-name", ID: "aa11bb22"},
}}
idDDLs := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
}
rolled, err := sqlmigrate.Down(ctx, m, idDDLs, 1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(names(rolled), []string{"001_old-name"}) {
t.Errorf("rolled = %v, want [001_old-name]", names(rolled))
}
})
}
// --- GetStatus ---
func TestGetStatus(t *testing.T) {
ctx := t.Context()
ddls := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_init"}},
{Migration: sqlmigrate.Migration{Name: "002_users"}},
{Migration: sqlmigrate.Migration{Name: "003_posts"}},
}
t.Run("all pending", func(t *testing.T) {
m := &mockMigrator{}
status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 0 {
t.Errorf("applied = %d, want 0", len(status.Applied))
}
if len(status.Pending) != 3 {
t.Errorf("pending = %d, want 3", len(status.Pending))
}
})
t.Run("partial", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init")}
status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 1 || status.Applied[0].Name != "001_init" {
t.Errorf("applied = %v", status.Applied)
}
if !slices.Equal(names(status.Pending), []string{"002_users", "003_posts"}) {
t.Errorf("pending = %v", names(status.Pending))
}
})
t.Run("all applied", func(t *testing.T) {
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 3 {
t.Errorf("applied = %d, want 3", len(status.Applied))
}
if len(status.Pending) != 0 {
t.Errorf("pending = %d, want 0", len(status.Pending))
}
})
}