diff --git a/cmd/sql-migrate/main.go b/cmd/sql-migrate/main.go index cbfb6ab..3eab340 100644 --- a/cmd/sql-migrate/main.go +++ b/cmd/sql-migrate/main.go @@ -351,10 +351,10 @@ func main() { fmt.Println(" ", d) } case "up": - var upN int + upN := -1 switch len(leafArgs) { case 0: - // no arg: upN stays 0, meaning "all pending" + // no arg: apply all pending case 1: upN, err = strconv.Atoi(leafArgs[0]) if err != nil || upN < 1 { diff --git a/database/sqlmigrate/README.md b/database/sqlmigrate/README.md index 48ac473..e0f307a 100644 --- a/database/sqlmigrate/README.md +++ b/database/sqlmigrate/README.md @@ -14,9 +14,9 @@ Each backend is a separate Go module to avoid pulling unnecessary drivers: | [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql | | [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) | | [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb | -| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) | +| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (uses native CLI) | ## CLI The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI -uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime. +uses _shmigrate_ to generate shell scripts for managing migrations without a Go dependency at runtime. diff --git a/database/sqlmigrate/sqlmigrate.go b/database/sqlmigrate/sqlmigrate.go index c5f7a14..c627c8f 100644 --- a/database/sqlmigrate/sqlmigrate.go +++ b/database/sqlmigrate/sqlmigrate.go @@ -25,6 +25,7 @@ var ( ErrWalkFailed = errors.New("walking migrations") ErrExecFailed = errors.New("migration exec failed") ErrQueryApplied = errors.New("querying applied migrations") + ErrInvalidN = errors.New("n must be positive or -1 for all") ) // Migration represents a paired up/down migration. @@ -178,8 +179,13 @@ func findMigration(a AppliedMigration, byName map[string]Migration, byID map[str } // Up applies up to n pending migrations using the given Runner. -// If n <= 0, applies all pending. Returns the names of applied migrations. +// If n < 0, applies all pending. If n == 0, returns ErrInvalidN. +// Returns the names of applied migrations. func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { + if n == 0 { + return nil, ErrInvalidN + } + applied, err := r.Applied(ctx) if err != nil { return nil, err @@ -192,7 +198,7 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin } } - if n <= 0 { + if n < 0 { n = len(pending) } if n > len(pending) { @@ -211,8 +217,13 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin } // Down rolls back up to n applied migrations, most recent first. -// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations. +// If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN. +// Returns the names of rolled-back migrations. func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { + if n == 0 { + return nil, ErrInvalidN + } + applied, err := r.Applied(ctx) if err != nil { return nil, err @@ -231,7 +242,7 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str copy(reversed, applied) slices.Reverse(reversed) - if n <= 0 || n > len(reversed) { + if n < 0 || n > len(reversed) { n = len(reversed) } diff --git a/database/sqlmigrate/sqlmigrate_test.go b/database/sqlmigrate/sqlmigrate_test.go index 5111cf0..21576cc 100644 --- a/database/sqlmigrate/sqlmigrate_test.go +++ b/database/sqlmigrate/sqlmigrate_test.go @@ -190,7 +190,7 @@ func TestUp(t *testing.T) { t.Run("apply all", func(t *testing.T) { m := &mockMigrator{} - ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + ran, err := sqlmigrate.Up(ctx, m, migrations, -1) if err != nil { t.Fatal(err) } @@ -199,6 +199,14 @@ func TestUp(t *testing.T) { } }) + t.Run("n=0 is error", func(t *testing.T) { + m := &mockMigrator{} + _, err := sqlmigrate.Up(ctx, m, migrations, 0) + if !errors.Is(err, sqlmigrate.ErrInvalidN) { + t.Errorf("got %v, want ErrInvalidN", err) + } + }) + t.Run("apply n", func(t *testing.T) { m := &mockMigrator{} ran, err := sqlmigrate.Up(ctx, m, migrations, 2) @@ -212,7 +220,7 @@ func TestUp(t *testing.T) { t.Run("none pending", func(t *testing.T) { m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} - ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + ran, err := sqlmigrate.Up(ctx, m, migrations, -1) if err != nil { t.Fatal(err) } @@ -223,7 +231,7 @@ func TestUp(t *testing.T) { t.Run("partial pending", func(t *testing.T) { m := &mockMigrator{applied: applied("001_init")} - ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + ran, err := sqlmigrate.Up(ctx, m, migrations, -1) if err != nil { t.Fatal(err) } @@ -245,7 +253,7 @@ func TestUp(t *testing.T) { t.Run("exec error stops and returns partial", func(t *testing.T) { m := &failOnNthMigrator{failAt: 1} - ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + ran, err := sqlmigrate.Up(ctx, m, migrations, -1) if err == nil { t.Fatal("expected error") } @@ -263,7 +271,7 @@ func TestUp(t *testing.T) { {Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, } - ran, err := sqlmigrate.Up(ctx, m, migs, 0) + ran, err := sqlmigrate.Up(ctx, m, migs, -1) if err != nil { t.Fatal(err) } @@ -315,7 +323,7 @@ func TestDown(t *testing.T) { t.Run("rollback all", func(t *testing.T) { m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} - rolled, err := sqlmigrate.Down(ctx, m, migrations, 0) + rolled, err := sqlmigrate.Down(ctx, m, migrations, -1) if err != nil { t.Fatal(err) } @@ -330,6 +338,14 @@ func TestDown(t *testing.T) { } }) + t.Run("n=0 is error", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init", "002_users")} + _, err := sqlmigrate.Down(ctx, m, migrations, 0) + if !errors.Is(err, sqlmigrate.ErrInvalidN) { + t.Errorf("got %v, want ErrInvalidN", err) + } + }) + t.Run("rollback n", func(t *testing.T) { m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} rolled, err := sqlmigrate.Down(ctx, m, migrations, 2) @@ -343,7 +359,7 @@ func TestDown(t *testing.T) { t.Run("none applied", func(t *testing.T) { m := &mockMigrator{} - rolled, err := sqlmigrate.Down(ctx, m, migrations, 0) + rolled, err := sqlmigrate.Down(ctx, m, migrations, -1) if err != nil { t.Fatal(err) } @@ -368,7 +384,7 @@ func TestDown(t *testing.T) { t.Run("unknown migration in applied", func(t *testing.T) { m := &mockMigrator{applied: applied("001_init", "999_unknown")} - _, err := sqlmigrate.Down(ctx, m, migrations, 0) + _, err := sqlmigrate.Down(ctx, m, migrations, -1) if !errors.Is(err, sqlmigrate.ErrMissingDown) { t.Errorf("got %v, want ErrMissingDown", err) }