mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 04:38:02 +00:00
fix(sqlmigrate): make n=0 an error in Up/Down, use -1 for "all"
Passing 0 to Up() or Down() is an easy mistake — it silently means "all" which could be destructive. Now n=0 returns ErrInvalidN. Convention: n > 0 for a specific count, n < 0 (typically -1) for all.
This commit is contained in:
parent
3547b7e409
commit
3e51c7b67a
@ -351,10 +351,10 @@ func main() {
|
||||
fmt.Println(" ", d)
|
||||
}
|
||||
case "up":
|
||||
var upN int
|
||||
upN := -1
|
||||
switch len(leafArgs) {
|
||||
case 0:
|
||||
// no arg: upN stays 0, meaning "all pending"
|
||||
// no arg: apply all pending
|
||||
case 1:
|
||||
upN, err = strconv.Atoi(leafArgs[0])
|
||||
if err != nil || upN < 1 {
|
||||
|
||||
@ -14,9 +14,9 @@ Each backend is a separate Go module to avoid pulling unnecessary drivers:
|
||||
| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql |
|
||||
| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) |
|
||||
| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb |
|
||||
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) |
|
||||
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (uses native CLI) |
|
||||
|
||||
## CLI
|
||||
|
||||
The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI
|
||||
uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime.
|
||||
uses _shmigrate_ to generate shell scripts for managing migrations without a Go dependency at runtime.
|
||||
|
||||
@ -25,6 +25,7 @@ var (
|
||||
ErrWalkFailed = errors.New("walking migrations")
|
||||
ErrExecFailed = errors.New("migration exec failed")
|
||||
ErrQueryApplied = errors.New("querying applied migrations")
|
||||
ErrInvalidN = errors.New("n must be positive or -1 for all")
|
||||
)
|
||||
|
||||
// Migration represents a paired up/down migration.
|
||||
@ -178,8 +179,13 @@ func findMigration(a AppliedMigration, byName map[string]Migration, byID map[str
|
||||
}
|
||||
|
||||
// Up applies up to n pending migrations using the given Runner.
|
||||
// If n <= 0, applies all pending. Returns the names of applied migrations.
|
||||
// If n < 0, applies all pending. If n == 0, returns ErrInvalidN.
|
||||
// Returns the names of applied migrations.
|
||||
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
||||
if n == 0 {
|
||||
return nil, ErrInvalidN
|
||||
}
|
||||
|
||||
applied, err := r.Applied(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -192,7 +198,7 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
|
||||
}
|
||||
}
|
||||
|
||||
if n <= 0 {
|
||||
if n < 0 {
|
||||
n = len(pending)
|
||||
}
|
||||
if n > len(pending) {
|
||||
@ -211,8 +217,13 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
|
||||
}
|
||||
|
||||
// Down rolls back up to n applied migrations, most recent first.
|
||||
// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations.
|
||||
// If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN.
|
||||
// Returns the names of rolled-back migrations.
|
||||
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
||||
if n == 0 {
|
||||
return nil, ErrInvalidN
|
||||
}
|
||||
|
||||
applied, err := r.Applied(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -231,7 +242,7 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
|
||||
copy(reversed, applied)
|
||||
slices.Reverse(reversed)
|
||||
|
||||
if n <= 0 || n > len(reversed) {
|
||||
if n < 0 || n > len(reversed) {
|
||||
n = len(reversed)
|
||||
}
|
||||
|
||||
|
||||
@ -190,7 +190,7 @@ func TestUp(t *testing.T) {
|
||||
|
||||
t.Run("apply all", func(t *testing.T) {
|
||||
m := &mockMigrator{}
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -199,6 +199,14 @@ func TestUp(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("n=0 is error", func(t *testing.T) {
|
||||
m := &mockMigrator{}
|
||||
_, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
|
||||
t.Errorf("got %v, want ErrInvalidN", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("apply n", func(t *testing.T) {
|
||||
m := &mockMigrator{}
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
|
||||
@ -212,7 +220,7 @@ func TestUp(t *testing.T) {
|
||||
|
||||
t.Run("none pending", func(t *testing.T) {
|
||||
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -223,7 +231,7 @@ func TestUp(t *testing.T) {
|
||||
|
||||
t.Run("partial pending", func(t *testing.T) {
|
||||
m := &mockMigrator{applied: applied("001_init")}
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -245,7 +253,7 @@ func TestUp(t *testing.T) {
|
||||
|
||||
t.Run("exec error stops and returns partial", func(t *testing.T) {
|
||||
m := &failOnNthMigrator{failAt: 1}
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
@ -263,7 +271,7 @@ func TestUp(t *testing.T) {
|
||||
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
||||
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
|
||||
}
|
||||
ran, err := sqlmigrate.Up(ctx, m, migs, 0)
|
||||
ran, err := sqlmigrate.Up(ctx, m, migs, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -315,7 +323,7 @@ func TestDown(t *testing.T) {
|
||||
|
||||
t.Run("rollback all", func(t *testing.T) {
|
||||
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -330,6 +338,14 @@ func TestDown(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("n=0 is error", func(t *testing.T) {
|
||||
m := &mockMigrator{applied: applied("001_init", "002_users")}
|
||||
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
|
||||
t.Errorf("got %v, want ErrInvalidN", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rollback n", func(t *testing.T) {
|
||||
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
|
||||
@ -343,7 +359,7 @@ func TestDown(t *testing.T) {
|
||||
|
||||
t.Run("none applied", func(t *testing.T) {
|
||||
m := &mockMigrator{}
|
||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -368,7 +384,7 @@ func TestDown(t *testing.T) {
|
||||
|
||||
t.Run("unknown migration in applied", func(t *testing.T) {
|
||||
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
|
||||
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||
_, err := sqlmigrate.Down(ctx, m, migrations, -1)
|
||||
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
|
||||
t.Errorf("got %v, want ErrMissingDown", err)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user