mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 04:38:02 +00:00
ref(database/sqlmigrate): extract migration library with shmigrate backend
Factor the inline migration logic from cmd/sql-migrate into reusable packages: database/sqlmigrate (core types, matching, file collection) and database/sqlmigrate/shmigrate (shell script generation backend). No behavior changes — the CLI produces identical output. The shmigrate package implements the sqlmigrate.Migrator interface so other backends (pgmigrate, mymigrate, etc.) can follow the same pattern.
This commit is contained in:
parent
c4964a5b65
commit
3547b7e409
@ -1,3 +1,13 @@
|
|||||||
module github.com/therootcompany/golib/cmd/sql-migrate/v2
|
module github.com/therootcompany/golib/cmd/sql-migrate/v2
|
||||||
|
|
||||||
go 1.25.3
|
go 1.26.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/therootcompany/golib/database/sqlmigrate v0.0.0
|
||||||
|
github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
replace (
|
||||||
|
github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate
|
||||||
|
github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate
|
||||||
|
)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"flag"
|
"flag"
|
||||||
@ -27,6 +28,9 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/database/sqlmigrate"
|
||||||
|
"github.com/therootcompany/golib/database/sqlmigrate/shmigrate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// replaced by goreleaser / ldflags
|
// replaced by goreleaser / ldflags
|
||||||
@ -67,14 +71,6 @@ DROP TABLE IF EXISTS _migrations;
|
|||||||
`
|
`
|
||||||
LOG_MIGRATIONS_QUERY = `-- note: CLI arguments must be passed to the sql command to keep output clean
|
LOG_MIGRATIONS_QUERY = `-- note: CLI arguments must be passed to the sql command to keep output clean
|
||||||
SELECT name FROM _migrations ORDER BY name;
|
SELECT name FROM _migrations ORDER BY name;
|
||||||
`
|
|
||||||
shHeader = `#!/bin/sh
|
|
||||||
set -e
|
|
||||||
set -u
|
|
||||||
|
|
||||||
if test -s ./.env; then
|
|
||||||
. ./.env
|
|
||||||
fi
|
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -297,14 +293,21 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
runner := &shmigrate.Migrator{
|
||||||
|
Writer: os.Stdout,
|
||||||
|
SqlCommand: state.SQLCommand,
|
||||||
|
MigrationsDir: state.MigrationsDir,
|
||||||
|
LogQueryPath: filepath.Join(state.MigrationsDir, LOG_QUERY_NAME),
|
||||||
|
LogPath: state.LogPath,
|
||||||
|
}
|
||||||
|
migrations := sqlmigrate.NamesOnly(ups)
|
||||||
|
|
||||||
switch subcmd {
|
switch subcmd {
|
||||||
case "init":
|
case "init":
|
||||||
break
|
break
|
||||||
case "sync":
|
case "sync":
|
||||||
if err := syncLog(&state); err != nil {
|
syncLog(runner)
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
case "create":
|
case "create":
|
||||||
if len(leafArgs) == 0 {
|
if len(leafArgs) == 0 {
|
||||||
log.Fatal("create requires a description")
|
log.Fatal("create requires a description")
|
||||||
@ -324,8 +327,7 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
|
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
err = status(&state, ups)
|
if err := cmdStatus(ctx, &state, runner, migrations); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
case "list":
|
case "list":
|
||||||
@ -364,12 +366,11 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = up(&state, ups, upN)
|
if err := cmdUp(ctx, &state, runner, migrations, upN); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
case "down":
|
case "down":
|
||||||
var downN int
|
downN := 1
|
||||||
switch len(leafArgs) {
|
switch len(leafArgs) {
|
||||||
case 0:
|
case 0:
|
||||||
// default: roll back one
|
// default: roll back one
|
||||||
@ -384,8 +385,7 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = down(&state, downN)
|
if err := cmdDown(ctx, &state, runner, migrations, downN); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -875,190 +875,171 @@ func fixupMigration(dir string, basename string) (up, down bool, warn error, err
|
|||||||
return up, down, nil, nil
|
return up, down, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func syncLog(state *State) error {
|
func syncLog(runner *shmigrate.Migrator) {
|
||||||
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
|
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
|
||||||
getMigsPath = filepathUnclean(getMigsPath)
|
logPath := filepathUnclean(runner.LogPath)
|
||||||
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
|
|
||||||
logPath := filepathUnclean(state.LogPath)
|
|
||||||
|
|
||||||
fmt.Printf(shHeader)
|
fmt.Printf(shmigrate.ShHeader)
|
||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
fmt.Println("# SYNC: reload migrations log from DB")
|
fmt.Println("# SYNC: reload migrations log from DB")
|
||||||
fmt.Printf("%s > %s || true\n", getMigs, logPath)
|
fmt.Printf("%s > %s || true\n", syncCmd, logPath)
|
||||||
fmt.Printf("cat %s\n", logPath)
|
fmt.Printf("cat %s\n", logPath)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func up(state *State, ups []string, n int) error {
|
func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error {
|
||||||
var pending []string
|
// fixup pending migrations before generating the script
|
||||||
for _, mig := range ups {
|
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
|
||||||
found := slices.Contains(state.Migrated, mig)
|
|
||||||
if !found {
|
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
|
||||||
pending = append(pending, mig)
|
if err != nil {
|
||||||
}
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
|
if len(status.Pending) == 0 {
|
||||||
getMigsPath = filepathUnclean(getMigsPath)
|
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
|
||||||
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
|
|
||||||
|
|
||||||
if len(pending) == 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "# Already up-to-date\n")
|
fmt.Fprintf(os.Stderr, "# Already up-to-date\n")
|
||||||
fmt.Fprintf(os.Stderr, "#\n")
|
fmt.Fprintf(os.Stderr, "#\n")
|
||||||
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
|
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
|
||||||
fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath))
|
fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if n == 0 {
|
|
||||||
n = len(pending)
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedUp := []string{}
|
fmt.Printf(shmigrate.ShHeader)
|
||||||
fixedDown := []string{}
|
|
||||||
|
|
||||||
fmt.Printf(shHeader)
|
|
||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
fmt.Println("# FORWARD / UP Migrations")
|
fmt.Println("# FORWARD / UP Migrations")
|
||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
for i, migration := range pending {
|
|
||||||
if i >= n {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
path := filepath.Join(state.MigrationsDir, migration+".up.sql")
|
applied, err := sqlmigrate.Up(ctx, runner, migrations, n)
|
||||||
path = filepathUnclean(path)
|
if err != nil {
|
||||||
{
|
return err
|
||||||
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
|
|
||||||
if warn != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
|
||||||
}
|
|
||||||
if up {
|
|
||||||
fixedUp = append(fixedUp, migration)
|
|
||||||
}
|
|
||||||
if down {
|
|
||||||
fixedDown = append(fixedDown, migration)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cmd := strings.Replace(state.SQLCommand, "%s", path, 1)
|
|
||||||
fmt.Printf("# +%d %s\n", i+1, migration)
|
|
||||||
fmt.Println(cmd)
|
|
||||||
fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath))
|
|
||||||
fmt.Println("")
|
|
||||||
}
|
}
|
||||||
fmt.Println("cat", filepathUnclean(state.LogPath))
|
_ = applied
|
||||||
|
|
||||||
|
fmt.Println("cat", filepathUnclean(runner.LogPath))
|
||||||
|
|
||||||
showFixes(fixedUp, fixedDown)
|
showFixes(fixedUp, fixedDown)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func down(state *State, n int) error {
|
func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error {
|
||||||
lines := make([]string, len(state.Lines))
|
// fixup applied migrations before generating the script
|
||||||
copy(lines, state.Lines)
|
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
|
||||||
slices.Reverse(lines)
|
|
||||||
|
|
||||||
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
|
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
|
||||||
getMigsPath = filepathUnclean(getMigsPath)
|
if err != nil {
|
||||||
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if len(lines) == 0 {
|
if len(status.Applied) == 0 {
|
||||||
|
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
|
||||||
fmt.Fprintf(os.Stderr, "# No migration history\n")
|
fmt.Fprintf(os.Stderr, "# No migration history\n")
|
||||||
fmt.Fprintf(os.Stderr, "#\n")
|
fmt.Fprintf(os.Stderr, "#\n")
|
||||||
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
|
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
|
||||||
fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath))
|
fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if n == 0 {
|
|
||||||
n = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedUp := []string{}
|
fmt.Printf(shmigrate.ShHeader)
|
||||||
fixedDown := []string{}
|
|
||||||
|
|
||||||
var applied []string
|
|
||||||
for _, line := range lines {
|
|
||||||
migration := commentStartRe.ReplaceAllString(line, "")
|
|
||||||
migration = strings.TrimSpace(migration)
|
|
||||||
if migration == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
applied = append(applied, migration)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf(shHeader)
|
|
||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
fmt.Println("# ROLLBACK / DOWN Migration")
|
fmt.Println("# ROLLBACK / DOWN Migration")
|
||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
for i, migration := range applied {
|
|
||||||
if i >= n {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
downPath := filepath.Join(state.MigrationsDir, migration+".down.sql")
|
// check for missing down files before generating script
|
||||||
cmd := strings.Replace(state.SQLCommand, "%s", downPath, 1)
|
reversed := make([]string, len(status.Applied))
|
||||||
fmt.Printf("\n# -%d %s\n", i+1, migration)
|
copy(reversed, status.Applied)
|
||||||
|
slices.Reverse(reversed)
|
||||||
|
limit := n
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 1
|
||||||
|
}
|
||||||
|
if limit > len(reversed) {
|
||||||
|
limit = len(reversed)
|
||||||
|
}
|
||||||
|
for _, name := range reversed[:limit] {
|
||||||
|
downPath := filepath.Join(state.MigrationsDir, name+".down.sql")
|
||||||
if !fileExists(downPath) {
|
if !fileExists(downPath) {
|
||||||
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
|
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
|
||||||
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
|
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
|
||||||
fmt.Printf("# ERROR: MISSING FILE\n")
|
|
||||||
} else {
|
|
||||||
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
|
|
||||||
if warn != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
|
||||||
}
|
|
||||||
if up {
|
|
||||||
fixedUp = append(fixedUp, migration)
|
|
||||||
}
|
|
||||||
if down {
|
|
||||||
fixedDown = append(fixedDown, migration)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(cmd)
|
|
||||||
fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath))
|
|
||||||
fmt.Println("")
|
|
||||||
}
|
}
|
||||||
fmt.Println("cat", filepathUnclean(state.LogPath))
|
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, runner, migrations, n)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = rolled
|
||||||
|
|
||||||
|
fmt.Println("cat", filepathUnclean(runner.LogPath))
|
||||||
|
|
||||||
showFixes(fixedUp, fixedDown)
|
showFixes(fixedUp, fixedDown)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func status(state *State, ups []string) error {
|
func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration) error {
|
||||||
previous := make([]string, len(state.Lines))
|
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
|
||||||
copy(previous, state.Lines)
|
if err != nil {
|
||||||
slices.Reverse(previous)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir))
|
fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir))
|
||||||
fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath))
|
fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath))
|
||||||
fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand)
|
fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand)
|
||||||
fmt.Fprintf(os.Stderr, "\n")
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
fmt.Printf("# previous: %d\n", len(previous))
|
|
||||||
for _, mig := range previous {
|
// show applied in reverse (most recent first)
|
||||||
|
applied := make([]string, len(status.Applied))
|
||||||
|
copy(applied, status.Applied)
|
||||||
|
slices.Reverse(applied)
|
||||||
|
|
||||||
|
fmt.Printf("# previous: %d\n", len(applied))
|
||||||
|
for _, mig := range applied {
|
||||||
fmt.Printf(" %s\n", mig)
|
fmt.Printf(" %s\n", mig)
|
||||||
}
|
}
|
||||||
if len(previous) == 0 {
|
if len(applied) == 0 {
|
||||||
fmt.Println(" # (no previous migrations)")
|
fmt.Println(" # (no previous migrations)")
|
||||||
}
|
}
|
||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
var pending []string
|
fmt.Printf("# pending: %d\n", len(status.Pending))
|
||||||
for _, mig := range ups {
|
for _, mig := range status.Pending {
|
||||||
found := slices.Contains(state.Migrated, mig)
|
|
||||||
if !found {
|
|
||||||
pending = append(pending, mig)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Printf("# pending: %d\n", len(pending))
|
|
||||||
for _, mig := range pending {
|
|
||||||
fmt.Printf(" %s\n", mig)
|
fmt.Printf(" %s\n", mig)
|
||||||
}
|
}
|
||||||
if len(pending) == 0 {
|
if len(status.Pending) == 0 {
|
||||||
fmt.Println(" # (no pending migrations)")
|
fmt.Println(" # (no pending migrations)")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fixupAll runs fixupMigration on all known migrations (applied + pending).
|
||||||
|
func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Migration) (fixedUp, fixedDown []string) {
|
||||||
|
seen := map[string]bool{}
|
||||||
|
var all []string
|
||||||
|
for _, name := range applied {
|
||||||
|
if !seen[name] {
|
||||||
|
all = append(all, name)
|
||||||
|
seen[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range migrations {
|
||||||
|
if !seen[m.Name] {
|
||||||
|
all = append(all, m.Name)
|
||||||
|
seen[m.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range all {
|
||||||
|
up, down, warn, err := fixupMigration(migrationsDir, name)
|
||||||
|
if warn != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
|
}
|
||||||
|
if up {
|
||||||
|
fixedUp = append(fixedUp, name)
|
||||||
|
}
|
||||||
|
if down {
|
||||||
|
fixedDown = append(fixedDown, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fixedUp, fixedDown
|
||||||
|
}
|
||||||
|
|||||||
22
database/sqlmigrate/README.md
Normal file
22
database/sqlmigrate/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# sqlmigrate
|
||||||
|
|
||||||
|
Database-agnostic SQL migration library for Go.
|
||||||
|
|
||||||
|
[](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate)
|
||||||
|
|
||||||
|
## Backend packages
|
||||||
|
|
||||||
|
Each backend is a separate Go module to avoid pulling unnecessary drivers:
|
||||||
|
|
||||||
|
| Package | Database | Driver |
|
||||||
|
|---------|----------|--------|
|
||||||
|
| [pgmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/pgmigrate) | PostgreSQL | pgx/v5 |
|
||||||
|
| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql |
|
||||||
|
| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) |
|
||||||
|
| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb |
|
||||||
|
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) |
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI
|
||||||
|
uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime.
|
||||||
3
database/sqlmigrate/go.mod
Normal file
3
database/sqlmigrate/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module github.com/therootcompany/golib/database/sqlmigrate
|
||||||
|
|
||||||
|
go 1.26.1
|
||||||
7
database/sqlmigrate/shmigrate/go.mod
Normal file
7
database/sqlmigrate/shmigrate/go.mod
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
module github.com/therootcompany/golib/database/sqlmigrate/shmigrate
|
||||||
|
|
||||||
|
go 1.26.1
|
||||||
|
|
||||||
|
require github.com/therootcompany/golib/database/sqlmigrate v0.0.0
|
||||||
|
|
||||||
|
replace github.com/therootcompany/golib/database/sqlmigrate => ../
|
||||||
150
database/sqlmigrate/shmigrate/shmigrate.go
Normal file
150
database/sqlmigrate/shmigrate/shmigrate.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
// Package shmigrate implements sqlmigrate.Runner by generating POSIX
|
||||||
|
// shell commands that reference migration files on disk. It is used by
|
||||||
|
// the sql-migrate CLI to produce scripts that can be piped to sh.
|
||||||
|
package shmigrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/database/sqlmigrate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShHeader is the standard header for generated shell scripts.
|
||||||
|
const ShHeader = `#!/bin/sh
|
||||||
|
set -e
|
||||||
|
set -u
|
||||||
|
|
||||||
|
if test -s ./.env; then
|
||||||
|
. ./.env
|
||||||
|
fi
|
||||||
|
`
|
||||||
|
|
||||||
|
// Migrator generates shell scripts for migration execution.
|
||||||
|
// It implements sqlmigrate.Migrator so it can be used with
|
||||||
|
// sqlmigrate.Up, sqlmigrate.Down, and sqlmigrate.GetStatus.
|
||||||
|
type Migrator struct {
|
||||||
|
// Writer receives the generated shell script.
|
||||||
|
Writer io.Writer
|
||||||
|
|
||||||
|
// SqlCommand is the shell command template with %s for the file path.
|
||||||
|
// Example: `psql "$PG_URL" -v ON_ERROR_STOP=on -A -t --file %s`
|
||||||
|
SqlCommand string
|
||||||
|
|
||||||
|
// MigrationsDir is the path to the migrations directory on disk.
|
||||||
|
MigrationsDir string
|
||||||
|
|
||||||
|
// LogQueryPath is the path to the _migrations.sql query file.
|
||||||
|
// Used to sync the migrations log after each migration.
|
||||||
|
LogQueryPath string
|
||||||
|
|
||||||
|
// LogPath is the path to the migrations.log file.
|
||||||
|
LogPath string
|
||||||
|
|
||||||
|
// FS is an optional filesystem for reading the migrations log.
|
||||||
|
// When nil, the OS filesystem is used.
|
||||||
|
FS fs.FS
|
||||||
|
|
||||||
|
counter int
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify interface compliance at compile time
|
||||||
|
var _ sqlmigrate.Migrator = (*Migrator)(nil)
|
||||||
|
|
||||||
|
// ExecUp outputs a shell command to run the .up.sql migration file.
|
||||||
|
func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration) error {
|
||||||
|
r.counter++
|
||||||
|
return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecDown outputs a shell command to run the .down.sql migration file.
|
||||||
|
func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration) error {
|
||||||
|
r.counter++
|
||||||
|
return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Migrator) exec(name, suffix, label string) error {
|
||||||
|
path := unclean(filepath.Join(r.MigrationsDir, name+suffix))
|
||||||
|
cmd := strings.Replace(r.SqlCommand, "%s", path, 1)
|
||||||
|
|
||||||
|
syncCmd := strings.Replace(r.SqlCommand, "%s", unclean(r.LogQueryPath), 1)
|
||||||
|
logPath := unclean(r.LogPath)
|
||||||
|
|
||||||
|
fmt.Fprintf(r.Writer, "# %s %s\n", label, name)
|
||||||
|
if _, err := os.Stat(filepath.Join(r.MigrationsDir, name+suffix)); err != nil {
|
||||||
|
fmt.Fprintln(r.Writer, "# ERROR: MISSING FILE")
|
||||||
|
}
|
||||||
|
fmt.Fprintln(r.Writer, cmd)
|
||||||
|
fmt.Fprintf(r.Writer, "%s > %s\n", syncCmd, logPath)
|
||||||
|
fmt.Fprintln(r.Writer)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Applied reads the migrations log file and returns applied migrations.
|
||||||
|
// Supports two formats:
|
||||||
|
// - New: "id\tname" (tab-separated, written by updated _migrations.sql)
|
||||||
|
// - Old: "name" (name only, for backwards compatibility)
|
||||||
|
//
|
||||||
|
// Returns an empty slice if the file does not exist. When FS is set, reads
|
||||||
|
// from that filesystem; otherwise reads from the OS filesystem.
|
||||||
|
func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, error) {
|
||||||
|
var f io.ReadCloser
|
||||||
|
var err error
|
||||||
|
if r.FS != nil {
|
||||||
|
f, err = r.FS.Open(r.LogPath)
|
||||||
|
} else {
|
||||||
|
f, err = os.Open(r.LogPath)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("reading migrations log: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var applied []sqlmigrate.AppliedMigration
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
// strip inline comments
|
||||||
|
if idx := strings.Index(line, "#"); idx >= 0 {
|
||||||
|
line = strings.TrimSpace(line[:idx])
|
||||||
|
}
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, name, ok := strings.Cut(line, "\t"); ok {
|
||||||
|
applied = append(applied, sqlmigrate.AppliedMigration{ID: id, Name: name})
|
||||||
|
} else {
|
||||||
|
applied = append(applied, sqlmigrate.AppliedMigration{Name: line})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return applied, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the migration counter. Call between Up and Down
|
||||||
|
// if generating both in the same script.
|
||||||
|
func (r *Migrator) Reset() {
|
||||||
|
r.counter = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// unclean ensures a relative path starts with ./ or ../ so it
|
||||||
|
// is not interpreted as a command name in shell scripts.
|
||||||
|
func unclean(path string) string {
|
||||||
|
if strings.HasPrefix(path, "/") {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return "./" + path
|
||||||
|
}
|
||||||
276
database/sqlmigrate/sqlmigrate.go
Normal file
276
database/sqlmigrate/sqlmigrate.go
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
// Package sqlmigrate provides a database-agnostic SQL migration interface.
|
||||||
|
//
|
||||||
|
// Backend implementations (each a separate Go module):
|
||||||
|
// - pgmigrate: PostgreSQL via pgx/v5
|
||||||
|
// - mymigrate: MySQL/MariaDB via go-sql-driver/mysql
|
||||||
|
// - litemigrate: SQLite via database/sql
|
||||||
|
// - msmigrate: SQL Server via go-mssqldb
|
||||||
|
// - shmigrate: POSIX shell script generation
|
||||||
|
package sqlmigrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sentinel errors for migration operations.
|
||||||
|
var (
|
||||||
|
ErrMissingUp = errors.New("missing up migration")
|
||||||
|
ErrMissingDown = errors.New("missing down migration")
|
||||||
|
ErrWalkFailed = errors.New("walking migrations")
|
||||||
|
ErrExecFailed = errors.New("migration exec failed")
|
||||||
|
ErrQueryApplied = errors.New("querying applied migrations")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Migration represents a paired up/down migration.
|
||||||
|
type Migration struct {
|
||||||
|
Name string // e.g. "2026-04-05-001000_create-todos"
|
||||||
|
ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect
|
||||||
|
Up string // SQL content of the .up.sql file
|
||||||
|
Down string // SQL content of the .down.sql file
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppliedMigration represents a migration recorded in the _migrations table.
|
||||||
|
type AppliedMigration struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status represents the current migration state.
|
||||||
|
type Status struct {
|
||||||
|
Applied []string
|
||||||
|
Pending []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migrator executes migrations. Implementations handle the
|
||||||
|
// database-specific or output-specific details.
|
||||||
|
type Migrator interface {
|
||||||
|
// ExecUp runs the up migration. For database migrators this executes
|
||||||
|
// m.Up in a transaction. For shell migrators this outputs a command
|
||||||
|
// referencing the .up.sql file.
|
||||||
|
ExecUp(ctx context.Context, m Migration) error
|
||||||
|
|
||||||
|
// ExecDown runs the down migration.
|
||||||
|
ExecDown(ctx context.Context, m Migration) error
|
||||||
|
|
||||||
|
// Applied returns all applied migrations from the _migrations table,
|
||||||
|
// sorted lexicographically by name. Returns an empty slice (not an
|
||||||
|
// error) if the migrations table or log does not exist yet.
|
||||||
|
Applied(ctx context.Context) ([]AppliedMigration, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// idFromInsert extracts the hex ID from an INSERT INTO _migrations line.
|
||||||
|
// Matches: INSERT INTO _migrations (name, id) VALUES ('...', '<hex>');
|
||||||
|
var idFromInsert = regexp.MustCompile(
|
||||||
|
`(?i)INSERT\s+INTO\s+_migrations\s*\(\s*name\s*,\s*id\s*\)\s*VALUES\s*\(\s*'[^']*'\s*,\s*'([0-9a-fA-F]+)'\s*\)`,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Collect reads .up.sql and .down.sql files from fsys, pairs them by
|
||||||
|
// basename, and returns them sorted lexicographically by name.
|
||||||
|
// If the up SQL contains an INSERT INTO _migrations line, the hex ID
|
||||||
|
// is extracted and stored in Migration.ID.
|
||||||
|
func Collect(fsys fs.FS) ([]Migration, error) {
|
||||||
|
ups := map[string]string{}
|
||||||
|
downs := map[string]string{}
|
||||||
|
|
||||||
|
err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if d.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := d.Name()
|
||||||
|
if base, ok := strings.CutSuffix(name, ".up.sql"); ok {
|
||||||
|
b, readErr := fs.ReadFile(fsys, path)
|
||||||
|
if readErr != nil {
|
||||||
|
return readErr
|
||||||
|
}
|
||||||
|
ups[base] = string(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if base, ok := strings.CutSuffix(name, ".down.sql"); ok {
|
||||||
|
b, readErr := fs.ReadFile(fsys, path)
|
||||||
|
if readErr != nil {
|
||||||
|
return readErr
|
||||||
|
}
|
||||||
|
downs[base] = string(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var migrations []Migration
|
||||||
|
for name, upSQL := range ups {
|
||||||
|
downSQL, ok := downs[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrMissingDown, name)
|
||||||
|
}
|
||||||
|
var id string
|
||||||
|
if m := idFromInsert.FindStringSubmatch(upSQL); m != nil {
|
||||||
|
id = m[1]
|
||||||
|
}
|
||||||
|
migrations = append(migrations, Migration{
|
||||||
|
Name: name,
|
||||||
|
ID: id,
|
||||||
|
Up: upSQL,
|
||||||
|
Down: downSQL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for name := range downs {
|
||||||
|
if _, ok := ups[name]; !ok {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrMissingUp, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(migrations, func(a, b Migration) int {
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
return migrations, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NamesOnly builds a Migration slice from a list of names, with empty
|
||||||
|
// Up/Down content. Useful for shell-based runners that reference files
|
||||||
|
// on disk rather than executing SQL directly.
|
||||||
|
func NamesOnly(names []string) []Migration {
|
||||||
|
migrations := make([]Migration, len(names))
|
||||||
|
for i, name := range names {
|
||||||
|
migrations[i] = Migration{Name: name}
|
||||||
|
}
|
||||||
|
return migrations
|
||||||
|
}
|
||||||
|
|
||||||
|
// isApplied returns true if the migration matches any applied entry by name or ID.
|
||||||
|
func isApplied(m Migration, applied []AppliedMigration) bool {
|
||||||
|
for _, a := range applied {
|
||||||
|
if a.Name == m.Name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if m.ID != "" && a.ID != "" && a.ID == m.ID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMigration looks up a migration by the applied entry's name or ID.
|
||||||
|
func findMigration(a AppliedMigration, byName map[string]Migration, byID map[string]Migration) (Migration, bool) {
|
||||||
|
if m, ok := byName[a.Name]; ok {
|
||||||
|
return m, true
|
||||||
|
}
|
||||||
|
if a.ID != "" {
|
||||||
|
if m, ok := byID[a.ID]; ok {
|
||||||
|
return m, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Migration{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Up applies up to n pending migrations using the given Runner.
|
||||||
|
// If n <= 0, applies all pending. Returns the names of applied migrations.
|
||||||
|
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
||||||
|
applied, err := r.Applied(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pending []Migration
|
||||||
|
for _, m := range migrations {
|
||||||
|
if !isApplied(m, applied) {
|
||||||
|
pending = append(pending, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n <= 0 {
|
||||||
|
n = len(pending)
|
||||||
|
}
|
||||||
|
if n > len(pending) {
|
||||||
|
n = len(pending)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ran []string
|
||||||
|
for _, m := range pending[:n] {
|
||||||
|
if err := r.ExecUp(ctx, m); err != nil {
|
||||||
|
return ran, fmt.Errorf("%s (up): %w", m.Name, err)
|
||||||
|
}
|
||||||
|
ran = append(ran, m.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ran, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Down rolls back up to n applied migrations, most recent first.
|
||||||
|
// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations.
|
||||||
|
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
|
||||||
|
applied, err := r.Applied(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
byName := map[string]Migration{}
|
||||||
|
byID := map[string]Migration{}
|
||||||
|
for _, m := range migrations {
|
||||||
|
byName[m.Name] = m
|
||||||
|
if m.ID != "" {
|
||||||
|
byID[m.ID] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reversed := make([]AppliedMigration, len(applied))
|
||||||
|
copy(reversed, applied)
|
||||||
|
slices.Reverse(reversed)
|
||||||
|
|
||||||
|
if n <= 0 || n > len(reversed) {
|
||||||
|
n = len(reversed)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ran []string
|
||||||
|
for _, a := range reversed[:n] {
|
||||||
|
m, ok := findMigration(a, byName, byID)
|
||||||
|
if !ok {
|
||||||
|
return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name)
|
||||||
|
}
|
||||||
|
if err := r.ExecDown(ctx, m); err != nil {
|
||||||
|
return ran, fmt.Errorf("%s (down): %w", a.Name, err)
|
||||||
|
}
|
||||||
|
ran = append(ran, a.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ran, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatus returns applied and pending migration lists.
|
||||||
|
func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status, error) {
|
||||||
|
applied, err := r.Applied(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
appliedNames := make([]string, len(applied))
|
||||||
|
for i, a := range applied {
|
||||||
|
appliedNames[i] = a.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
var pending []string
|
||||||
|
for _, m := range migrations {
|
||||||
|
if !isApplied(m, applied) {
|
||||||
|
pending = append(pending, m.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Status{
|
||||||
|
Applied: appliedNames,
|
||||||
|
Pending: pending,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
457
database/sqlmigrate/sqlmigrate_test.go
Normal file
457
database/sqlmigrate/sqlmigrate_test.go
Normal file
@ -0,0 +1,457 @@
|
|||||||
|
package sqlmigrate_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/database/sqlmigrate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// applied builds an []AppliedMigration from names (IDs empty).
|
||||||
|
func applied(names ...string) []sqlmigrate.AppliedMigration {
|
||||||
|
out := make([]sqlmigrate.AppliedMigration, len(names))
|
||||||
|
for i, n := range names {
|
||||||
|
out[i] = sqlmigrate.AppliedMigration{Name: n}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockMigrator tracks applied migrations in memory.
|
||||||
|
type mockMigrator struct {
|
||||||
|
applied []sqlmigrate.AppliedMigration
|
||||||
|
execErr error // if set, ExecUp/ExecDown return this on every call
|
||||||
|
upCalls []string
|
||||||
|
downCalls []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error {
|
||||||
|
m.upCalls = append(m.upCalls, mig.Name)
|
||||||
|
if m.execErr != nil {
|
||||||
|
return m.execErr
|
||||||
|
}
|
||||||
|
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID})
|
||||||
|
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int {
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error {
|
||||||
|
m.downCalls = append(m.downCalls, mig.Name)
|
||||||
|
if m.execErr != nil {
|
||||||
|
return m.execErr
|
||||||
|
}
|
||||||
|
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name })
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) {
|
||||||
|
return slices.Clone(m.applied), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Collect ---
|
||||||
|
|
||||||
|
func TestCollect(t *testing.T) {
|
||||||
|
t.Run("pairs and sorts", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"002_second.up.sql": {Data: []byte("CREATE TABLE b;")},
|
||||||
|
"002_second.down.sql": {Data: []byte("DROP TABLE b;")},
|
||||||
|
"001_first.up.sql": {Data: []byte("CREATE TABLE a;")},
|
||||||
|
"001_first.down.sql": {Data: []byte("DROP TABLE a;")},
|
||||||
|
}
|
||||||
|
migrations, err := sqlmigrate.Collect(fsys)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(migrations) != 2 {
|
||||||
|
t.Fatalf("got %d migrations, want 2", len(migrations))
|
||||||
|
}
|
||||||
|
if migrations[0].Name != "001_first" {
|
||||||
|
t.Errorf("first = %q, want %q", migrations[0].Name, "001_first")
|
||||||
|
}
|
||||||
|
if migrations[1].Name != "002_second" {
|
||||||
|
t.Errorf("second = %q, want %q", migrations[1].Name, "002_second")
|
||||||
|
}
|
||||||
|
if migrations[0].Up != "CREATE TABLE a;" {
|
||||||
|
t.Errorf("first.Up = %q", migrations[0].Up)
|
||||||
|
}
|
||||||
|
if migrations[0].Down != "DROP TABLE a;" {
|
||||||
|
t.Errorf("first.Down = %q", migrations[0].Down)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parses ID from INSERT", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")},
|
||||||
|
"001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")},
|
||||||
|
}
|
||||||
|
migrations, err := sqlmigrate.Collect(fsys)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrations[0].ID != "abcd1234" {
|
||||||
|
t.Errorf("ID = %q, want %q", migrations[0].ID, "abcd1234")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no ID when no INSERT", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"001_init.up.sql": {Data: []byte("CREATE TABLE a;")},
|
||||||
|
"001_init.down.sql": {Data: []byte("DROP TABLE a;")},
|
||||||
|
}
|
||||||
|
migrations, err := sqlmigrate.Collect(fsys)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrations[0].ID != "" {
|
||||||
|
t.Errorf("ID = %q, want empty", migrations[0].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing down", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"001_only-up.up.sql": {Data: []byte("CREATE TABLE x;")},
|
||||||
|
}
|
||||||
|
_, err := sqlmigrate.Collect(fsys)
|
||||||
|
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
|
||||||
|
t.Errorf("got %v, want ErrMissingDown", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing up", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"001_only-down.down.sql": {Data: []byte("DROP TABLE x;")},
|
||||||
|
}
|
||||||
|
_, err := sqlmigrate.Collect(fsys)
|
||||||
|
if !errors.Is(err, sqlmigrate.ErrMissingUp) {
|
||||||
|
t.Errorf("got %v, want ErrMissingUp", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ignores non-sql files", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"001_init.up.sql": {Data: []byte("UP")},
|
||||||
|
"001_init.down.sql": {Data: []byte("DOWN")},
|
||||||
|
"README.md": {Data: []byte("# Migrations")},
|
||||||
|
"_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")},
|
||||||
|
}
|
||||||
|
migrations, err := sqlmigrate.Collect(fsys)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(migrations) != 1 {
|
||||||
|
t.Fatalf("got %d migrations, want 1", len(migrations))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty fs", func(t *testing.T) {
|
||||||
|
fsys := fstest.MapFS{}
|
||||||
|
migrations, err := sqlmigrate.Collect(fsys)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(migrations) != 0 {
|
||||||
|
t.Fatalf("got %d migrations, want 0", len(migrations))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- NamesOnly ---
|
||||||
|
|
||||||
|
func TestNamesOnly(t *testing.T) {
|
||||||
|
names := []string{"001_init", "002_users"}
|
||||||
|
migrations := sqlmigrate.NamesOnly(names)
|
||||||
|
if len(migrations) != 2 {
|
||||||
|
t.Fatalf("got %d, want 2", len(migrations))
|
||||||
|
}
|
||||||
|
for i, m := range migrations {
|
||||||
|
if m.Name != names[i] {
|
||||||
|
t.Errorf("[%d].Name = %q, want %q", i, m.Name, names[i])
|
||||||
|
}
|
||||||
|
if m.Up != "" || m.Down != "" {
|
||||||
|
t.Errorf("[%d] has non-empty content", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Up ---
|
||||||
|
|
||||||
|
func TestUp(t *testing.T) {
|
||||||
|
ctx := t.Context()
|
||||||
|
migrations := []sqlmigrate.Migration{
|
||||||
|
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
||||||
|
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
|
||||||
|
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("apply all", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(ran, []string{"001_init", "002_users", "003_posts"}) {
|
||||||
|
t.Errorf("applied = %v", ran)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("apply n", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(ran, []string{"001_init", "002_users"}) {
|
||||||
|
t.Errorf("applied = %v", ran)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("none pending", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(ran) != 0 {
|
||||||
|
t.Fatalf("applied %d, want 0", len(ran))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partial pending", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init")}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(ran, []string{"002_users", "003_posts"}) {
|
||||||
|
t.Errorf("applied = %v", ran)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("n exceeds pending", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init")}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 99)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(ran) != 2 {
|
||||||
|
t.Fatalf("applied %d, want 2", len(ran))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exec error stops and returns partial", func(t *testing.T) {
|
||||||
|
m := &failOnNthMigrator{failAt: 1}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if len(ran) != 1 {
|
||||||
|
t.Errorf("applied %d before error, want 1", len(ran))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skips migration matched by ID", func(t *testing.T) {
|
||||||
|
// DB has migration applied under old name, but same ID
|
||||||
|
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{
|
||||||
|
{Name: "001_old-name", ID: "aa11bb22"},
|
||||||
|
}}
|
||||||
|
migs := []sqlmigrate.Migration{
|
||||||
|
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
||||||
|
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
|
||||||
|
}
|
||||||
|
ran, err := sqlmigrate.Up(ctx, m, migs, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Only 002_users should be applied; 001 is matched by ID
|
||||||
|
if !slices.Equal(ran, []string{"002_users"}) {
|
||||||
|
t.Errorf("applied = %v, want [002_users]", ran)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// failOnNthMigrator fails on the Nth ExecUp call (0-indexed).
|
||||||
|
type failOnNthMigrator struct {
|
||||||
|
applied []sqlmigrate.AppliedMigration
|
||||||
|
calls int
|
||||||
|
failAt int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error {
|
||||||
|
if m.calls == m.failAt {
|
||||||
|
m.calls++
|
||||||
|
return errors.New("connection lost")
|
||||||
|
}
|
||||||
|
m.calls++
|
||||||
|
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID})
|
||||||
|
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int {
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error {
|
||||||
|
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name })
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) {
|
||||||
|
return slices.Clone(m.applied), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Down ---
|
||||||
|
|
||||||
|
func TestDown(t *testing.T) {
|
||||||
|
ctx := t.Context()
|
||||||
|
migrations := []sqlmigrate.Migration{
|
||||||
|
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
||||||
|
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
|
||||||
|
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("rollback all", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(rolled) != 3 {
|
||||||
|
t.Fatalf("rolled %d, want 3", len(rolled))
|
||||||
|
}
|
||||||
|
if rolled[0] != "003_posts" {
|
||||||
|
t.Errorf("first rollback = %q, want 003_posts", rolled[0])
|
||||||
|
}
|
||||||
|
if rolled[2] != "001_init" {
|
||||||
|
t.Errorf("last rollback = %q, want 001_init", rolled[2])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rollback n", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(rolled, []string{"003_posts", "002_users"}) {
|
||||||
|
t.Errorf("rolled = %v", rolled)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("none applied", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{}
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(rolled) != 0 {
|
||||||
|
t.Fatalf("rolled %d, want 0", len(rolled))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exec error", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{
|
||||||
|
applied: applied("001_init", "002_users"),
|
||||||
|
execErr: errors.New("permission denied"),
|
||||||
|
}
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if len(rolled) != 0 {
|
||||||
|
t.Errorf("rolled %d on error, want 0", len(rolled))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown migration in applied", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
|
||||||
|
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
|
||||||
|
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
|
||||||
|
t.Errorf("got %v, want ErrMissingDown", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("n exceeds applied", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init")}
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, m, migrations, 99)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(rolled) != 1 {
|
||||||
|
t.Fatalf("rolled %d, want 1", len(rolled))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("finds migration by ID when name changed", func(t *testing.T) {
|
||||||
|
// DB has old name, file has new name, same ID
|
||||||
|
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{
|
||||||
|
{Name: "001_old-name", ID: "aa11bb22"},
|
||||||
|
}}
|
||||||
|
migs := []sqlmigrate.Migration{
|
||||||
|
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
|
||||||
|
}
|
||||||
|
rolled, err := sqlmigrate.Down(ctx, m, migs, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(rolled, []string{"001_old-name"}) {
|
||||||
|
t.Errorf("rolled = %v, want [001_old-name]", rolled)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetStatus ---
|
||||||
|
|
||||||
|
func TestGetStatus(t *testing.T) {
|
||||||
|
ctx := t.Context()
|
||||||
|
migrations := []sqlmigrate.Migration{
|
||||||
|
{Name: "001_init"},
|
||||||
|
{Name: "002_users"},
|
||||||
|
{Name: "003_posts"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("all pending", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{}
|
||||||
|
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(status.Applied) != 0 {
|
||||||
|
t.Errorf("applied = %d, want 0", len(status.Applied))
|
||||||
|
}
|
||||||
|
if len(status.Pending) != 3 {
|
||||||
|
t.Errorf("pending = %d, want 3", len(status.Pending))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partial", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init")}
|
||||||
|
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(status.Applied) != 1 || status.Applied[0] != "001_init" {
|
||||||
|
t.Errorf("applied = %v", status.Applied)
|
||||||
|
}
|
||||||
|
if !slices.Equal(status.Pending, []string{"002_users", "003_posts"}) {
|
||||||
|
t.Errorf("pending = %v", status.Pending)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all applied", func(t *testing.T) {
|
||||||
|
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
|
||||||
|
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(status.Applied) != 3 {
|
||||||
|
t.Errorf("applied = %d, want 3", len(status.Applied))
|
||||||
|
}
|
||||||
|
if len(status.Pending) != 0 {
|
||||||
|
t.Errorf("pending = %d, want 0", len(status.Pending))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user