ref(sqlmigrate): split Migration into Migration + Script

Migration{ID, Name} is the identity type returned by Up/Down/Latest/Drop.
Script{Migration, Up, Down} holds collected SQL content from Collect().

Migrator interface now takes SQL as a separate parameter:
  ExecUp(ctx, Migration, sql string)
  ExecDown(ctx, Migration, sql string)

This separates identity from content — callers that track what ran
don't need to carry around SQL strings they'll never use.

Updates shmigrate to match (ignores the sql parameter, references
files on disk instead).
This commit is contained in:
AJ ONeal 2026-04-09 03:21:46 -06:00
parent 285eb0b684
commit 75e4a883b7
No known key found for this signature in database
7 changed files with 263 additions and 224 deletions

View File

@ -3,10 +3,17 @@ module github.com/therootcompany/golib/cmd/sql-migrate/v2
go 1.26.1
require (
github.com/therootcompany/golib/database/sqlmigrate v0.0.0
github.com/jackc/pgx/v5 v5.9.1
github.com/therootcompany/golib/database/sqlmigrate v1.0.0
github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
golang.org/x/text v0.29.0 // indirect
)
replace (
github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate
github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate

26
cmd/sql-migrate/go.sum Normal file
View File

@ -0,0 +1,26 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -939,7 +939,7 @@ func syncLog(runner *shmigrate.Migrator) {
fmt.Printf("cat %s\n", logPath)
}
func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error {
func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error {
// fixup pending migrations before generating the script
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
@ -974,7 +974,7 @@ func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrat
return nil
}
func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error {
func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error {
// fixup applied migrations before generating the script
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
@ -998,7 +998,7 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr
fmt.Println("")
// check for missing down files before generating script
reversed := make([]string, len(status.Applied))
reversed := make([]sqlmigrate.Migration, len(status.Applied))
copy(reversed, status.Applied)
slices.Reverse(reversed)
limit := n
@ -1008,8 +1008,8 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr
if limit > len(reversed) {
limit = len(reversed)
}
for _, name := range reversed[:limit] {
downPath := filepath.Join(state.MigrationsDir, name+".down.sql")
for _, a := range reversed[:limit] {
downPath := filepath.Join(state.MigrationsDir, a.Name+".down.sql")
if !fileExists(downPath) {
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
@ -1028,7 +1028,7 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr
return nil
}
func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration) error {
func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script) error {
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
if err != nil {
return err
@ -1040,21 +1040,21 @@ func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, mi
fmt.Fprintf(os.Stderr, "\n")
// show applied in reverse (most recent first)
applied := make([]string, len(status.Applied))
copy(applied, status.Applied)
slices.Reverse(applied)
appliedList := make([]sqlmigrate.Migration, len(status.Applied))
copy(appliedList, status.Applied)
slices.Reverse(appliedList)
fmt.Printf("# previous: %d\n", len(applied))
for _, mig := range applied {
fmt.Printf(" %s\n", mig)
fmt.Printf("# previous: %d\n", len(appliedList))
for _, mig := range appliedList {
fmt.Printf(" %s\n", mig.Name)
}
if len(applied) == 0 {
if len(appliedList) == 0 {
fmt.Println(" # (no previous migrations)")
}
fmt.Println("")
fmt.Printf("# pending: %d\n", len(status.Pending))
for _, mig := range status.Pending {
fmt.Printf(" %s\n", mig)
fmt.Printf(" %s\n", mig.Name)
}
if len(status.Pending) == 0 {
fmt.Println(" # (no pending migrations)")
@ -1063,7 +1063,7 @@ func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, mi
}
// fixupAll runs fixupMigration on all known migrations (applied + pending).
func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Migration) (fixedUp, fixedDown []string) {
func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Script) (fixedUp, fixedDown []string) {
seen := map[string]bool{}
var all []string
for _, name := range applied {

View File

@ -3,3 +3,5 @@ module github.com/therootcompany/golib/database/sqlmigrate/shmigrate
go 1.26.1
require github.com/therootcompany/golib/database/sqlmigrate v1.0.0
replace github.com/therootcompany/golib/database/sqlmigrate => ../

View File

@ -59,13 +59,15 @@ type Migrator struct {
var _ sqlmigrate.Migrator = (*Migrator)(nil)
// ExecUp outputs a shell command to run the .up.sql migration file.
func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration) error {
// The sql parameter is ignored — shmigrate references files on disk.
func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration, sql string) error {
r.counter++
return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter))
}
// ExecDown outputs a shell command to run the .down.sql migration file.
func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration) error {
// The sql parameter is ignored — shmigrate references files on disk.
func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration, sql string) error {
r.counter++
return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter))
}
@ -95,7 +97,7 @@ func (r *Migrator) exec(name, suffix, label string) error {
//
// Returns an empty slice if the file does not exist. When FS is set, reads
// from that filesystem; otherwise reads from the OS filesystem.
func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, error) {
func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.Migration, error) {
var f io.ReadCloser
var err error
if r.FS != nil {
@ -111,7 +113,7 @@ func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration,
}
defer f.Close()
var applied []sqlmigrate.AppliedMigration
var applied []sqlmigrate.Migration
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
@ -123,9 +125,9 @@ func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration,
continue
}
if id, name, ok := strings.Cut(line, "\t"); ok {
applied = append(applied, sqlmigrate.AppliedMigration{ID: id, Name: name})
applied = append(applied, sqlmigrate.Migration{ID: id, Name: name})
} else {
applied = append(applied, sqlmigrate.AppliedMigration{Name: line})
applied = append(applied, sqlmigrate.Migration{Name: line})
}
}

View File

@ -28,41 +28,40 @@ var (
ErrInvalidN = errors.New("n must be positive or -1 for all")
)
// Migration represents a paired up/down migration.
// Migration identifies a migration by its name and optional hex ID.
type Migration struct {
Name string // e.g. "2026-04-05-001000_create-todos"
ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect
Name string // e.g. "2026-04-05-001000_create-todos"
}
// Script is a Migration with its up and down SQL content, as returned by Collect.
type Script struct {
Migration
Up string // SQL content of the .up.sql file
Down string // SQL content of the .down.sql file
}
// AppliedMigration represents a migration recorded in the _migrations table.
type AppliedMigration struct {
ID string
Name string
}
// Status represents the current migration state.
type Status struct {
Applied []string
Pending []string
Applied []Migration
Pending []Migration
}
// Migrator executes migrations. Implementations handle the
// database-specific or output-specific details.
type Migrator interface {
// ExecUp runs the up migration. For database migrators this executes
// m.Up in a transaction. For shell migrators this outputs a command
// referencing the .up.sql file.
ExecUp(ctx context.Context, m Migration) error
// ExecUp runs the up migration SQL. For database migrators this
// executes the SQL in a transaction. For shell migrators this
// outputs a command referencing the .up.sql file.
ExecUp(ctx context.Context, m Migration, sql string) error
// ExecDown runs the down migration.
ExecDown(ctx context.Context, m Migration) error
// ExecDown runs the down migration SQL.
ExecDown(ctx context.Context, m Migration, sql string) error
// Applied returns all applied migrations from the _migrations table,
// sorted lexicographically by name. Returns an empty slice (not an
// error) if the migrations table or log does not exist yet.
Applied(ctx context.Context) ([]AppliedMigration, error)
Applied(ctx context.Context) ([]Migration, error)
}
// idFromInsert extracts the hex ID from an INSERT INTO _migrations line.
@ -75,8 +74,8 @@ var idFromInsert = regexp.MustCompile(
// pairs them by basename, and returns them sorted lexicographically by name.
// If subpath is "" or ".", the root of fsys is used.
// If the up SQL contains an INSERT INTO _migrations line, the hex ID
// is extracted and stored in Migration.ID.
func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
// is extracted and stored in Script.ID.
func Collect(fsys fs.FS, subpath string) ([]Script, error) {
if subpath != "" && subpath != "." {
var err error
fsys, err = fs.Sub(fsys, subpath)
@ -120,7 +119,7 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err)
}
var migrations []Migration
var ddls []Script
for name, upSQL := range ups {
downSQL, ok := downs[name]
if !ok {
@ -130,9 +129,8 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
if m := idFromInsert.FindStringSubmatch(upSQL); m != nil {
id = m[1]
}
migrations = append(migrations, Migration{
Name: name,
ID: id,
ddls = append(ddls, Script{
Migration: Migration{ID: id, Name: name},
Up: upSQL,
Down: downSQL,
})
@ -143,54 +141,54 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
}
}
slices.SortFunc(migrations, func(a, b Migration) int {
slices.SortFunc(ddls, func(a, b Script) int {
return strings.Compare(a.Name, b.Name)
})
return migrations, nil
return ddls, nil
}
// NamesOnly builds a Migration slice from a list of names, with empty
// NamesOnly builds a Script slice from a list of names, with empty
// Up/Down content. Useful for shell-based runners that reference files
// on disk rather than executing SQL directly.
func NamesOnly(names []string) []Migration {
migrations := make([]Migration, len(names))
func NamesOnly(names []string) []Script {
ddls := make([]Script, len(names))
for i, name := range names {
migrations[i] = Migration{Name: name}
ddls[i] = Script{Migration: Migration{Name: name}}
}
return migrations
return ddls
}
// isApplied returns true if the migration matches any applied entry by name or ID.
func isApplied(m Migration, applied []AppliedMigration) bool {
// isApplied returns true if the Script matches any applied entry by name or ID.
func isApplied(d Script, applied []Migration) bool {
for _, a := range applied {
if a.Name == m.Name {
if a.Name == d.Name {
return true
}
if m.ID != "" && a.ID != "" && a.ID == m.ID {
if d.ID != "" && a.ID != "" && a.ID == d.ID {
return true
}
}
return false
}
// findMigration looks up a migration by the applied entry's name or ID.
func findMigration(a AppliedMigration, byName map[string]Migration, byID map[string]Migration) (Migration, bool) {
if m, ok := byName[a.Name]; ok {
return m, true
// findScript looks up a Script by the applied entry's name or ID.
func findScript(a Migration, byName map[string]Script, byID map[string]Script) (Script, bool) {
if d, ok := byName[a.Name]; ok {
return d, true
}
if a.ID != "" {
if m, ok := byID[a.ID]; ok {
return m, true
if d, ok := byID[a.ID]; ok {
return d, true
}
}
return Migration{}, false
return Script{}, false
}
// Up applies up to n pending migrations using the given Runner.
// Up applies up to n pending migrations using the given Migrator.
// If n < 0, applies all pending. If n == 0, returns ErrInvalidN.
// Returns the names of applied migrations.
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
// Returns the applied migrations.
func Up(ctx context.Context, r Migrator, ddls []Script, n int) ([]Migration, error) {
if n == 0 {
return nil, ErrInvalidN
}
@ -200,10 +198,10 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
return nil, err
}
var pending []Migration
for _, m := range migrations {
if !isApplied(m, applied) {
pending = append(pending, m)
var pending []Script
for _, d := range ddls {
if !isApplied(d, applied) {
pending = append(pending, d)
}
}
@ -214,12 +212,12 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
n = len(pending)
}
var ran []string
for _, m := range pending[:n] {
if err := r.ExecUp(ctx, m); err != nil {
return ran, fmt.Errorf("%s (up): %w", m.Name, err)
var ran []Migration
for _, d := range pending[:n] {
if err := r.ExecUp(ctx, d.Migration, d.Up); err != nil {
return ran, fmt.Errorf("%s (up): %w", d.Name, err)
}
ran = append(ran, m.Name)
ran = append(ran, d.Migration)
}
return ran, nil
@ -227,8 +225,8 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
// Down rolls back up to n applied migrations, most recent first.
// If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN.
// Returns the names of rolled-back migrations.
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
// Returns the rolled-back migrations.
func Down(ctx context.Context, r Migrator, ddls []Script, n int) ([]Migration, error) {
if n == 0 {
return nil, ErrInvalidN
}
@ -238,16 +236,16 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
return nil, err
}
byName := map[string]Migration{}
byID := map[string]Migration{}
for _, m := range migrations {
byName[m.Name] = m
if m.ID != "" {
byID[m.ID] = m
byName := map[string]Script{}
byID := map[string]Script{}
for _, d := range ddls {
byName[d.Name] = d
if d.ID != "" {
byID[d.ID] = d
}
}
reversed := make([]AppliedMigration, len(applied))
reversed := make([]Migration, len(applied))
copy(reversed, applied)
slices.Reverse(reversed)
@ -255,52 +253,47 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
n = len(reversed)
}
var ran []string
var ran []Migration
for _, a := range reversed[:n] {
m, ok := findMigration(a, byName, byID)
d, ok := findScript(a, byName, byID)
if !ok {
return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name)
}
if err := r.ExecDown(ctx, m); err != nil {
if err := r.ExecDown(ctx, a, d.Down); err != nil {
return ran, fmt.Errorf("%s (down): %w", a.Name, err)
}
ran = append(ran, a.Name)
ran = append(ran, a)
}
return ran, nil
}
// GetStatus returns applied and pending migration lists.
func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status, error) {
func GetStatus(ctx context.Context, r Migrator, ddls []Script) (*Status, error) {
applied, err := r.Applied(ctx)
if err != nil {
return nil, err
}
appliedNames := make([]string, len(applied))
for i, a := range applied {
appliedNames[i] = a.Name
}
var pending []string
for _, m := range migrations {
if !isApplied(m, applied) {
pending = append(pending, m.Name)
var pending []Migration
for _, d := range ddls {
if !isApplied(d, applied) {
pending = append(pending, d.Migration)
}
}
return &Status{
Applied: appliedNames,
Applied: applied,
Pending: pending,
}, nil
}
// Latest applies all pending migrations. Equivalent to Up(ctx, r, migrations, -1).
func Latest(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) {
return Up(ctx, r, migrations, -1)
// Latest applies all pending migrations. Equivalent to Up(ctx, r, ddls, -1).
func Latest(ctx context.Context, r Migrator, ddls []Script) ([]Migration, error) {
return Up(ctx, r, ddls, -1)
}
// Drop rolls back all applied migrations. Equivalent to Down(ctx, r, migrations, -1).
func Drop(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) {
return Down(ctx, r, migrations, -1)
// Drop rolls back all applied migrations. Equivalent to Down(ctx, r, ddls, -1).
func Drop(ctx context.Context, r Migrator, ddls []Script) ([]Migration, error) {
return Down(ctx, r, ddls, -1)
}

View File

@ -11,45 +11,54 @@ import (
"github.com/therootcompany/golib/database/sqlmigrate"
)
// applied builds an []AppliedMigration from names (IDs empty).
func applied(names ...string) []sqlmigrate.AppliedMigration {
out := make([]sqlmigrate.AppliedMigration, len(names))
// migs builds []Migration from names (IDs empty).
func migs(names ...string) []sqlmigrate.Migration {
out := make([]sqlmigrate.Migration, len(names))
for i, n := range names {
out[i] = sqlmigrate.AppliedMigration{Name: n}
out[i] = sqlmigrate.Migration{Name: n}
}
return out
}
// names extracts just the Name field from a slice of Migration.
func names(ms []sqlmigrate.Migration) []string {
out := make([]string, len(ms))
for i, m := range ms {
out[i] = m.Name
}
return out
}
// mockMigrator tracks applied migrations in memory.
type mockMigrator struct {
applied []sqlmigrate.AppliedMigration
applied []sqlmigrate.Migration
execErr error // if set, ExecUp/ExecDown return this on every call
upCalls []string
downCalls []string
}
func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error {
func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.upCalls = append(m.upCalls, mig.Name)
if m.execErr != nil {
return m.execErr
}
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID})
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int {
m.applied = append(m.applied, mig)
slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int {
return strings.Compare(a.Name, b.Name)
})
return nil
}
func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error {
func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.downCalls = append(m.downCalls, mig.Name)
if m.execErr != nil {
return m.execErr
}
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name })
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name })
return nil
}
func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) {
func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) {
return slices.Clone(m.applied), nil
}
@ -63,24 +72,24 @@ func TestCollect(t *testing.T) {
"001_first.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_first.down.sql": {Data: []byte("DROP TABLE a;")},
}
migrations, err := sqlmigrate.Collect(fsys, ".")
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if len(migrations) != 2 {
t.Fatalf("got %d migrations, want 2", len(migrations))
if len(ddls) != 2 {
t.Fatalf("got %d ddls, want 2", len(ddls))
}
if migrations[0].Name != "001_first" {
t.Errorf("first = %q, want %q", migrations[0].Name, "001_first")
if ddls[0].Name != "001_first" {
t.Errorf("first = %q, want %q", ddls[0].Name, "001_first")
}
if migrations[1].Name != "002_second" {
t.Errorf("second = %q, want %q", migrations[1].Name, "002_second")
if ddls[1].Name != "002_second" {
t.Errorf("second = %q, want %q", ddls[1].Name, "002_second")
}
if migrations[0].Up != "CREATE TABLE a;" {
t.Errorf("first.Up = %q", migrations[0].Up)
if ddls[0].Up != "CREATE TABLE a;" {
t.Errorf("first.Up = %q", ddls[0].Up)
}
if migrations[0].Down != "DROP TABLE a;" {
t.Errorf("first.Down = %q", migrations[0].Down)
if ddls[0].Down != "DROP TABLE a;" {
t.Errorf("first.Down = %q", ddls[0].Down)
}
})
@ -89,12 +98,12 @@ func TestCollect(t *testing.T) {
"001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")},
}
migrations, err := sqlmigrate.Collect(fsys, ".")
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if migrations[0].ID != "abcd1234" {
t.Errorf("ID = %q, want %q", migrations[0].ID, "abcd1234")
if ddls[0].ID != "abcd1234" {
t.Errorf("ID = %q, want %q", ddls[0].ID, "abcd1234")
}
})
@ -103,12 +112,12 @@ func TestCollect(t *testing.T) {
"001_init.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;")},
}
migrations, err := sqlmigrate.Collect(fsys, ".")
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if migrations[0].ID != "" {
t.Errorf("ID = %q, want empty", migrations[0].ID)
if ddls[0].ID != "" {
t.Errorf("ID = %q, want empty", ddls[0].ID)
}
})
@ -139,23 +148,23 @@ func TestCollect(t *testing.T) {
"README.md": {Data: []byte("# Migrations")},
"_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")},
}
migrations, err := sqlmigrate.Collect(fsys, ".")
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if len(migrations) != 1 {
t.Fatalf("got %d migrations, want 1", len(migrations))
if len(ddls) != 1 {
t.Fatalf("got %d ddls, want 1", len(ddls))
}
})
t.Run("empty fs", func(t *testing.T) {
fsys := fstest.MapFS{}
migrations, err := sqlmigrate.Collect(fsys, ".")
ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil {
t.Fatal(err)
}
if len(migrations) != 0 {
t.Fatalf("got %d migrations, want 0", len(migrations))
if len(ddls) != 0 {
t.Fatalf("got %d ddls, want 0", len(ddls))
}
})
}
@ -163,16 +172,16 @@ func TestCollect(t *testing.T) {
// --- NamesOnly ---
func TestNamesOnly(t *testing.T) {
names := []string{"001_init", "002_users"}
migrations := sqlmigrate.NamesOnly(names)
if len(migrations) != 2 {
t.Fatalf("got %d, want 2", len(migrations))
ns := []string{"001_init", "002_users"}
ddls := sqlmigrate.NamesOnly(ns)
if len(ddls) != 2 {
t.Fatalf("got %d, want 2", len(ddls))
}
for i, m := range migrations {
if m.Name != names[i] {
t.Errorf("[%d].Name = %q, want %q", i, m.Name, names[i])
for i, d := range ddls {
if d.Name != ns[i] {
t.Errorf("[%d].Name = %q, want %q", i, d.Name, ns[i])
}
if m.Up != "" || m.Down != "" {
if d.Up != "" || d.Down != "" {
t.Errorf("[%d] has non-empty content", i)
}
}
@ -182,26 +191,26 @@ func TestNamesOnly(t *testing.T) {
func TestUp(t *testing.T) {
ctx := t.Context()
migrations := []sqlmigrate.Migration{
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
ddls := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
}
t.Run("apply all", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ran, []string{"001_init", "002_users", "003_posts"}) {
t.Errorf("applied = %v", ran)
if !slices.Equal(names(ran), []string{"001_init", "002_users", "003_posts"}) {
t.Errorf("applied = %v", names(ran))
}
})
t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{}
_, err := sqlmigrate.Up(ctx, m, migrations, 0)
_, err := sqlmigrate.Up(ctx, m, ddls, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err)
}
@ -209,18 +218,18 @@ func TestUp(t *testing.T) {
t.Run("apply n", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
ran, err := sqlmigrate.Up(ctx, m, ddls, 2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ran, []string{"001_init", "002_users"}) {
t.Errorf("applied = %v", ran)
if !slices.Equal(names(ran), []string{"001_init", "002_users"}) {
t.Errorf("applied = %v", names(ran))
}
})
t.Run("none pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
@ -230,19 +239,19 @@ func TestUp(t *testing.T) {
})
t.Run("partial pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
m := &mockMigrator{applied: migs("001_init")}
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ran, []string{"002_users", "003_posts"}) {
t.Errorf("applied = %v", ran)
if !slices.Equal(names(ran), []string{"002_users", "003_posts"}) {
t.Errorf("applied = %v", names(ran))
}
})
t.Run("n exceeds pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 99)
m := &mockMigrator{applied: migs("001_init")}
ran, err := sqlmigrate.Up(ctx, m, ddls, 99)
if err != nil {
t.Fatal(err)
}
@ -253,7 +262,7 @@ func TestUp(t *testing.T) {
t.Run("exec error stops and returns partial", func(t *testing.T) {
m := &failOnNthMigrator{failAt: 1}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err == nil {
t.Fatal("expected error")
}
@ -264,50 +273,50 @@ func TestUp(t *testing.T) {
t.Run("skips migration matched by ID", func(t *testing.T) {
// DB has migration applied under old name, but same ID
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{
m := &mockMigrator{applied: []sqlmigrate.Migration{
{Name: "001_old-name", ID: "aa11bb22"},
}}
migs := []sqlmigrate.Migration{
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
idDDLs := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
}
ran, err := sqlmigrate.Up(ctx, m, migs, -1)
ran, err := sqlmigrate.Up(ctx, m, idDDLs, -1)
if err != nil {
t.Fatal(err)
}
// Only 002_users should be applied; 001 is matched by ID
if !slices.Equal(ran, []string{"002_users"}) {
t.Errorf("applied = %v, want [002_users]", ran)
if !slices.Equal(names(ran), []string{"002_users"}) {
t.Errorf("applied = %v, want [002_users]", names(ran))
}
})
}
// failOnNthMigrator fails on the Nth ExecUp call (0-indexed).
type failOnNthMigrator struct {
applied []sqlmigrate.AppliedMigration
applied []sqlmigrate.Migration
calls int
failAt int
}
func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error {
func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error {
if m.calls == m.failAt {
m.calls++
return errors.New("connection lost")
}
m.calls++
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID})
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int {
m.applied = append(m.applied, mig)
slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int {
return strings.Compare(a.Name, b.Name)
})
return nil
}
func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error {
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name })
func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name })
return nil
}
func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) {
func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) {
return slices.Clone(m.applied), nil
}
@ -315,51 +324,51 @@ func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigr
func TestDown(t *testing.T) {
ctx := t.Context()
migrations := []sqlmigrate.Migration{
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
ddls := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
}
t.Run("rollback all", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 3 {
t.Fatalf("rolled %d, want 3", len(rolled))
}
if rolled[0] != "003_posts" {
t.Errorf("first rollback = %q, want 003_posts", rolled[0])
if rolled[0].Name != "003_posts" {
t.Errorf("first rollback = %q, want 003_posts", rolled[0].Name)
}
if rolled[2] != "001_init" {
t.Errorf("last rollback = %q, want 001_init", rolled[2])
if rolled[2].Name != "001_init" {
t.Errorf("last rollback = %q, want 001_init", rolled[2].Name)
}
})
t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users")}
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
m := &mockMigrator{applied: migs("001_init", "002_users")}
_, err := sqlmigrate.Down(ctx, m, ddls, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err)
}
})
t.Run("rollback n", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, ddls, 2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(rolled, []string{"003_posts", "002_users"}) {
t.Errorf("rolled = %v", rolled)
if !slices.Equal(names(rolled), []string{"003_posts", "002_users"}) {
t.Errorf("rolled = %v", names(rolled))
}
})
t.Run("none applied", func(t *testing.T) {
m := &mockMigrator{}
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
rolled, err := sqlmigrate.Down(ctx, m, ddls, -1)
if err != nil {
t.Fatal(err)
}
@ -370,10 +379,10 @@ func TestDown(t *testing.T) {
t.Run("exec error", func(t *testing.T) {
m := &mockMigrator{
applied: applied("001_init", "002_users"),
applied: migs("001_init", "002_users"),
execErr: errors.New("permission denied"),
}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 1)
rolled, err := sqlmigrate.Down(ctx, m, ddls, 1)
if err == nil {
t.Fatal("expected error")
}
@ -383,16 +392,16 @@ func TestDown(t *testing.T) {
})
t.Run("unknown migration in applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
_, err := sqlmigrate.Down(ctx, m, migrations, -1)
m := &mockMigrator{applied: migs("001_init", "999_unknown")}
_, err := sqlmigrate.Down(ctx, m, ddls, -1)
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err)
}
})
t.Run("n exceeds applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 99)
m := &mockMigrator{applied: migs("001_init")}
rolled, err := sqlmigrate.Down(ctx, m, ddls, 99)
if err != nil {
t.Fatal(err)
}
@ -403,18 +412,18 @@ func TestDown(t *testing.T) {
t.Run("finds migration by ID when name changed", func(t *testing.T) {
// DB has old name, file has new name, same ID
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{
m := &mockMigrator{applied: []sqlmigrate.Migration{
{Name: "001_old-name", ID: "aa11bb22"},
}}
migs := []sqlmigrate.Migration{
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
idDDLs := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
}
rolled, err := sqlmigrate.Down(ctx, m, migs, 1)
rolled, err := sqlmigrate.Down(ctx, m, idDDLs, 1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(rolled, []string{"001_old-name"}) {
t.Errorf("rolled = %v, want [001_old-name]", rolled)
if !slices.Equal(names(rolled), []string{"001_old-name"}) {
t.Errorf("rolled = %v, want [001_old-name]", names(rolled))
}
})
}
@ -423,15 +432,15 @@ func TestDown(t *testing.T) {
func TestGetStatus(t *testing.T) {
ctx := t.Context()
migrations := []sqlmigrate.Migration{
{Name: "001_init"},
{Name: "002_users"},
{Name: "003_posts"},
ddls := []sqlmigrate.Script{
{Migration: sqlmigrate.Migration{Name: "001_init"}},
{Migration: sqlmigrate.Migration{Name: "002_users"}},
{Migration: sqlmigrate.Migration{Name: "003_posts"}},
}
t.Run("all pending", func(t *testing.T) {
m := &mockMigrator{}
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil {
t.Fatal(err)
}
@ -444,22 +453,22 @@ func TestGetStatus(t *testing.T) {
})
t.Run("partial", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
m := &mockMigrator{applied: migs("001_init")}
status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 1 || status.Applied[0] != "001_init" {
if len(status.Applied) != 1 || status.Applied[0].Name != "001_init" {
t.Errorf("applied = %v", status.Applied)
}
if !slices.Equal(status.Pending, []string{"002_users", "003_posts"}) {
t.Errorf("pending = %v", status.Pending)
if !slices.Equal(names(status.Pending), []string{"002_users", "003_posts"}) {
t.Errorf("pending = %v", names(status.Pending))
}
})
t.Run("all applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil {
t.Fatal(err)
}