fix(sqlmigrate): make n=0 an error in Up/Down, use -1 for "all"

Passing 0 to Up() or Down() is an easy mistake — it silently means
"all" which could be destructive. Now n=0 returns ErrInvalidN.
Convention: n > 0 for a specific count, n < 0 (typically -1) for all.
This commit is contained in:
AJ ONeal 2026-04-08 15:33:17 -06:00
parent 3547b7e409
commit 3e51c7b67a
No known key found for this signature in database
4 changed files with 43 additions and 16 deletions

View File

@ -351,10 +351,10 @@ func main() {
fmt.Println(" ", d)
}
case "up":
var upN int
upN := -1
switch len(leafArgs) {
case 0:
// no arg: upN stays 0, meaning "all pending"
// no arg: apply all pending
case 1:
upN, err = strconv.Atoi(leafArgs[0])
if err != nil || upN < 1 {

View File

@ -14,9 +14,9 @@ Each backend is a separate Go module to avoid pulling unnecessary drivers:
| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql |
| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) |
| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb |
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) |
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (uses native CLI) |
## CLI
The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI
uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime.
uses _shmigrate_ to generate shell scripts for managing migrations without a Go dependency at runtime.

View File

@ -25,6 +25,7 @@ var (
ErrWalkFailed = errors.New("walking migrations")
ErrExecFailed = errors.New("migration exec failed")
ErrQueryApplied = errors.New("querying applied migrations")
ErrInvalidN = errors.New("n must be positive or -1 for all")
)
// Migration represents a paired up/down migration.
@ -178,8 +179,13 @@ func findMigration(a AppliedMigration, byName map[string]Migration, byID map[str
}
// Up applies up to n pending migrations using the given Runner.
// If n <= 0, applies all pending. Returns the names of applied migrations.
// If n < 0, applies all pending. If n == 0, returns ErrInvalidN.
// Returns the names of applied migrations.
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
if n == 0 {
return nil, ErrInvalidN
}
applied, err := r.Applied(ctx)
if err != nil {
return nil, err
@ -192,7 +198,7 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
}
}
if n <= 0 {
if n < 0 {
n = len(pending)
}
if n > len(pending) {
@ -211,8 +217,13 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
}
// Down rolls back up to n applied migrations, most recent first.
// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations.
// If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN.
// Returns the names of rolled-back migrations.
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
if n == 0 {
return nil, ErrInvalidN
}
applied, err := r.Applied(ctx)
if err != nil {
return nil, err
@ -231,7 +242,7 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
copy(reversed, applied)
slices.Reverse(reversed)
if n <= 0 || n > len(reversed) {
if n < 0 || n > len(reversed) {
n = len(reversed)
}

View File

@ -190,7 +190,7 @@ func TestUp(t *testing.T) {
t.Run("apply all", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
if err != nil {
t.Fatal(err)
}
@ -199,6 +199,14 @@ func TestUp(t *testing.T) {
}
})
t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{}
_, err := sqlmigrate.Up(ctx, m, migrations, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err)
}
})
t.Run("apply n", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
@ -212,7 +220,7 @@ func TestUp(t *testing.T) {
t.Run("none pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
if err != nil {
t.Fatal(err)
}
@ -223,7 +231,7 @@ func TestUp(t *testing.T) {
t.Run("partial pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
if err != nil {
t.Fatal(err)
}
@ -245,7 +253,7 @@ func TestUp(t *testing.T) {
t.Run("exec error stops and returns partial", func(t *testing.T) {
m := &failOnNthMigrator{failAt: 1}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
if err == nil {
t.Fatal("expected error")
}
@ -263,7 +271,7 @@ func TestUp(t *testing.T) {
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
}
ran, err := sqlmigrate.Up(ctx, m, migs, 0)
ran, err := sqlmigrate.Up(ctx, m, migs, -1)
if err != nil {
t.Fatal(err)
}
@ -315,7 +323,7 @@ func TestDown(t *testing.T) {
t.Run("rollback all", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
if err != nil {
t.Fatal(err)
}
@ -330,6 +338,14 @@ func TestDown(t *testing.T) {
}
})
t.Run("n=0 is error", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users")}
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
t.Errorf("got %v, want ErrInvalidN", err)
}
})
t.Run("rollback n", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
@ -343,7 +359,7 @@ func TestDown(t *testing.T) {
t.Run("none applied", func(t *testing.T) {
m := &mockMigrator{}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
if err != nil {
t.Fatal(err)
}
@ -368,7 +384,7 @@ func TestDown(t *testing.T) {
t.Run("unknown migration in applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
_, err := sqlmigrate.Down(ctx, m, migrations, -1)
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err)
}