ref(sqlmigrate): split Migration into Migration + Script

Migration{ID, Name} is the identity type returned by Up/Down/Latest/Drop.
Script{Migration, Up, Down} holds collected SQL content from Collect().

Migrator interface now takes SQL as a separate parameter:
  ExecUp(ctx, Migration, sql string)
  ExecDown(ctx, Migration, sql string)

This separates identity from content — callers that track what ran
don't need to carry around SQL strings they'll never use.

Updates shmigrate to match (ignores the sql parameter, references
files on disk instead).
This commit is contained in:
AJ ONeal 2026-04-09 03:21:46 -06:00
parent 285eb0b684
commit 75e4a883b7
No known key found for this signature in database
7 changed files with 263 additions and 224 deletions

View File

@ -3,10 +3,17 @@ module github.com/therootcompany/golib/cmd/sql-migrate/v2
go 1.26.1 go 1.26.1
require ( require (
github.com/therootcompany/golib/database/sqlmigrate v0.0.0 github.com/jackc/pgx/v5 v5.9.1
github.com/therootcompany/golib/database/sqlmigrate v1.0.0
github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0 github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0
) )
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
golang.org/x/text v0.29.0 // indirect
)
replace ( replace (
github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate
github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate

26
cmd/sql-migrate/go.sum Normal file
View File

@ -0,0 +1,26 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -939,7 +939,7 @@ func syncLog(runner *shmigrate.Migrator) {
fmt.Printf("cat %s\n", logPath) fmt.Printf("cat %s\n", logPath)
} }
func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error { func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error {
// fixup pending migrations before generating the script // fixup pending migrations before generating the script
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations) fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
@ -974,7 +974,7 @@ func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrat
return nil return nil
} }
func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error { func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error {
// fixup applied migrations before generating the script // fixup applied migrations before generating the script
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations) fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
@ -998,7 +998,7 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr
fmt.Println("") fmt.Println("")
// check for missing down files before generating script // check for missing down files before generating script
reversed := make([]string, len(status.Applied)) reversed := make([]sqlmigrate.Migration, len(status.Applied))
copy(reversed, status.Applied) copy(reversed, status.Applied)
slices.Reverse(reversed) slices.Reverse(reversed)
limit := n limit := n
@ -1008,8 +1008,8 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr
if limit > len(reversed) { if limit > len(reversed) {
limit = len(reversed) limit = len(reversed)
} }
for _, name := range reversed[:limit] { for _, a := range reversed[:limit] {
downPath := filepath.Join(state.MigrationsDir, name+".down.sql") downPath := filepath.Join(state.MigrationsDir, a.Name+".down.sql")
if !fileExists(downPath) { if !fileExists(downPath) {
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath)) fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n") fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
@ -1028,7 +1028,7 @@ func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migr
return nil return nil
} }
func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration) error { func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script) error {
status, err := sqlmigrate.GetStatus(ctx, runner, migrations) status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
if err != nil { if err != nil {
return err return err
@ -1040,21 +1040,21 @@ func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, mi
fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "\n")
// show applied in reverse (most recent first) // show applied in reverse (most recent first)
applied := make([]string, len(status.Applied)) appliedList := make([]sqlmigrate.Migration, len(status.Applied))
copy(applied, status.Applied) copy(appliedList, status.Applied)
slices.Reverse(applied) slices.Reverse(appliedList)
fmt.Printf("# previous: %d\n", len(applied)) fmt.Printf("# previous: %d\n", len(appliedList))
for _, mig := range applied { for _, mig := range appliedList {
fmt.Printf(" %s\n", mig) fmt.Printf(" %s\n", mig.Name)
} }
if len(applied) == 0 { if len(appliedList) == 0 {
fmt.Println(" # (no previous migrations)") fmt.Println(" # (no previous migrations)")
} }
fmt.Println("") fmt.Println("")
fmt.Printf("# pending: %d\n", len(status.Pending)) fmt.Printf("# pending: %d\n", len(status.Pending))
for _, mig := range status.Pending { for _, mig := range status.Pending {
fmt.Printf(" %s\n", mig) fmt.Printf(" %s\n", mig.Name)
} }
if len(status.Pending) == 0 { if len(status.Pending) == 0 {
fmt.Println(" # (no pending migrations)") fmt.Println(" # (no pending migrations)")
@ -1063,7 +1063,7 @@ func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, mi
} }
// fixupAll runs fixupMigration on all known migrations (applied + pending). // fixupAll runs fixupMigration on all known migrations (applied + pending).
func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Migration) (fixedUp, fixedDown []string) { func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Script) (fixedUp, fixedDown []string) {
seen := map[string]bool{} seen := map[string]bool{}
var all []string var all []string
for _, name := range applied { for _, name := range applied {

View File

@ -3,3 +3,5 @@ module github.com/therootcompany/golib/database/sqlmigrate/shmigrate
go 1.26.1 go 1.26.1
require github.com/therootcompany/golib/database/sqlmigrate v1.0.0 require github.com/therootcompany/golib/database/sqlmigrate v1.0.0
replace github.com/therootcompany/golib/database/sqlmigrate => ../

View File

@ -59,13 +59,15 @@ type Migrator struct {
var _ sqlmigrate.Migrator = (*Migrator)(nil) var _ sqlmigrate.Migrator = (*Migrator)(nil)
// ExecUp outputs a shell command to run the .up.sql migration file. // ExecUp outputs a shell command to run the .up.sql migration file.
func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration) error { // The sql parameter is ignored — shmigrate references files on disk.
func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration, sql string) error {
r.counter++ r.counter++
return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter)) return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter))
} }
// ExecDown outputs a shell command to run the .down.sql migration file. // ExecDown outputs a shell command to run the .down.sql migration file.
func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration) error { // The sql parameter is ignored — shmigrate references files on disk.
func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration, sql string) error {
r.counter++ r.counter++
return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter)) return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter))
} }
@ -95,7 +97,7 @@ func (r *Migrator) exec(name, suffix, label string) error {
// //
// Returns an empty slice if the file does not exist. When FS is set, reads // Returns an empty slice if the file does not exist. When FS is set, reads
// from that filesystem; otherwise reads from the OS filesystem. // from that filesystem; otherwise reads from the OS filesystem.
func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, error) { func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.Migration, error) {
var f io.ReadCloser var f io.ReadCloser
var err error var err error
if r.FS != nil { if r.FS != nil {
@ -111,7 +113,7 @@ func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration,
} }
defer f.Close() defer f.Close()
var applied []sqlmigrate.AppliedMigration var applied []sqlmigrate.Migration
scanner := bufio.NewScanner(f) scanner := bufio.NewScanner(f)
for scanner.Scan() { for scanner.Scan() {
line := strings.TrimSpace(scanner.Text()) line := strings.TrimSpace(scanner.Text())
@ -123,9 +125,9 @@ func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration,
continue continue
} }
if id, name, ok := strings.Cut(line, "\t"); ok { if id, name, ok := strings.Cut(line, "\t"); ok {
applied = append(applied, sqlmigrate.AppliedMigration{ID: id, Name: name}) applied = append(applied, sqlmigrate.Migration{ID: id, Name: name})
} else { } else {
applied = append(applied, sqlmigrate.AppliedMigration{Name: line}) applied = append(applied, sqlmigrate.Migration{Name: line})
} }
} }

