diff --git a/cmd/sql-migrate/go.mod b/cmd/sql-migrate/go.mod index dce9864..36f9ead 100644 --- a/cmd/sql-migrate/go.mod +++ b/cmd/sql-migrate/go.mod @@ -3,10 +3,17 @@ module github.com/therootcompany/golib/cmd/sql-migrate/v2 go 1.26.1 require ( - github.com/therootcompany/golib/database/sqlmigrate v0.0.0 + github.com/jackc/pgx/v5 v5.9.1 + github.com/therootcompany/golib/database/sqlmigrate v1.0.0 github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0 ) +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + golang.org/x/text v0.29.0 // indirect +) + replace ( github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate diff --git a/cmd/sql-migrate/go.sum b/cmd/sql-migrate/go.sum new file mode 100644 index 0000000..8e29ab9 --- /dev/null +++ b/cmd/sql-migrate/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/cmd/sql-migrate/main.go b/cmd/sql-migrate/main.go index 3da9973..54ff2d6 100644 --- a/cmd/sql-migrate/main.go +++ b/cmd/sql-migrate/main.go @@ -939,7 +939,7 @@ func syncLog(runner *shmigrate.Migrator) { fmt.Printf("cat %s\n", logPath) } -func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error { +func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error { // fixup pending migrations before generating the script fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations) @@ -974,7 +974,7 @@ func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrat return nil } -func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error { +func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error { // fixup applied migrations before generating the script fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations) @@ -998,7 +998,7 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr fmt.Println("") // check for missing down files before generating script - reversed := make([]string, len(status.Applied)) + reversed := make([]sqlmigrate.Migration, len(status.Applied)) copy(reversed, status.Applied) slices.Reverse(reversed) limit := n @@ -1008,8 +1008,8 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr if limit > len(reversed) { limit = len(reversed) } - for _, name := range reversed[:limit] { - downPath := filepath.Join(state.MigrationsDir, name+".down.sql") + for _, a := range reversed[:limit] { + downPath := filepath.Join(state.MigrationsDir, a.Name+".down.sql") if !fileExists(downPath) { fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath)) fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n") @@ -1028,7 +1028,7 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr return nil } -func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration) error { +func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script) error { status, err := sqlmigrate.GetStatus(ctx, runner, migrations) if err != nil { return err @@ -1040,21 +1040,21 @@ func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, mi fmt.Fprintf(os.Stderr, "\n") // show applied in reverse (most recent first) - applied := make([]string, len(status.Applied)) - copy(applied, status.Applied) - slices.Reverse(applied) + appliedList := make([]sqlmigrate.Migration, len(status.Applied)) + copy(appliedList, status.Applied) + slices.Reverse(appliedList) - fmt.Printf("# previous: %d\n", len(applied)) - for _, mig := range applied { - fmt.Printf(" %s\n", mig) + fmt.Printf("# previous: %d\n", len(appliedList)) + for _, mig := range appliedList { + fmt.Printf(" %s\n", mig.Name) } - if len(applied) == 0 { + if len(appliedList) == 0 { fmt.Println(" # (no previous migrations)") } fmt.Println("") fmt.Printf("# pending: %d\n", len(status.Pending)) for _, mig := range status.Pending { - fmt.Printf(" %s\n", mig) + fmt.Printf(" %s\n", mig.Name) } if len(status.Pending) == 0 { fmt.Println(" # (no pending migrations)") @@ -1063,7 +1063,7 @@ func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, mi } // fixupAll runs fixupMigration on all known migrations (applied + pending). -func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Migration) (fixedUp, fixedDown []string) { +func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Script) (fixedUp, fixedDown []string) { seen := map[string]bool{} var all []string for _, name := range applied { diff --git a/database/sqlmigrate/shmigrate/go.mod b/database/sqlmigrate/shmigrate/go.mod index ac35638..6726e6a 100644 --- a/database/sqlmigrate/shmigrate/go.mod +++ b/database/sqlmigrate/shmigrate/go.mod @@ -3,3 +3,5 @@ module github.com/therootcompany/golib/database/sqlmigrate/shmigrate go 1.26.1 require github.com/therootcompany/golib/database/sqlmigrate v1.0.0 + +replace github.com/therootcompany/golib/database/sqlmigrate => ../ diff --git a/database/sqlmigrate/shmigrate/shmigrate.go b/database/sqlmigrate/shmigrate/shmigrate.go index 54d6ab3..8dabc4d 100644 --- a/database/sqlmigrate/shmigrate/shmigrate.go +++ b/database/sqlmigrate/shmigrate/shmigrate.go @@ -59,13 +59,15 @@ type Migrator struct { var _ sqlmigrate.Migrator = (*Migrator)(nil) // ExecUp outputs a shell command to run the .up.sql migration file. -func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration) error { +// The sql parameter is ignored — shmigrate references files on disk. +func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration, sql string) error { r.counter++ return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter)) } // ExecDown outputs a shell command to run the .down.sql migration file. -func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration) error { +// The sql parameter is ignored — shmigrate references files on disk. +func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration, sql string) error { r.counter++ return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter)) } @@ -95,7 +97,7 @@ func (r *Migrator) exec(name, suffix, label string) error { // // Returns an empty slice if the file does not exist. When FS is set, reads // from that filesystem; otherwise reads from the OS filesystem. -func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, error) { +func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.Migration, error) { var f io.ReadCloser var err error if r.FS != nil { @@ -111,7 +113,7 @@ func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, } defer f.Close() - var applied []sqlmigrate.AppliedMigration + var applied []sqlmigrate.Migration scanner := bufio.NewScanner(f) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -123,9 +125,9 @@ func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, continue } if id, name, ok := strings.Cut(line, "\t"); ok { - applied = append(applied, sqlmigrate.AppliedMigration{ID: id, Name: name}) + applied = append(applied, sqlmigrate.Migration{ID: id, Name: name}) } else { - applied = append(applied, sqlmigrate.AppliedMigration{Name: line}) + applied = append(applied, sqlmigrate.Migration{Name: line}) } } diff --git a/database/sqlmigrate/sqlmigrate.go b/database/sqlmigrate/sqlmigrate.go index 104e739..58ff60d 100644 --- a/database/sqlmigrate/sqlmigrate.go +++ b/database/sqlmigrate/sqlmigrate.go @@ -28,41 +28,40 @@ var ( ErrInvalidN = errors.New("n must be positive or -1 for all") ) -// Migration represents a paired up/down migration. +// Migration identifies a migration by its name and optional hex ID. type Migration struct { - Name string // e.g. "2026-04-05-001000_create-todos" ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect + Name string // e.g. "2026-04-05-001000_create-todos" +} + +// Script is a Migration with its up and down SQL content, as returned by Collect. +type Script struct { + Migration Up string // SQL content of the .up.sql file Down string // SQL content of the .down.sql file } -// AppliedMigration represents a migration recorded in the _migrations table. -type AppliedMigration struct { - ID string - Name string -} - // Status represents the current migration state. type Status struct { - Applied []string - Pending []string + Applied []Migration + Pending []Migration } // Migrator executes migrations. Implementations handle the // database-specific or output-specific details. type Migrator interface { - // ExecUp runs the up migration. For database migrators this executes - // m.Up in a transaction. For shell migrators this outputs a command - // referencing the .up.sql file. - ExecUp(ctx context.Context, m Migration) error + // ExecUp runs the up migration SQL. For database migrators this + // executes the SQL in a transaction. For shell migrators this + // outputs a command referencing the .up.sql file. + ExecUp(ctx context.Context, m Migration, sql string) error - // ExecDown runs the down migration. - ExecDown(ctx context.Context, m Migration) error + // ExecDown runs the down migration SQL. + ExecDown(ctx context.Context, m Migration, sql string) error // Applied returns all applied migrations from the _migrations table, // sorted lexicographically by name. Returns an empty slice (not an // error) if the migrations table or log does not exist yet. - Applied(ctx context.Context) ([]AppliedMigration, error) + Applied(ctx context.Context) ([]Migration, error) } // idFromInsert extracts the hex ID from an INSERT INTO _migrations line. @@ -75,8 +74,8 @@ var idFromInsert = regexp.MustCompile( // pairs them by basename, and returns them sorted lexicographically by name. // If subpath is "" or ".", the root of fsys is used. // If the up SQL contains an INSERT INTO _migrations line, the hex ID -// is extracted and stored in Migration.ID. -func Collect(fsys fs.FS, subpath string) ([]Migration, error) { +// is extracted and stored in Script.ID. +func Collect(fsys fs.FS, subpath string) ([]Script, error) { if subpath != "" && subpath != "." { var err error fsys, err = fs.Sub(fsys, subpath) @@ -120,7 +119,7 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) { return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err) } - var migrations []Migration + var ddls []Script for name, upSQL := range ups { downSQL, ok := downs[name] if !ok { @@ -130,11 +129,10 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) { if m := idFromInsert.FindStringSubmatch(upSQL); m != nil { id = m[1] } - migrations = append(migrations, Migration{ - Name: name, - ID: id, - Up: upSQL, - Down: downSQL, + ddls = append(ddls, Script{ + Migration: Migration{ID: id, Name: name}, + Up: upSQL, + Down: downSQL, }) } for name := range downs { @@ -143,54 +141,54 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) { } } - slices.SortFunc(migrations, func(a, b Migration) int { + slices.SortFunc(ddls, func(a, b Script) int { return strings.Compare(a.Name, b.Name) }) - return migrations, nil + return ddls, nil } -// NamesOnly builds a Migration slice from a list of names, with empty +// NamesOnly builds a Script slice from a list of names, with empty // Up/Down content. Useful for shell-based runners that reference files // on disk rather than executing SQL directly. -func NamesOnly(names []string) []Migration { - migrations := make([]Migration, len(names)) +func NamesOnly(names []string) []Script { + ddls := make([]Script, len(names)) for i, name := range names { - migrations[i] = Migration{Name: name} + ddls[i] = Script{Migration: Migration{Name: name}} } - return migrations + return ddls } -// isApplied returns true if the migration matches any applied entry by name or ID. -func isApplied(m Migration, applied []AppliedMigration) bool { +// isApplied returns true if the Script matches any applied entry by name or ID. +func isApplied(d Script, applied []Migration) bool { for _, a := range applied { - if a.Name == m.Name { + if a.Name == d.Name { return true } - if m.ID != "" && a.ID != "" && a.ID == m.ID { + if d.ID != "" && a.ID != "" && a.ID == d.ID { return true } } return false } -// findMigration looks up a migration by the applied entry's name or ID. -func findMigration(a AppliedMigration, byName map[string]Migration, byID map[string]Migration) (Migration, bool) { - if m, ok := byName[a.Name]; ok { - return m, true +// findScript looks up a Script by the applied entry's name or ID. +func findScript(a Migration, byName map[string]Script, byID map[string]Script) (Script, bool) { + if d, ok := byName[a.Name]; ok { + return d, true } if a.ID != "" { - if m, ok := byID[a.ID]; ok { - return m, true + if d, ok := byID[a.ID]; ok { + return d, true } } - return Migration{}, false + return Script{}, false } -// Up applies up to n pending migrations using the given Runner. +// Up applies up to n pending migrations using the given Migrator. // If n < 0, applies all pending. If n == 0, returns ErrInvalidN. -// Returns the names of applied migrations. -func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { +// Returns the applied migrations. +func Up(ctx context.Context, r Migrator, ddls []Script, n int) ([]Migration, error) { if n == 0 { return nil, ErrInvalidN } @@ -200,10 +198,10 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin return nil, err } - var pending []Migration - for _, m := range migrations { - if !isApplied(m, applied) { - pending = append(pending, m) + var pending []Script + for _, d := range ddls { + if !isApplied(d, applied) { + pending = append(pending, d) } } @@ -214,12 +212,12 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin n = len(pending) } - var ran []string - for _, m := range pending[:n] { - if err := r.ExecUp(ctx, m); err != nil { - return ran, fmt.Errorf("%s (up): %w", m.Name, err) + var ran []Migration + for _, d := range pending[:n] { + if err := r.ExecUp(ctx, d.Migration, d.Up); err != nil { + return ran, fmt.Errorf("%s (up): %w", d.Name, err) } - ran = append(ran, m.Name) + ran = append(ran, d.Migration) } return ran, nil @@ -227,8 +225,8 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin // Down rolls back up to n applied migrations, most recent first. // If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN. -// Returns the names of rolled-back migrations. -func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { +// Returns the rolled-back migrations. +func Down(ctx context.Context, r Migrator, ddls []Script, n int) ([]Migration, error) { if n == 0 { return nil, ErrInvalidN } @@ -238,16 +236,16 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str return nil, err } - byName := map[string]Migration{} - byID := map[string]Migration{} - for _, m := range migrations { - byName[m.Name] = m - if m.ID != "" { - byID[m.ID] = m + byName := map[string]Script{} + byID := map[string]Script{} + for _, d := range ddls { + byName[d.Name] = d + if d.ID != "" { + byID[d.ID] = d } } - reversed := make([]AppliedMigration, len(applied)) + reversed := make([]Migration, len(applied)) copy(reversed, applied) slices.Reverse(reversed) @@ -255,52 +253,47 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str n = len(reversed) } - var ran []string + var ran []Migration for _, a := range reversed[:n] { - m, ok := findMigration(a, byName, byID) + d, ok := findScript(a, byName, byID) if !ok { return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name) } - if err := r.ExecDown(ctx, m); err != nil { + if err := r.ExecDown(ctx, a, d.Down); err != nil { return ran, fmt.Errorf("%s (down): %w", a.Name, err) } - ran = append(ran, a.Name) + ran = append(ran, a) } return ran, nil } // GetStatus returns applied and pending migration lists. -func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status, error) { +func GetStatus(ctx context.Context, r Migrator, ddls []Script) (*Status, error) { applied, err := r.Applied(ctx) if err != nil { return nil, err } - appliedNames := make([]string, len(applied)) - for i, a := range applied { - appliedNames[i] = a.Name - } - - var pending []string - for _, m := range migrations { - if !isApplied(m, applied) { - pending = append(pending, m.Name) + var pending []Migration + for _, d := range ddls { + if !isApplied(d, applied) { + pending = append(pending, d.Migration) } } return &Status{ - Applied: appliedNames, + Applied: applied, Pending: pending, }, nil } -// Latest applies all pending migrations. Equivalent to Up(ctx, r, migrations, -1). -func Latest(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) { - return Up(ctx, r, migrations, -1) +// Latest applies all pending migrations. Equivalent to Up(ctx, r, ddls, -1). +func Latest(ctx context.Context, r Migrator, ddls []Script) ([]Migration, error) { + return Up(ctx, r, ddls, -1) } -// Drop rolls back all applied migrations. Equivalent to Down(ctx, r, migrations, -1). -func Drop(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) { - return Down(ctx, r, migrations, -1) +// Drop rolls back all applied migrations. Equivalent to Down(ctx, r, ddls, -1). +func Drop(ctx context.Context, r Migrator, ddls []Script) ([]Migration, error) { + return Down(ctx, r, ddls, -1) } diff --git a/database/sqlmigrate/sqlmigrate_test.go b/database/sqlmigrate/sqlmigrate_test.go index 8348e7c..c770515 100644 --- a/database/sqlmigrate/sqlmigrate_test.go +++ b/database/sqlmigrate/sqlmigrate_test.go @@ -11,45 +11,54 @@ import ( "github.com/therootcompany/golib/database/sqlmigrate" ) -// applied builds an []AppliedMigration from names (IDs empty). -func applied(names ...string) []sqlmigrate.AppliedMigration { - out := make([]sqlmigrate.AppliedMigration, len(names)) +// migs builds []Migration from names (IDs empty). +func migs(names ...string) []sqlmigrate.Migration { + out := make([]sqlmigrate.Migration, len(names)) for i, n := range names { - out[i] = sqlmigrate.AppliedMigration{Name: n} + out[i] = sqlmigrate.Migration{Name: n} + } + return out +} + +// names extracts just the Name field from a slice of Migration. +func names(ms []sqlmigrate.Migration) []string { + out := make([]string, len(ms)) + for i, m := range ms { + out[i] = m.Name } return out } // mockMigrator tracks applied migrations in memory. type mockMigrator struct { - applied []sqlmigrate.AppliedMigration + applied []sqlmigrate.Migration execErr error // if set, ExecUp/ExecDown return this on every call upCalls []string downCalls []string } -func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error { +func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error { m.upCalls = append(m.upCalls, mig.Name) if m.execErr != nil { return m.execErr } - m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID}) - slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int { + m.applied = append(m.applied, mig) + slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int { return strings.Compare(a.Name, b.Name) }) return nil } -func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error { +func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error { m.downCalls = append(m.downCalls, mig.Name) if m.execErr != nil { return m.execErr } - m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name }) + m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name }) return nil } -func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) { +func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) { return slices.Clone(m.applied), nil } @@ -63,24 +72,24 @@ func TestCollect(t *testing.T) { "001_first.up.sql": {Data: []byte("CREATE TABLE a;")}, "001_first.down.sql": {Data: []byte("DROP TABLE a;")}, } - migrations, err := sqlmigrate.Collect(fsys, ".") + ddls, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } - if len(migrations) != 2 { - t.Fatalf("got %d migrations, want 2", len(migrations)) + if len(ddls) != 2 { + t.Fatalf("got %d ddls, want 2", len(ddls)) } - if migrations[0].Name != "001_first" { - t.Errorf("first = %q, want %q", migrations[0].Name, "001_first") + if ddls[0].Name != "001_first" { + t.Errorf("first = %q, want %q", ddls[0].Name, "001_first") } - if migrations[1].Name != "002_second" { - t.Errorf("second = %q, want %q", migrations[1].Name, "002_second") + if ddls[1].Name != "002_second" { + t.Errorf("second = %q, want %q", ddls[1].Name, "002_second") } - if migrations[0].Up != "CREATE TABLE a;" { - t.Errorf("first.Up = %q", migrations[0].Up) + if ddls[0].Up != "CREATE TABLE a;" { + t.Errorf("first.Up = %q", ddls[0].Up) } - if migrations[0].Down != "DROP TABLE a;" { - t.Errorf("first.Down = %q", migrations[0].Down) + if ddls[0].Down != "DROP TABLE a;" { + t.Errorf("first.Down = %q", ddls[0].Down) } }) @@ -89,12 +98,12 @@ func TestCollect(t *testing.T) { "001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")}, "001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")}, } - migrations, err := sqlmigrate.Collect(fsys, ".") + ddls, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } - if migrations[0].ID != "abcd1234" { - t.Errorf("ID = %q, want %q", migrations[0].ID, "abcd1234") + if ddls[0].ID != "abcd1234" { + t.Errorf("ID = %q, want %q", ddls[0].ID, "abcd1234") } }) @@ -103,12 +112,12 @@ func TestCollect(t *testing.T) { "001_init.up.sql": {Data: []byte("CREATE TABLE a;")}, "001_init.down.sql": {Data: []byte("DROP TABLE a;")}, } - migrations, err := sqlmigrate.Collect(fsys, ".") + ddls, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } - if migrations[0].ID != "" { - t.Errorf("ID = %q, want empty", migrations[0].ID) + if ddls[0].ID != "" { + t.Errorf("ID = %q, want empty", ddls[0].ID) } }) @@ -139,23 +148,23 @@ func TestCollect(t *testing.T) { "README.md": {Data: []byte("# Migrations")}, "_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")}, } - migrations, err := sqlmigrate.Collect(fsys, ".") + ddls, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } - if len(migrations) != 1 { - t.Fatalf("got %d migrations, want 1", len(migrations)) + if len(ddls) != 1 { + t.Fatalf("got %d ddls, want 1", len(ddls)) } }) t.Run("empty fs", func(t *testing.T) { fsys := fstest.MapFS{} - migrations, err := sqlmigrate.Collect(fsys, ".") + ddls, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } - if len(migrations) != 0 { - t.Fatalf("got %d migrations, want 0", len(migrations)) + if len(ddls) != 0 { + t.Fatalf("got %d ddls, want 0", len(ddls)) } }) } @@ -163,16 +172,16 @@ func TestCollect(t *testing.T) { // --- NamesOnly --- func TestNamesOnly(t *testing.T) { - names := []string{"001_init", "002_users"} - migrations := sqlmigrate.NamesOnly(names) - if len(migrations) != 2 { - t.Fatalf("got %d, want 2", len(migrations)) + ns := []string{"001_init", "002_users"} + ddls := sqlmigrate.NamesOnly(ns) + if len(ddls) != 2 { + t.Fatalf("got %d, want 2", len(ddls)) } - for i, m := range migrations { - if m.Name != names[i] { - t.Errorf("[%d].Name = %q, want %q", i, m.Name, names[i]) + for i, d := range ddls { + if d.Name != ns[i] { + t.Errorf("[%d].Name = %q, want %q", i, d.Name, ns[i]) } - if m.Up != "" || m.Down != "" { + if d.Up != "" || d.Down != "" { t.Errorf("[%d] has non-empty content", i) } } @@ -182,26 +191,26 @@ func TestNamesOnly(t *testing.T) { func TestUp(t *testing.T) { ctx := t.Context() - migrations := []sqlmigrate.Migration{ - {Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, - {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, - {Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, + ddls := []sqlmigrate.Script{ + {Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + {Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, + {Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, } t.Run("apply all", func(t *testing.T) { m := &mockMigrator{} - ran, err := sqlmigrate.Up(ctx, m, migrations, -1) + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) if err != nil { t.Fatal(err) } - if !slices.Equal(ran, []string{"001_init", "002_users", "003_posts"}) { - t.Errorf("applied = %v", ran) + if !slices.Equal(names(ran), []string{"001_init", "002_users", "003_posts"}) { + t.Errorf("applied = %v", names(ran)) } }) t.Run("n=0 is error", func(t *testing.T) { m := &mockMigrator{} - _, err := sqlmigrate.Up(ctx, m, migrations, 0) + _, err := sqlmigrate.Up(ctx, m, ddls, 0) if !errors.Is(err, sqlmigrate.ErrInvalidN) { t.Errorf("got %v, want ErrInvalidN", err) } @@ -209,18 +218,18 @@ func TestUp(t *testing.T) { t.Run("apply n", func(t *testing.T) { m := &mockMigrator{} - ran, err := sqlmigrate.Up(ctx, m, migrations, 2) + ran, err := sqlmigrate.Up(ctx, m, ddls, 2) if err != nil { t.Fatal(err) } - if !slices.Equal(ran, []string{"001_init", "002_users"}) { - t.Errorf("applied = %v", ran) + if !slices.Equal(names(ran), []string{"001_init", "002_users"}) { + t.Errorf("applied = %v", names(ran)) } }) t.Run("none pending", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} - ran, err := sqlmigrate.Up(ctx, m, migrations, -1) + m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")} + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) if err != nil { t.Fatal(err) } @@ -230,19 +239,19 @@ func TestUp(t *testing.T) { }) t.Run("partial pending", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init")} - ran, err := sqlmigrate.Up(ctx, m, migrations, -1) + m := &mockMigrator{applied: migs("001_init")} + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) if err != nil { t.Fatal(err) } - if !slices.Equal(ran, []string{"002_users", "003_posts"}) { - t.Errorf("applied = %v", ran) + if !slices.Equal(names(ran), []string{"002_users", "003_posts"}) { + t.Errorf("applied = %v", names(ran)) } }) t.Run("n exceeds pending", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init")} - ran, err := sqlmigrate.Up(ctx, m, migrations, 99) + m := &mockMigrator{applied: migs("001_init")} + ran, err := sqlmigrate.Up(ctx, m, ddls, 99) if err != nil { t.Fatal(err) } @@ -253,7 +262,7 @@ func TestUp(t *testing.T) { t.Run("exec error stops and returns partial", func(t *testing.T) { m := &failOnNthMigrator{failAt: 1} - ran, err := sqlmigrate.Up(ctx, m, migrations, -1) + ran, err := sqlmigrate.Up(ctx, m, ddls, -1) if err == nil { t.Fatal("expected error") } @@ -264,50 +273,50 @@ func TestUp(t *testing.T) { t.Run("skips migration matched by ID", func(t *testing.T) { // DB has migration applied under old name, but same ID - m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{ + m := &mockMigrator{applied: []sqlmigrate.Migration{ {Name: "001_old-name", ID: "aa11bb22"}, }} - migs := []sqlmigrate.Migration{ - {Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, - {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, + idDDLs := []sqlmigrate.Script{ + {Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + {Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, } - ran, err := sqlmigrate.Up(ctx, m, migs, -1) + ran, err := sqlmigrate.Up(ctx, m, idDDLs, -1) if err != nil { t.Fatal(err) } // Only 002_users should be applied; 001 is matched by ID - if !slices.Equal(ran, []string{"002_users"}) { - t.Errorf("applied = %v, want [002_users]", ran) + if !slices.Equal(names(ran), []string{"002_users"}) { + t.Errorf("applied = %v, want [002_users]", names(ran)) } }) } // failOnNthMigrator fails on the Nth ExecUp call (0-indexed). type failOnNthMigrator struct { - applied []sqlmigrate.AppliedMigration + applied []sqlmigrate.Migration calls int failAt int } -func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error { +func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error { if m.calls == m.failAt { m.calls++ return errors.New("connection lost") } m.calls++ - m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID}) - slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int { + m.applied = append(m.applied, mig) + slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int { return strings.Compare(a.Name, b.Name) }) return nil } -func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error { - m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name }) +func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error { + m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name }) return nil } -func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) { +func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) { return slices.Clone(m.applied), nil } @@ -315,51 +324,51 @@ func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigr func TestDown(t *testing.T) { ctx := t.Context() - migrations := []sqlmigrate.Migration{ - {Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, - {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, - {Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, + ddls := []sqlmigrate.Script{ + {Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + {Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, + {Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, } t.Run("rollback all", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} - rolled, err := sqlmigrate.Down(ctx, m, migrations, -1) + m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")} + rolled, err := sqlmigrate.Down(ctx, m, ddls, -1) if err != nil { t.Fatal(err) } if len(rolled) != 3 { t.Fatalf("rolled %d, want 3", len(rolled)) } - if rolled[0] != "003_posts" { - t.Errorf("first rollback = %q, want 003_posts", rolled[0]) + if rolled[0].Name != "003_posts" { + t.Errorf("first rollback = %q, want 003_posts", rolled[0].Name) } - if rolled[2] != "001_init" { - t.Errorf("last rollback = %q, want 001_init", rolled[2]) + if rolled[2].Name != "001_init" { + t.Errorf("last rollback = %q, want 001_init", rolled[2].Name) } }) t.Run("n=0 is error", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init", "002_users")} - _, err := sqlmigrate.Down(ctx, m, migrations, 0) + m := &mockMigrator{applied: migs("001_init", "002_users")} + _, err := sqlmigrate.Down(ctx, m, ddls, 0) if !errors.Is(err, sqlmigrate.ErrInvalidN) { t.Errorf("got %v, want ErrInvalidN", err) } }) t.Run("rollback n", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} - rolled, err := sqlmigrate.Down(ctx, m, migrations, 2) + m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")} + rolled, err := sqlmigrate.Down(ctx, m, ddls, 2) if err != nil { t.Fatal(err) } - if !slices.Equal(rolled, []string{"003_posts", "002_users"}) { - t.Errorf("rolled = %v", rolled) + if !slices.Equal(names(rolled), []string{"003_posts", "002_users"}) { + t.Errorf("rolled = %v", names(rolled)) } }) t.Run("none applied", func(t *testing.T) { m := &mockMigrator{} - rolled, err := sqlmigrate.Down(ctx, m, migrations, -1) + rolled, err := sqlmigrate.Down(ctx, m, ddls, -1) if err != nil { t.Fatal(err) } @@ -370,10 +379,10 @@ func TestDown(t *testing.T) { t.Run("exec error", func(t *testing.T) { m := &mockMigrator{ - applied: applied("001_init", "002_users"), + applied: migs("001_init", "002_users"), execErr: errors.New("permission denied"), } - rolled, err := sqlmigrate.Down(ctx, m, migrations, 1) + rolled, err := sqlmigrate.Down(ctx, m, ddls, 1) if err == nil { t.Fatal("expected error") } @@ -383,16 +392,16 @@ func TestDown(t *testing.T) { }) t.Run("unknown migration in applied", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init", "999_unknown")} - _, err := sqlmigrate.Down(ctx, m, migrations, -1) + m := &mockMigrator{applied: migs("001_init", "999_unknown")} + _, err := sqlmigrate.Down(ctx, m, ddls, -1) if !errors.Is(err, sqlmigrate.ErrMissingDown) { t.Errorf("got %v, want ErrMissingDown", err) } }) t.Run("n exceeds applied", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init")} - rolled, err := sqlmigrate.Down(ctx, m, migrations, 99) + m := &mockMigrator{applied: migs("001_init")} + rolled, err := sqlmigrate.Down(ctx, m, ddls, 99) if err != nil { t.Fatal(err) } @@ -403,18 +412,18 @@ func TestDown(t *testing.T) { t.Run("finds migration by ID when name changed", func(t *testing.T) { // DB has old name, file has new name, same ID - m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{ + m := &mockMigrator{applied: []sqlmigrate.Migration{ {Name: "001_old-name", ID: "aa11bb22"}, }} - migs := []sqlmigrate.Migration{ - {Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + idDDLs := []sqlmigrate.Script{ + {Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, } - rolled, err := sqlmigrate.Down(ctx, m, migs, 1) + rolled, err := sqlmigrate.Down(ctx, m, idDDLs, 1) if err != nil { t.Fatal(err) } - if !slices.Equal(rolled, []string{"001_old-name"}) { - t.Errorf("rolled = %v, want [001_old-name]", rolled) + if !slices.Equal(names(rolled), []string{"001_old-name"}) { + t.Errorf("rolled = %v, want [001_old-name]", names(rolled)) } }) } @@ -423,15 +432,15 @@ func TestDown(t *testing.T) { func TestGetStatus(t *testing.T) { ctx := t.Context() - migrations := []sqlmigrate.Migration{ - {Name: "001_init"}, - {Name: "002_users"}, - {Name: "003_posts"}, + ddls := []sqlmigrate.Script{ + {Migration: sqlmigrate.Migration{Name: "001_init"}}, + {Migration: sqlmigrate.Migration{Name: "002_users"}}, + {Migration: sqlmigrate.Migration{Name: "003_posts"}}, } t.Run("all pending", func(t *testing.T) { m := &mockMigrator{} - status, err := sqlmigrate.GetStatus(ctx, m, migrations) + status, err := sqlmigrate.GetStatus(ctx, m, ddls) if err != nil { t.Fatal(err) } @@ -444,22 +453,22 @@ func TestGetStatus(t *testing.T) { }) t.Run("partial", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init")} - status, err := sqlmigrate.GetStatus(ctx, m, migrations) + m := &mockMigrator{applied: migs("001_init")} + status, err := sqlmigrate.GetStatus(ctx, m, ddls) if err != nil { t.Fatal(err) } - if len(status.Applied) != 1 || status.Applied[0] != "001_init" { + if len(status.Applied) != 1 || status.Applied[0].Name != "001_init" { t.Errorf("applied = %v", status.Applied) } - if !slices.Equal(status.Pending, []string{"002_users", "003_posts"}) { - t.Errorf("pending = %v", status.Pending) + if !slices.Equal(names(status.Pending), []string{"002_users", "003_posts"}) { + t.Errorf("pending = %v", names(status.Pending)) } }) t.Run("all applied", func(t *testing.T) { - m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} - status, err := sqlmigrate.GetStatus(ctx, m, migrations) + m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")} + status, err := sqlmigrate.GetStatus(ctx, m, ddls) if err != nil { t.Fatal(err) }