mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 04:38:02 +00:00
fix(sqlmigrate): make n=0 an error in Up/Down, use -1 for "all"
Passing 0 to Up() or Down() is an easy mistake — it silently means "all" which could be destructive. Now n=0 returns ErrInvalidN. Convention: n > 0 for a specific count, n < 0 (typically -1) for all.
This commit is contained in:
parent
3547b7e409
commit
3e51c7b67a
@ -351,10 +351,10 @@ func main() {
|
|||||||
fmt.Println(" ", d)
|
fmt.Println(" ", d)
|
||||||
}
|
}
|
||||||
case "up":
|
case "up":
|
||||||
var upN int
|
upN := -1
|
||||||
switch len(leafArgs) {
|
switch len(leafArgs) {
|
||||||
case 0:
|
case 0:
|
||||||
// no arg: upN stays 0, meaning "all pending"
|
// no arg: apply all pending
|
||||||
case 1:
|
case 1:
|
||||||
upN, err = strconv.Atoi(leafArgs[0])
|
upN, err = strconv.Atoi(leafArgs[0])
|
||||||
if err != nil || upN < 1 {
|
if err != nil || upN < 1 {
|
||||||
|
|||||||
@ -14,9 +14,9 @@ Each backend is a separate Go module to avoid pulling unnecessary drivers:
|
|||||||
| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql |
|
| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql |
|
||||||
| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) |
|
| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) |
|
||||||
| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb |
|
| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb |
|
||||||
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) |
|
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (uses native CLI) |
|
||||||
|
|
||||||
## CLI
|
## CLI
|
||||||
|
|
||||||
The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI
|
The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI
|
||||||
uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime.
|
uses _shmigrate_ to generate shell scripts for managing migrations without a Go dependency at runtime.
|
||||||
|
|||||||
@ -25,6 +25,7 @@ var (
|
|||||||
ErrWalkFailed = errors.New("walking migrations")
|
ErrWalkFailed = errors.New("walking migrations")
|
||||||
ErrExecFailed = errors.New("migration exec failed")
|
ErrExecFailed = errors.New("migration exec failed")
|
||||||
ErrQueryApplied = errors.New("querying applied migrations")
|
ErrQueryApplied = errors.New("querying applied migrations")
|
||||||
|
ErrInvalidN = errors.New("n must be positive or -1 for all")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Migration represents a paired up/down migration.
|
// Migration represents a paired up/down migration.
|
||||||
@ -178,8 +179,13 @@ func findMigration(a AppliedMigration, byName map[string]Migration, byID map[str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Up applies up to n pending migrations using the given Runner.
|
// Up applies up to n pending migrations using the given Runner.
|
||||||
// If n <= 0, applies all pending. Returns the names of applied migrations.
|
// If n < 0, applies all pending. If n == 0, returns ErrInvalidN.
|
||||||
|
// Returns the names of applied migrations.
|
||||||
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, ErrInvalidN
|
||||||
|
}
|
||||||
|
|
||||||
applied, err := r.Applied(ctx)
|
applied, err := r.Applied(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -192,7 +198,7 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if n <= 0 {
|
if n < 0 {
|
||||||
n = len(pending)
|
n = len(pending)
|
||||||
}
|
}
|
||||||
if n > len(pending) {
|
if n > len(pending) {
|
||||||
@ -211,8 +217,13 @@ func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Down rolls back up to n applied migrations, most recent first.
|
// Down rolls back up to n applied migrations, most recent first.
|
||||||
// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations.
|
// If n < 0, rolls back all applied. If n == 0, returns ErrInvalidN.
|
||||||
|
// Returns the names of rolled-back migrations.
|
||||||
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, ErrInvalidN
|
||||||
|
}
|
||||||
|
|
||||||
applied, err := r.Applied(ctx)
|
applied, err := r.Applied(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -231,7 +242,7 @@ func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]str
|
|||||||
copy(reversed, applied)
|
copy(reversed, applied)
|
||||||
slices.Reverse(reversed)
|
slices.Reverse(reversed)
|
||||||
|
|
||||||
if n <= 0 || n > len(reversed) {
|
if n < 0 || n > len(reversed) {
|
||||||
n = len(reversed)
|
n = len(reversed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -190,7 +190,7 @@ func TestUp(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("apply all", func(t *testing.T) {
|
t.Run("apply all", func(t *testing.T) {
|
||||||
m := &mockMigrator{}
|
m := &mockMigrator{}
|
||||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -199,6 +199,14 @@ func TestUp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("n=0 is error", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{}
|
||||||
|
_, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||||
|
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
|
||||||
|
t.Errorf("got %v, want ErrInvalidN", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("apply n", func(t *testing.T) {
|
t.Run("apply n", func(t *testing.T) {
|
||||||
m := &mockMigrator{}
|
m := &mockMigrator{}
|
||||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
|
||||||
@ -212,7 +220,7 @@ func TestUp(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("none pending", func(t *testing.T) {
|
t.Run("none pending", func(t *testing.T) {
|
||||||
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -223,7 +231,7 @@ func TestUp(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("partial pending", func(t *testing.T) {
|
t.Run("partial pending", func(t *testing.T) {
|
||||||
m := &mockMigrator{applied: applied("001_init")}
|
m := &mockMigrator{applied: applied("001_init")}
|
||||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -245,7 +253,7 @@ func TestUp(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("exec error stops and returns partial", func(t *testing.T) {
|
t.Run("exec error stops and returns partial", func(t *testing.T) {
|
||||||
m := &failOnNthMigrator{failAt: 1}
|
m := &failOnNthMigrator{failAt: 1}
|
||||||
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
ran, err := sqlmigrate.Up(ctx, m, migrations, -1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error")
|
t.Fatal("expected error")
|
||||||
}
|
}
|
||||||
@ -263,7 +271,7 @@ func TestUp(t *testing.T) {
|
|||||||
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
||||||
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
|
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
|
||||||
}
|
}
|
||||||
ran, err := sqlmigrate.Up(ctx, m, migs, 0)
|
ran, err := sqlmigrate.Up(ctx, m, migs, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -315,7 +323,7 @@ func TestDown(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("rollback all", func(t *testing.T) {
|
t.Run("rollback all", func(t *testing.T) {
|
||||||
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -330,6 +338,14 @@ func TestDown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("n=0 is error", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init", "002_users")}
|
||||||
|
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||||
|
if !errors.Is(err, sqlmigrate.ErrInvalidN) {
|
||||||
|
t.Errorf("got %v, want ErrInvalidN", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("rollback n", func(t *testing.T) {
|
t.Run("rollback n", func(t *testing.T) {
|
||||||
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
|
||||||
@ -343,7 +359,7 @@ func TestDown(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("none applied", func(t *testing.T) {
|
t.Run("none applied", func(t *testing.T) {
|
||||||
m := &mockMigrator{}
|
m := &mockMigrator{}
|
||||||
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -368,7 +384,7 @@ func TestDown(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("unknown migration in applied", func(t *testing.T) {
|
t.Run("unknown migration in applied", func(t *testing.T) {
|
||||||
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
|
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
|
||||||
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
_, err := sqlmigrate.Down(ctx, m, migrations, -1)
|
||||||
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
|
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
|
||||||
t.Errorf("got %v, want ErrMissingDown", err)
|
t.Errorf("got %v, want ErrMissingDown", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user