View File

@ -28,41 +28,40 @@ var (
ErrInvalidN = errors.New("n must be positive or -1 for all") ErrInvalidN = errors.New("n must be positive or -1 for all")
) )
// Migration represents a paired up/down migration. // Migration identifies a migration by its name and optional hex ID.
type Migration struct { type Migration struct {
Name string // e.g. "2026-04-05-001000_create-todos"
ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect
Name string // e.g. "2026-04-05-001000_create-todos"
}
// Script is a Migration with its up and down SQL content, as returned by Collect.
type Script struct {
Migration
Up string // SQL content of the .up.sql file Up string // SQL content of the .up.sql file
Down string // SQL content of the .down.sql file Down string // SQL content of the .down.sql file
} }
// AppliedMigration represents a migration recorded in the _migrations table.
type AppliedMigration struct {
ID string
Name string
}
// Status represents the current migration state. // Status represents the current migration state.
type Status struct { type Status struct {
Applied []string Applied []Migration
Pending []string Pending []Migration
} }
// Migrator executes migrations. Implementations handle the // Migrator executes migrations. Implementations handle the
// database-specific or output-specific details. // database-specific or output-specific details.
type Migrator interface { type Migrator interface {
// ExecUp runs the up migration. For database migrators this executes // ExecUp runs the up migration SQL. For database migrators this
// m.Up in a transaction. For shell migrators this outputs a command // executes the SQL in a transaction. For shell migrators this
// referencing the .up.sql file. // outputs a command referencing the .up.sql file.
ExecUp(ctx context.Context, m Migration) error ExecUp(ctx context.Context, m Migration, sql string) error
// ExecDown runs the down migration. // ExecDown runs the down migration SQL.
ExecDown(ctx context.Context, m Migration) error ExecDown(ctx context.Context, m Migration, sql string) error
// Applied returns all applied migrations from the _migrations table, // Applied returns all applied migrations from the _migrations table,
// sorted lexicographically by name. Returns an empty slice (not an // sorted lexicographically by name. Returns an empty slice (not an
// error) if the migrations table or log does not exist yet. // error) if the migrations table or log does not exist yet.
Applied(ctx context.Context) ([]AppliedMigration, error) Applied(ctx context.Context) ([]Migration, error)
} }
// idFromInsert extracts the hex ID from an INSERT INTO _migrations line. // idFromInsert extracts the hex ID from an INSERT INTO _migrations line.
@ -75,8 +74,8 @@ var idFromInsert = regexp.MustCompile(
// pairs them by basename, and returns them sorted lexicographically by name. // pairs them by basename, and returns them sorted lexicographically by name.
// If subpath is "" or ".", the root of fsys is used. // If subpath is "" or ".", the root of fsys is used.
// If the up SQL contains an INSERT INTO _migrations line, the hex ID // If the up SQL contains an INSERT INTO _migrations line, the hex ID
// is extracted and stored in Migration.ID. // is extracted and stored in Script.ID.
func Collect(fsys fs.FS, subpath string) ([]Migration, error) { func Collect(fsys fs.FS, subpath string) ([]Script, error) {
if subpath != "" && subpath != "." { if subpath != "" && subpath != "." {
var err error var err error
fsys, err = fs.Sub(fsys, subpath) fsys, err = fs.Sub(fsys, subpath)
@ -120,7 +119,7 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err) return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err)
} }
var migrations []Migration var ddls []Script
for name, upSQL := range ups { for name, upSQL := range ups {
downSQL, ok := downs[name] downSQL, ok := downs[name]
if !ok { if !ok {
@ -130,11 +129,10 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
if m := idFromInsert.FindStringSubmatch(upSQL); m != nil { if m := idFromInsert.FindStringSubmatch(upSQL); m != nil {
id = m[1] id = m[1]
} }
migrations = append(migrations, Migration{ ddls = append(ddls, Script{
Name: name, Migration: Migration{ID: id, Name: name},
ID: id, Up: upSQL,
Up: upSQL, Down: downSQL,
Down: downSQL,
}) })
} }
for name := range downs { for name := range downs {
@ -143,54 +141,54 @@ func Collect(fsys fs.FS, subpath string) ([]Migration, error) {
} }
} }
slices.SortFunc(migrations, func(a, b Migration) int { slices.SortFunc(ddls, func(a, b Script) int {
return strings.Compare(a.Name, b.Name) return strings.Compare(a.Name, b.Name)
}) })
return migrations, nil return ddls, nil
} }
// NamesOnly builds a Migration slice from a list of names, with empty // NamesOnly builds a Script slice from a list of names, with empty
// Up/Down content. Useful for shell-based runners that reference files // Up/Down content. Useful for shell-based runners that reference files
// on disk rather than executing SQL directly. // on disk rather than executing SQL directly.
func NamesOnly(names []string) []Migration { func NamesOnly(names []string) []Script {
migrations := make([]Migration, len(names)) ddls := make([]Script, len(names))
for i, name := range names { for i, name := range names {
migrations[i] = Migration{Name: name} ddls[i] = Script{Migration: Migration{Name: name}}
} }
return migrations return ddls
} }
// isApplied returns true if the migration matches any applied entry by name or ID. // isApplied returns true if the Script matches any applied entry by name or ID.
func isApplied(m Migration, applied []AppliedMigration) bool { func isApplied(d Script, applied []Migration) bool {
for _, a := range applied { for _, a := range applied {
if a.Name == m.Name { if a.Name == d.Name {
return true return true
} }
if m.ID != "" && a.ID != "" && a.ID == m.ID { if d.ID != "" && a.ID != "" && a.ID == d.ID {
return true return true
} }
} }
return false return false
} }
// findMigration looks up a migration by the applied entry's name or ID. // findScript looks up a Script by the applied entry's name or ID.
func findMigration(a AppliedMigration, byName map[string]Migration, byID map[string]Migration) (Migration, bool) { func findScript(a Migration, byName map[string]Script, byID map[string]Script) (Script, bool) {
if m, ok := byName[a.Name]; ok { if d, ok := byName[a.Name]; ok {
return m, true return d, true
} }
if a.ID != "" { if a.ID != "" {
if m, ok := byID[a.ID]; ok { if d, ok := byID[a.ID]; ok {
return m, true return d, true
} }
} }
return Migration{}, false return Script{}, false
} }
// Up applies up to n pending migrations using the given Runner. // Up applies up to n pending migrations using the given Migrator.
// If n < 0, applies all pending. If n == 0, returns ErrInvalidN. // If n < 0, applies all pending. If n == 0, returns ErrInvalidN.
// Returns the names of applied migrations. // Returns the applied migrations.
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { func Up(ctx context.Context, r Migrator, ddls []Script, n int) ([]Migration, error) {
if n == 0 { if n == 0 {
return nil, ErrInvalidN return nil, ErrInvalidN
} }
@ -200,10 +198,10 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
return nil, err return nil, err
} }
var pending []Migration var pending []Script
for _, m := range migrations { for _, d := range ddls {
if !isApplied(m, applied) { if !isApplied(d, applied) {
pending = append(pending, m) pending = append(pending, d)
} }
} }
@ -214,12 +212,12 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
n = len(pending) n = len(pending)
} }
var ran []string var ran []Migration
for _, m := range pending[:n] { for _, d := range pending[:n] {
if err := r.ExecUp(ctx, m); err != nil { if err := r.ExecUp(ctx, d.Migration, d.Up); err != nil {
return ran, fmt.Errorf("%s (up): %w", m.Name, err) return ran, fmt.Errorf("%s (up): %w", d.Name, err)
} }
ran = append(ran, m.Name) ran = append(ran, d.Migration)
} }
return ran, nil return ran, nil
@ -227,8 +225,8 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
// Down rolls back up to n applied migrations, most recent first. // Down rolls back up to n applied migrations, most recent first.
// If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN. // If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN.
// Returns the names of rolled-back migrations. // Returns the rolled-back migrations.
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { func Down(ctx context.Context, r Migrator, ddls []Script, n int) ([]Migration, error) {
if n == 0 { if n == 0 {
return nil, ErrInvalidN return nil, ErrInvalidN
} }
@ -238,16 +236,16 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
return nil, err return nil, err
} }
byName := map[string]Migration{} byName := map[string]Script{}
byID := map[string]Migration{} byID := map[string]Script{}
for _, m := range migrations { for _, d := range ddls {
byName[m.Name] = m byName[d.Name] = d
if m.ID != "" { if d.ID != "" {
byID[m.ID] = m byID[d.ID] = d
} }
} }
reversed := make([]AppliedMigration, len(applied)) reversed := make([]Migration, len(applied))
copy(reversed, applied) copy(reversed, applied)
slices.Reverse(reversed) slices.Reverse(reversed)
@ -255,52 +253,47 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
n = len(reversed) n = len(reversed)
} }
var ran []string var ran []Migration
for _, a := range reversed[:n] { for _, a := range reversed[:n] {
m, ok := findMigration(a, byName, byID) d, ok := findScript(a, byName, byID)
if !ok { if !ok {
return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name) return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name)
} }
if err := r.ExecDown(ctx, m); err != nil { if err := r.ExecDown(ctx, a, d.Down); err != nil {
return ran, fmt.Errorf("%s (down): %w", a.Name, err) return ran, fmt.Errorf("%s (down): %w", a.Name, err)
} }
ran = append(ran, a.Name) ran = append(ran, a)
} }
return ran, nil return ran, nil
} }
// GetStatus returns applied and pending migration lists. // GetStatus returns applied and pending migration lists.
func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status, error) { func GetStatus(ctx context.Context, r Migrator, ddls []Script) (*Status, error) {
applied, err := r.Applied(ctx) applied, err := r.Applied(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
appliedNames := make([]string, len(applied)) var pending []Migration
for i, a := range applied { for _, d := range ddls {
appliedNames[i] = a.Name if !isApplied(d, applied) {
} pending = append(pending, d.Migration)
var pending []string
for _, m := range migrations {
if !isApplied(m, applied) {
pending = append(pending, m.Name)
} }
} }
return &Status{ return &Status{
Applied: appliedNames, Applied: applied,
Pending: pending, Pending: pending,
}, nil }, nil
} }
// Latest applies all pending migrations. Equivalent to Up(ctx, r, migrations, -1). // Latest applies all pending migrations. Equivalent to Up(ctx, r, ddls, -1).
func Latest(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) { func Latest(ctx context.Context, r Migrator, ddls []Script) ([]Migration, error) {
return Up(ctx, r, migrations, -1) return Up(ctx, r, ddls, -1)
} }
// Drop rolls back all applied migrations. Equivalent to Down(ctx, r, migrations, -1). // Drop rolls back all applied migrations. Equivalent to Down(ctx, r, ddls, -1).
func Drop(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) { func Drop(ctx context.Context, r Migrator, ddls []Script) ([]Migration, error) {
return Down(ctx, r, migrations, -1) return Down(ctx, r, ddls, -1)
} }

View File

@ -11,45 +11,54 @@ import (
"github.com/therootcompany/golib/database/sqlmigrate" "github.com/therootcompany/golib/database/sqlmigrate"
) )
// applied builds an []AppliedMigration from names (IDs empty). // migs builds []Migration from names (IDs empty).
func applied(names ...string) []sqlmigrate.AppliedMigration { func migs(names ...string) []sqlmigrate.Migration {
out := make([]sqlmigrate.AppliedMigration, len(names)) out := make([]sqlmigrate.Migration, len(names))
for i, n := range names { for i, n := range names {
out[i] = sqlmigrate.AppliedMigration{Name: n} out[i] = sqlmigrate.Migration{Name: n}
}
return out
}
// names extracts just the Name field from a slice of Migration.
func names(ms []sqlmigrate.Migration) []string {
out := make([]string, len(ms))
for i, m := range ms {
out[i] = m.Name
} }
return out return out
} }
// mockMigrator tracks applied migrations in memory. // mockMigrator tracks applied migrations in memory.
type mockMigrator struct { type mockMigrator struct {
applied []sqlmigrate.AppliedMigration applied []sqlmigrate.Migration
execErr error // if set, ExecUp/ExecDown return this on every call execErr error // if set, ExecUp/ExecDown return this on every call
upCalls []string upCalls []string
downCalls []string downCalls []string
} }
func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error { func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.upCalls = append(m.upCalls, mig.Name) m.upCalls = append(m.upCalls, mig.Name)
if m.execErr != nil { if m.execErr != nil {
return m.execErr return m.execErr
} }
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID}) m.applied = append(m.applied, mig)
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int { slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int {
return strings.Compare(a.Name, b.Name) return strings.Compare(a.Name, b.Name)
}) })
return nil return nil
} }
func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error { func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.downCalls = append(m.downCalls, mig.Name) m.downCalls = append(m.downCalls, mig.Name)
if m.execErr != nil { if m.execErr != nil {
return m.execErr return m.execErr
} }
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name }) m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name })
return nil return nil
} }
func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) { func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) {
return slices.Clone(m.applied), nil return slices.Clone(m.applied), nil
} }
@ -63,24 +72,24 @@ func TestCollect(t *testing.T) {
"001_first.up.sql": {Data: []byte("CREATE TABLE a;")}, "001_first.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_first.down.sql": {Data: []byte("DROP TABLE a;")}, "001_first.down.sql": {Data: []byte("DROP TABLE a;")},
} }
migrations, err := sqlmigrate.Collect(fsys, ".") ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(migrations) != 2 { if len(ddls) != 2 {
t.Fatalf("got %d migrations, want 2", len(migrations)) t.Fatalf("got %d ddls, want 2", len(ddls))
} }
if migrations[0].Name != "001_first" { if ddls[0].Name != "001_first" {
t.Errorf("first = %q, want %q", migrations[0].Name, "001_first") t.Errorf("first = %q, want %q", ddls[0].Name, "001_first")
} }
if migrations[1].Name != "002_second" { if ddls[1].Name != "002_second" {
t.Errorf("second = %q, want %q", migrations[1].Name, "002_second") t.Errorf("second = %q, want %q", ddls[1].Name, "002_second")
} }
if migrations[0].Up != "CREATE TABLE a;" { if ddls[0].Up != "CREATE TABLE a;" {
t.Errorf("first.Up = %q", migrations[0].Up) t.Errorf("first.Up = %q", ddls[0].Up)
} }
if migrations[0].Down != "DROP TABLE a;" { if ddls[0].Down != "DROP TABLE a;" {
t.Errorf("first.Down = %q", migrations[0].Down) t.Errorf("first.Down = %q", ddls[0].Down)
} }
}) })
@ -89,12 +98,12 @@ func TestCollect(t *testing.T) {
"001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")}, "001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")}, "001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")},
} }
migrations, err := sqlmigrate.Collect(fsys, ".") ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if migrations[0].ID != "abcd1234" { if ddls[0].ID != "abcd1234" {
t.Errorf("ID = %q, want %q", migrations[0].ID, "abcd1234") t.Errorf("ID = %q, want %q", ddls[0].ID, "abcd1234")
} }
}) })
@ -103,12 +112,12 @@ func TestCollect(t *testing.T) {
"001_init.up.sql": {Data: []byte("CREATE TABLE a;")}, "001_init.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;")}, "001_init.down.sql": {Data: []byte("DROP TABLE a;")},
} }
migrations, err := sqlmigrate.Collect(fsys, ".") ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if migrations[0].ID != "" { if ddls[0].ID != "" {
t.Errorf("ID = %q, want empty", migrations[0].ID) t.Errorf("ID = %q, want empty", ddls[0].ID)
} }
}) })
@ -139,23 +148,23 @@ func TestCollect(t *testing.T) {
"README.md": {Data: []byte("# Migrations")}, "README.md": {Data: []byte("# Migrations")},
"_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")}, "_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")},
} }
migrations, err := sqlmigrate.Collect(fsys, ".") ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(migrations) != 1 { if len(ddls) != 1 {
t.Fatalf("got %d migrations, want 1", len(migrations)) t.Fatalf("got %d ddls, want 1", len(ddls))
} }
}) })
t.Run("empty fs", func(t *testing.T) { t.Run("empty fs", func(t *testing.T) {
fsys := fstest.MapFS{} fsys := fstest.MapFS{}
migrations, err := sqlmigrate.Collect(fsys, ".") ddls, err := sqlmigrate.Collect(fsys, ".")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(migrations) != 0 { if len(ddls) != 0 {
t.Fatalf("got %d migrations, want 0", len(migrations)) t.Fatalf("got %d ddls, want 0", len(ddls))
} }
}) })
} }
@ -163,16 +172,16 @@ func TestCollect(t *testing.T) {
// --- NamesOnly --- // --- NamesOnly ---
func TestNamesOnly(t *testing.T) { func TestNamesOnly(t *testing.T) {
names := []string{"001_init", "002_users"} ns := []string{"001_init", "002_users"}
migrations := sqlmigrate.NamesOnly(names) ddls := sqlmigrate.NamesOnly(ns)
if len(migrations) != 2 { if len(ddls) != 2 {
t.Fatalf("got %d, want 2", len(migrations)) t.Fatalf("got %d, want 2", len(ddls))
} }
for i, m := range migrations { for i, d := range ddls {
if m.Name != names[i] { if d.Name != ns[i] {
t.Errorf("[%d].Name = %q, want %q", i, m.Name, names[i]) t.Errorf("[%d].Name = %q, want %q", i, d.Name, ns[i])
} }
if m.Up != "" || m.Down != "" { if d.Up != "" || d.Down != "" {
t.Errorf("[%d] has non-empty content", i) t.Errorf("[%d] has non-empty content", i)
} }
} }
@ -182,26 +191,26 @@ func TestNamesOnly(t *testing.T) {
func TestUp(t *testing.T) { func TestUp(t *testing.T) {
ctx := t.Context() ctx := t.Context()
migrations := []sqlmigrate.Migration{ ddls := []sqlmigrate.Script{
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, {Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, {Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, {Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
} }
t.Run("apply all", func(t *testing.T) { t.Run("apply all", func(t *testing.T) {
m := &mockMigrator{} m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1) ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !slices.Equal(ran, []string{"001_init", "002_users", "003_posts"}) { if !slices.Equal(names(ran), []string{"001_init", "002_users", "003_posts"}) {
t.Errorf("applied = %v", ran) t.Errorf("applied = %v", names(ran))
} }
}) })
t.Run("n=0 is error", func(t *testing.T) { t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{} m := &mockMigrator{}
_, err := sqlmigrate.Up(ctx, m, migrations, 0) _, err := sqlmigrate.Up(ctx, m, ddls, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) { if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err) t.Errorf("got %v, want ErrInvalidN", err)
} }
@ -209,18 +218,18 @@ func TestUp(t *testing.T) {
t.Run("apply n", func(t *testing.T) { t.Run("apply n", func(t *testing.T) {
m := &mockMigrator{} m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, 2) ran, err := sqlmigrate.Up(ctx, m, ddls, 2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !slices.Equal(ran, []string{"001_init", "002_users"}) { if !slices.Equal(names(ran), []string{"001_init", "002_users"}) {
t.Errorf("applied = %v", ran) t.Errorf("applied = %v", names(ran))
} }
}) })
t.Run("none pending", func(t *testing.T) { t.Run("none pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1) ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -230,19 +239,19 @@ func TestUp(t *testing.T) {
}) })
t.Run("partial pending", func(t *testing.T) { t.Run("partial pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")} m := &mockMigrator{applied: migs("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1) ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !slices.Equal(ran, []string{"002_users", "003_posts"}) { if !slices.Equal(names(ran), []string{"002_users", "003_posts"}) {
t.Errorf("applied = %v", ran) t.Errorf("applied = %v", names(ran))
} }
}) })
t.Run("n exceeds pending", func(t *testing.T) { t.Run("n exceeds pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")} m := &mockMigrator{applied: migs("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 99) ran, err := sqlmigrate.Up(ctx, m, ddls, 99)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -253,7 +262,7 @@ func TestUp(t *testing.T) {
t.Run("exec error stops and returns partial", func(t *testing.T) { t.Run("exec error stops and returns partial", func(t *testing.T) {
m := &failOnNthMigrator{failAt: 1} m := &failOnNthMigrator{failAt: 1}
ran, err := sqlmigrate.Up(ctx, m, migrations, -1) ran, err := sqlmigrate.Up(ctx, m, ddls, -1)
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
@ -264,50 +273,50 @@ func TestUp(t *testing.T) {
t.Run("skips migration matched by ID", func(t *testing.T) { t.Run("skips migration matched by ID", func(t *testing.T) {
// DB has migration applied under old name, but same ID // DB has migration applied under old name, but same ID
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{ m := &mockMigrator{applied: []sqlmigrate.Migration{
{Name: "001_old-name", ID: "aa11bb22"}, {Name: "001_old-name", ID: "aa11bb22"},
}} }}
migs := []sqlmigrate.Migration{ idDDLs := []sqlmigrate.Script{
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, {Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, {Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
} }
ran, err := sqlmigrate.Up(ctx, m, migs, -1) ran, err := sqlmigrate.Up(ctx, m, idDDLs, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Only 002_users should be applied; 001 is matched by ID // Only 002_users should be applied; 001 is matched by ID
if !slices.Equal(ran, []string{"002_users"}) { if !slices.Equal(names(ran), []string{"002_users"}) {
t.Errorf("applied = %v, want [002_users]", ran) t.Errorf("applied = %v, want [002_users]", names(ran))
} }
}) })
} }
// failOnNthMigrator fails on the Nth ExecUp call (0-indexed). // failOnNthMigrator fails on the Nth ExecUp call (0-indexed).
type failOnNthMigrator struct { type failOnNthMigrator struct {
applied []sqlmigrate.AppliedMigration applied []sqlmigrate.Migration
calls int calls int
failAt int failAt int
} }
func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error { func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration, _ string) error {
if m.calls == m.failAt { if m.calls == m.failAt {
m.calls++ m.calls++
return errors.New("connection lost") return errors.New("connection lost")
} }
m.calls++ m.calls++
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID}) m.applied = append(m.applied, mig)
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int { slices.SortFunc(m.applied, func(a, b sqlmigrate.Migration) int {
return strings.Compare(a.Name, b.Name) return strings.Compare(a.Name, b.Name)
}) })
return nil return nil
} }
func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error { func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration, _ string) error {
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name }) m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.Migration) bool { return a.Name == mig.Name })
return nil return nil
} }
func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) { func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.Migration, error) {
return slices.Clone(m.applied), nil return slices.Clone(m.applied), nil
} }
@ -315,51 +324,51 @@ func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigr
func TestDown(t *testing.T) { func TestDown(t *testing.T) {
ctx := t.Context() ctx := t.Context()
migrations := []sqlmigrate.Migration{ ddls := []sqlmigrate.Script{
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, {Migration: sqlmigrate.Migration{Name: "001_init"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, {Migration: sqlmigrate.Migration{Name: "002_users"}, Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, {Migration: sqlmigrate.Migration{Name: "003_posts"}, Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
} }
t.Run("rollback all", func(t *testing.T) { t.Run("rollback all", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1) rolled, err := sqlmigrate.Down(ctx, m, ddls, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(rolled) != 3 { if len(rolled) != 3 {
t.Fatalf("rolled %d, want 3", len(rolled)) t.Fatalf("rolled %d, want 3", len(rolled))
} }
if rolled[0] != "003_posts" { if rolled[0].Name != "003_posts" {
t.Errorf("first rollback = %q, want 003_posts", rolled[0]) t.Errorf("first rollback = %q, want 003_posts", rolled[0].Name)
} }
if rolled[2] != "001_init" { if rolled[2].Name != "001_init" {
t.Errorf("last rollback = %q, want 001_init", rolled[2]) t.Errorf("last rollback = %q, want 001_init", rolled[2].Name)
} }
}) })
t.Run("n=0 is error", func(t *testing.T) { t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users")} m := &mockMigrator{applied: migs("001_init", "002_users")}
_, err := sqlmigrate.Down(ctx, m, migrations, 0) _, err := sqlmigrate.Down(ctx, m, ddls, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) { if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err) t.Errorf("got %v, want ErrInvalidN", err)
} }
}) })
t.Run("rollback n", func(t *testing.T) { t.Run("rollback n", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2) rolled, err := sqlmigrate.Down(ctx, m, ddls, 2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !slices.Equal(rolled, []string{"003_posts", "002_users"}) { if !slices.Equal(names(rolled), []string{"003_posts", "002_users"}) {
t.Errorf("rolled = %v", rolled) t.Errorf("rolled = %v", names(rolled))
} }
}) })
t.Run("none applied", func(t *testing.T) { t.Run("none applied", func(t *testing.T) {
m := &mockMigrator{} m := &mockMigrator{}
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1) rolled, err := sqlmigrate.Down(ctx, m, ddls, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -370,10 +379,10 @@ func TestDown(t *testing.T) {
t.Run("exec error", func(t *testing.T) { t.Run("exec error", func(t *testing.T) {
m := &mockMigrator{ m := &mockMigrator{
applied: applied("001_init", "002_users"), applied: migs("001_init", "002_users"),
execErr: errors.New("permission denied"), execErr: errors.New("permission denied"),
} }
rolled, err := sqlmigrate.Down(ctx, m, migrations, 1) rolled, err := sqlmigrate.Down(ctx, m, ddls, 1)
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
@ -383,16 +392,16 @@ func TestDown(t *testing.T) {
}) })
t.Run("unknown migration in applied", func(t *testing.T) { t.Run("unknown migration in applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "999_unknown")} m := &mockMigrator{applied: migs("001_init", "999_unknown")}
_, err := sqlmigrate.Down(ctx, m, migrations, -1) _, err := sqlmigrate.Down(ctx, m, ddls, -1)
if !errors.Is(err, sqlmigrate.ErrMissingDown) { if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err) t.Errorf("got %v, want ErrMissingDown", err)
} }
}) })
t.Run("n exceeds applied", func(t *testing.T) { t.Run("n exceeds applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")} m := &mockMigrator{applied: migs("001_init")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 99) rolled, err := sqlmigrate.Down(ctx, m, ddls, 99)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -403,18 +412,18 @@ func TestDown(t *testing.T) {
t.Run("finds migration by ID when name changed", func(t *testing.T) { t.Run("finds migration by ID when name changed", func(t *testing.T) {
// DB has old name, file has new name, same ID // DB has old name, file has new name, same ID
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{ m := &mockMigrator{applied: []sqlmigrate.Migration{
{Name: "001_old-name", ID: "aa11bb22"}, {Name: "001_old-name", ID: "aa11bb22"},
}} }}
migs := []sqlmigrate.Migration{ idDDLs := []sqlmigrate.Script{
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, {Migration: sqlmigrate.Migration{Name: "001_new-name", ID: "aa11bb22"}, Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
} }
rolled, err := sqlmigrate.Down(ctx, m, migs, 1) rolled, err := sqlmigrate.Down(ctx, m, idDDLs, 1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !slices.Equal(rolled, []string{"001_old-name"}) { if !slices.Equal(names(rolled), []string{"001_old-name"}) {
t.Errorf("rolled = %v, want [001_old-name]", rolled) t.Errorf("rolled = %v, want [001_old-name]", names(rolled))
} }
}) })
} }
@ -423,15 +432,15 @@ func TestDown(t *testing.T) {
func TestGetStatus(t *testing.T) { func TestGetStatus(t *testing.T) {
ctx := t.Context() ctx := t.Context()
migrations := []sqlmigrate.Migration{ ddls := []sqlmigrate.Script{
{Name: "001_init"}, {Migration: sqlmigrate.Migration{Name: "001_init"}},
{Name: "002_users"}, {Migration: sqlmigrate.Migration{Name: "002_users"}},
{Name: "003_posts"}, {Migration: sqlmigrate.Migration{Name: "003_posts"}},
} }
t.Run("all pending", func(t *testing.T) { t.Run("all pending", func(t *testing.T) {
m := &mockMigrator{} m := &mockMigrator{}
status, err := sqlmigrate.GetStatus(ctx, m, migrations) status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -444,22 +453,22 @@ func TestGetStatus(t *testing.T) {
}) })
t.Run("partial", func(t *testing.T) { t.Run("partial", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")} m := &mockMigrator{applied: migs("001_init")}
status, err := sqlmigrate.GetStatus(ctx, m, migrations) status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(status.Applied) != 1 || status.Applied[0] != "001_init" { if len(status.Applied) != 1 || status.Applied[0].Name != "001_init" {
t.Errorf("applied = %v", status.Applied) t.Errorf("applied = %v", status.Applied)
} }
if !slices.Equal(status.Pending, []string{"002_users", "003_posts"}) { if !slices.Equal(names(status.Pending), []string{"002_users", "003_posts"}) {
t.Errorf("pending = %v", status.Pending) t.Errorf("pending = %v", names(status.Pending))
} }
}) })
t.Run("all applied", func(t *testing.T) { t.Run("all applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} m := &mockMigrator{applied: migs("001_init", "002_users", "003_posts")}
status, err := sqlmigrate.GetStatus(ctx, m, migrations) status, err := sqlmigrate.GetStatus(ctx, m, ddls)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }