From 3547b7e409836cfdc9dced8b937501b0a799fc58 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Wed, 8 Apr 2026 15:03:56 -0600 Subject: [PATCH] ref(database/sqlmigrate): extract migration library with shmigrate backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Factor the inline migration logic from cmd/sql-migrate into reusable packages: database/sqlmigrate (core types, matching, file collection) and database/sqlmigrate/shmigrate (shell script generation backend). No behavior changes — the CLI produces identical output. The shmigrate package implements the sqlmigrate.Migrator interface so other backends (pgmigrate, mymigrate, etc.) can follow the same pattern. --- cmd/sql-migrate/go.mod | 12 +- cmd/sql-migrate/main.go | 269 ++++++------ database/sqlmigrate/README.md | 22 + database/sqlmigrate/go.mod | 3 + database/sqlmigrate/shmigrate/go.mod | 7 + database/sqlmigrate/shmigrate/shmigrate.go | 150 +++++++ database/sqlmigrate/sqlmigrate.go | 276 +++++++++++++ database/sqlmigrate/sqlmigrate_test.go | 457 +++++++++++++++++++++ 8 files changed, 1051 insertions(+), 145 deletions(-) create mode 100644 database/sqlmigrate/README.md create mode 100644 database/sqlmigrate/go.mod create mode 100644 database/sqlmigrate/shmigrate/go.mod create mode 100644 database/sqlmigrate/shmigrate/shmigrate.go create mode 100644 database/sqlmigrate/sqlmigrate.go create mode 100644 database/sqlmigrate/sqlmigrate_test.go diff --git a/cmd/sql-migrate/go.mod b/cmd/sql-migrate/go.mod index 042d55e..dce9864 100644 --- a/cmd/sql-migrate/go.mod +++ b/cmd/sql-migrate/go.mod @@ -1,3 +1,13 @@ module github.com/therootcompany/golib/cmd/sql-migrate/v2 -go 1.25.3 +go 1.26.1 + +require ( + github.com/therootcompany/golib/database/sqlmigrate v0.0.0 + github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0 +) + +replace ( + github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate + github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate +) diff --git a/cmd/sql-migrate/main.go b/cmd/sql-migrate/main.go index cc015a3..cbfb6ab 100644 --- a/cmd/sql-migrate/main.go +++ b/cmd/sql-migrate/main.go @@ -13,6 +13,7 @@ package main import ( "bufio" + "context" "crypto/rand" "encoding/hex" "flag" @@ -27,6 +28,9 @@ import ( "strconv" "strings" "time" + + "github.com/therootcompany/golib/database/sqlmigrate" + "github.com/therootcompany/golib/database/sqlmigrate/shmigrate" ) // replaced by goreleaser / ldflags @@ -67,14 +71,6 @@ DROP TABLE IF EXISTS _migrations; ` LOG_MIGRATIONS_QUERY = `-- note: CLI arguments must be passed to the sql command to keep output clean SELECT name FROM _migrations ORDER BY name; -` - shHeader = `#!/bin/sh -set -e -set -u - -if test -s ./.env; then - . ./.env -fi ` ) @@ -297,14 +293,21 @@ func main() { os.Exit(1) } + ctx := context.Background() + runner := &shmigrate.Migrator{ + Writer: os.Stdout, + SqlCommand: state.SQLCommand, + MigrationsDir: state.MigrationsDir, + LogQueryPath: filepath.Join(state.MigrationsDir, LOG_QUERY_NAME), + LogPath: state.LogPath, + } + migrations := sqlmigrate.NamesOnly(ups) + switch subcmd { case "init": break case "sync": - if err := syncLog(&state); err != nil { - log.Fatal(err) - } - return + syncLog(runner) case "create": if len(leafArgs) == 0 { log.Fatal("create requires a description") @@ -324,8 +327,7 @@ func main() { fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " ")) os.Exit(1) } - err = status(&state, ups) - if err != nil { + if err := cmdStatus(ctx, &state, runner, migrations); err != nil { log.Fatal(err) } case "list": @@ -364,12 +366,11 @@ func main() { os.Exit(1) } - err = up(&state, ups, upN) - if err != nil { + if err := cmdUp(ctx, &state, runner, migrations, upN); err != nil { log.Fatal(err) } case "down": - var downN int + downN := 1 switch len(leafArgs) { case 0: // default: roll back one @@ -384,8 +385,7 @@ func main() { os.Exit(1) } - err = down(&state, downN) - if err != nil { + if err := cmdDown(ctx, &state, runner, migrations, downN); err != nil { log.Fatal(err) } default: @@ -875,190 +875,171 @@ func fixupMigration(dir string, basename string) (up, down bool, warn error, err return up, down, nil, nil } -func syncLog(state *State) error { - getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) - getMigsPath = filepathUnclean(getMigsPath) - getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1) - logPath := filepathUnclean(state.LogPath) +func syncLog(runner *shmigrate.Migrator) { + syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1) + logPath := filepathUnclean(runner.LogPath) - fmt.Printf(shHeader) + fmt.Printf(shmigrate.ShHeader) fmt.Println("") fmt.Println("# SYNC: reload migrations log from DB") - fmt.Printf("%s > %s || true\n", getMigs, logPath) + fmt.Printf("%s > %s || true\n", syncCmd, logPath) fmt.Printf("cat %s\n", logPath) - return nil } -func up(state *State, ups []string, n int) error { - var pending []string - for _, mig := range ups { - found := slices.Contains(state.Migrated, mig) - if !found { - pending = append(pending, mig) - } +func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error { + // fixup pending migrations before generating the script + fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations) + + status, err := sqlmigrate.GetStatus(ctx, runner, migrations) + if err != nil { + return err } - getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) - getMigsPath = filepathUnclean(getMigsPath) - getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1) - - if len(pending) == 0 { + if len(status.Pending) == 0 { + syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1) fmt.Fprintf(os.Stderr, "# Already up-to-date\n") fmt.Fprintf(os.Stderr, "#\n") fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n") - fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath)) + fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath)) return nil } - if n == 0 { - n = len(pending) - } - fixedUp := []string{} - fixedDown := []string{} - - fmt.Printf(shHeader) + fmt.Printf(shmigrate.ShHeader) fmt.Println("") fmt.Println("# FORWARD / UP Migrations") fmt.Println("") - for i, migration := range pending { - if i >= n { - break - } - path := filepath.Join(state.MigrationsDir, migration+".up.sql") - path = filepathUnclean(path) - { - up, down, warn, err := fixupMigration(state.MigrationsDir, migration) - if warn != nil { - fmt.Fprintf(os.Stderr, "Warn: %s\n", warn) - } - if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - } - if up { - fixedUp = append(fixedUp, migration) - } - if down { - fixedDown = append(fixedDown, migration) - } - } - cmd := strings.Replace(state.SQLCommand, "%s", path, 1) - fmt.Printf("# +%d %s\n", i+1, migration) - fmt.Println(cmd) - fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath)) - fmt.Println("") + applied, err := sqlmigrate.Up(ctx, runner, migrations, n) + if err != nil { + return err } - fmt.Println("cat", filepathUnclean(state.LogPath)) + _ = applied + + fmt.Println("cat", filepathUnclean(runner.LogPath)) showFixes(fixedUp, fixedDown) return nil } -func down(state *State, n int) error { - lines := make([]string, len(state.Lines)) - copy(lines, state.Lines) - slices.Reverse(lines) +func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error { + // fixup applied migrations before generating the script + fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations) - getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) - getMigsPath = filepathUnclean(getMigsPath) - getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1) + status, err := sqlmigrate.GetStatus(ctx, runner, migrations) + if err != nil { + return err + } - if len(lines) == 0 { + if len(status.Applied) == 0 { + syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1) fmt.Fprintf(os.Stderr, "# No migration history\n") fmt.Fprintf(os.Stderr, "#\n") fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n") - fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath)) + fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath)) return nil } - if n == 0 { - n = 1 - } - fixedUp := []string{} - fixedDown := []string{} - - var applied []string - for _, line := range lines { - migration := commentStartRe.ReplaceAllString(line, "") - migration = strings.TrimSpace(migration) - if migration == "" { - continue - } - applied = append(applied, migration) - } - - fmt.Printf(shHeader) + fmt.Printf(shmigrate.ShHeader) fmt.Println("") fmt.Println("# ROLLBACK / DOWN Migration") fmt.Println("") - for i, migration := range applied { - if i >= n { - break - } - downPath := filepath.Join(state.MigrationsDir, migration+".down.sql") - cmd := strings.Replace(state.SQLCommand, "%s", downPath, 1) - fmt.Printf("\n# -%d %s\n", i+1, migration) + // check for missing down files before generating script + reversed := make([]string, len(status.Applied)) + copy(reversed, status.Applied) + slices.Reverse(reversed) + limit := n + if limit <= 0 { + limit = 1 + } + if limit > len(reversed) { + limit = len(reversed) + } + for _, name := range reversed[:limit] { + downPath := filepath.Join(state.MigrationsDir, name+".down.sql") if !fileExists(downPath) { fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath)) fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n") - fmt.Printf("# ERROR: MISSING FILE\n") - } else { - up, down, warn, err := fixupMigration(state.MigrationsDir, migration) - if warn != nil { - fmt.Fprintf(os.Stderr, "Warn: %s\n", warn) - } - if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - } - if up { - fixedUp = append(fixedUp, migration) - } - if down { - fixedDown = append(fixedDown, migration) - } } - - fmt.Println(cmd) - fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath)) - fmt.Println("") } - fmt.Println("cat", filepathUnclean(state.LogPath)) + + rolled, err := sqlmigrate.Down(ctx, runner, migrations, n) + if err != nil { + return err + } + _ = rolled + + fmt.Println("cat", filepathUnclean(runner.LogPath)) showFixes(fixedUp, fixedDown) return nil } -func status(state *State, ups []string) error { - previous := make([]string, len(state.Lines)) - copy(previous, state.Lines) - slices.Reverse(previous) +func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration) error { + status, err := sqlmigrate.GetStatus(ctx, runner, migrations) + if err != nil { + return err + } fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir)) fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath)) fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand) fmt.Fprintf(os.Stderr, "\n") - fmt.Printf("# previous: %d\n", len(previous)) - for _, mig := range previous { + + // show applied in reverse (most recent first) + applied := make([]string, len(status.Applied)) + copy(applied, status.Applied) + slices.Reverse(applied) + + fmt.Printf("# previous: %d\n", len(applied)) + for _, mig := range applied { fmt.Printf(" %s\n", mig) } - if len(previous) == 0 { + if len(applied) == 0 { fmt.Println(" # (no previous migrations)") } fmt.Println("") - var pending []string - for _, mig := range ups { - found := slices.Contains(state.Migrated, mig) - if !found { - pending = append(pending, mig) - } - } - fmt.Printf("# pending: %d\n", len(pending)) - for _, mig := range pending { + fmt.Printf("# pending: %d\n", len(status.Pending)) + for _, mig := range status.Pending { fmt.Printf(" %s\n", mig) } - if len(pending) == 0 { + if len(status.Pending) == 0 { fmt.Println(" # (no pending migrations)") } return nil } + +// fixupAll runs fixupMigration on all known migrations (applied + pending). +func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Migration) (fixedUp, fixedDown []string) { + seen := map[string]bool{} + var all []string + for _, name := range applied { + if !seen[name] { + all = append(all, name) + seen[name] = true + } + } + for _, m := range migrations { + if !seen[m.Name] { + all = append(all, m.Name) + seen[m.Name] = true + } + } + + for _, name := range all { + up, down, warn, err := fixupMigration(migrationsDir, name) + if warn != nil { + fmt.Fprintf(os.Stderr, "Warn: %s\n", warn) + } + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + } + if up { + fixedUp = append(fixedUp, name) + } + if down { + fixedDown = append(fixedDown, name) + } + } + return fixedUp, fixedDown +} diff --git a/database/sqlmigrate/README.md b/database/sqlmigrate/README.md new file mode 100644 index 0000000..48ac473 --- /dev/null +++ b/database/sqlmigrate/README.md @@ -0,0 +1,22 @@ +# sqlmigrate + +Database-agnostic SQL migration library for Go. + +[![Go Reference](https://pkg.go.dev/badge/github.com/therootcompany/golib/database/sqlmigrate.svg)](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate) + +## Backend packages + +Each backend is a separate Go module to avoid pulling unnecessary drivers: + +| Package | Database | Driver | +|---------|----------|--------| +| [pgmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/pgmigrate) | PostgreSQL | pgx/v5 | +| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql | +| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) | +| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb | +| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) | + +## CLI + +The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI +uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime. diff --git a/database/sqlmigrate/go.mod b/database/sqlmigrate/go.mod new file mode 100644 index 0000000..479d822 --- /dev/null +++ b/database/sqlmigrate/go.mod @@ -0,0 +1,3 @@ +module github.com/therootcompany/golib/database/sqlmigrate + +go 1.26.1 diff --git a/database/sqlmigrate/shmigrate/go.mod b/database/sqlmigrate/shmigrate/go.mod new file mode 100644 index 0000000..0a5b0f1 --- /dev/null +++ b/database/sqlmigrate/shmigrate/go.mod @@ -0,0 +1,7 @@ +module github.com/therootcompany/golib/database/sqlmigrate/shmigrate + +go 1.26.1 + +require github.com/therootcompany/golib/database/sqlmigrate v0.0.0 + +replace github.com/therootcompany/golib/database/sqlmigrate => ../ diff --git a/database/sqlmigrate/shmigrate/shmigrate.go b/database/sqlmigrate/shmigrate/shmigrate.go new file mode 100644 index 0000000..c246aab --- /dev/null +++ b/database/sqlmigrate/shmigrate/shmigrate.go @@ -0,0 +1,150 @@ +// Package shmigrate implements sqlmigrate.Runner by generating POSIX +// shell commands that reference migration files on disk. It is used by +// the sql-migrate CLI to produce scripts that can be piped to sh. +package shmigrate + +import ( + "bufio" + "context" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/therootcompany/golib/database/sqlmigrate" +) + +// ShHeader is the standard header for generated shell scripts. +const ShHeader = `#!/bin/sh +set -e +set -u + +if test -s ./.env; then + . ./.env +fi +` + +// Migrator generates shell scripts for migration execution. +// It implements sqlmigrate.Migrator so it can be used with +// sqlmigrate.Up, sqlmigrate.Down, and sqlmigrate.GetStatus. +type Migrator struct { + // Writer receives the generated shell script. + Writer io.Writer + + // SqlCommand is the shell command template with %s for the file path. + // Example: `psql "$PG_URL" -v ON_ERROR_STOP=on -A -t --file %s` + SqlCommand string + + // MigrationsDir is the path to the migrations directory on disk. + MigrationsDir string + + // LogQueryPath is the path to the _migrations.sql query file. + // Used to sync the migrations log after each migration. + LogQueryPath string + + // LogPath is the path to the migrations.log file. + LogPath string + + // FS is an optional filesystem for reading the migrations log. + // When nil, the OS filesystem is used. + FS fs.FS + + counter int +} + +// verify interface compliance at compile time +var _ sqlmigrate.Migrator = (*Migrator)(nil) + +// ExecUp outputs a shell command to run the .up.sql migration file. +func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration) error { + r.counter++ + return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter)) +} + +// ExecDown outputs a shell command to run the .down.sql migration file. +func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration) error { + r.counter++ + return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter)) +} + +func (r *Migrator) exec(name, suffix, label string) error { + path := unclean(filepath.Join(r.MigrationsDir, name+suffix)) + cmd := strings.Replace(r.SqlCommand, "%s", path, 1) + + syncCmd := strings.Replace(r.SqlCommand, "%s", unclean(r.LogQueryPath), 1) + logPath := unclean(r.LogPath) + + fmt.Fprintf(r.Writer, "# %s %s\n", label, name) + if _, err := os.Stat(filepath.Join(r.MigrationsDir, name+suffix)); err != nil { + fmt.Fprintln(r.Writer, "# ERROR: MISSING FILE") + } + fmt.Fprintln(r.Writer, cmd) + fmt.Fprintf(r.Writer, "%s > %s\n", syncCmd, logPath) + fmt.Fprintln(r.Writer) + + return nil +} + +// Applied reads the migrations log file and returns applied migrations. +// Supports two formats: +// - New: "id\tname" (tab-separated, written by updated _migrations.sql) +// - Old: "name" (name only, for backwards compatibility) +// +// Returns an empty slice if the file does not exist. When FS is set, reads +// from that filesystem; otherwise reads from the OS filesystem. +func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, error) { + var f io.ReadCloser + var err error + if r.FS != nil { + f, err = r.FS.Open(r.LogPath) + } else { + f, err = os.Open(r.LogPath) + } + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("reading migrations log: %w", err) + } + defer f.Close() + + var applied []sqlmigrate.AppliedMigration + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + // strip inline comments + if idx := strings.Index(line, "#"); idx >= 0 { + line = strings.TrimSpace(line[:idx]) + } + if line == "" { + continue + } + if id, name, ok := strings.Cut(line, "\t"); ok { + applied = append(applied, sqlmigrate.AppliedMigration{ID: id, Name: name}) + } else { + applied = append(applied, sqlmigrate.AppliedMigration{Name: line}) + } + } + + return applied, nil +} + +// Reset resets the migration counter. Call between Up and Down +// if generating both in the same script. +func (r *Migrator) Reset() { + r.counter = 0 +} + +// unclean ensures a relative path starts with ./ or ../ so it +// is not interpreted as a command name in shell scripts. +func unclean(path string) string { + if strings.HasPrefix(path, "/") { + return path + } + if strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") { + return path + } + return "./" + path +} diff --git a/database/sqlmigrate/sqlmigrate.go b/database/sqlmigrate/sqlmigrate.go new file mode 100644 index 0000000..c5f7a14 --- /dev/null +++ b/database/sqlmigrate/sqlmigrate.go @@ -0,0 +1,276 @@ +// Package sqlmigrate provides a database-agnostic SQL migration interface. +// +// Backend implementations (each a separate Go module): +// - pgmigrate: PostgreSQL via pgx/v5 +// - mymigrate: MySQL/MariaDB via go-sql-driver/mysql +// - litemigrate: SQLite via database/sql +// - msmigrate: SQL Server via go-mssqldb +// - shmigrate: POSIX shell script generation +package sqlmigrate + +import ( + "context" + "errors" + "fmt" + "io/fs" + "regexp" + "slices" + "strings" +) + +// Sentinel errors for migration operations. +var ( + ErrMissingUp = errors.New("missing up migration") + ErrMissingDown = errors.New("missing down migration") + ErrWalkFailed = errors.New("walking migrations") + ErrExecFailed = errors.New("migration exec failed") + ErrQueryApplied = errors.New("querying applied migrations") +) + +// Migration represents a paired up/down migration. +type Migration struct { + Name string // e.g. "2026-04-05-001000_create-todos" + ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect + Up string // SQL content of the .up.sql file + Down string // SQL content of the .down.sql file +} + +// AppliedMigration represents a migration recorded in the _migrations table. +type AppliedMigration struct { + ID string + Name string +} + +// Status represents the current migration state. +type Status struct { + Applied []string + Pending []string +} + +// Migrator executes migrations. Implementations handle the +// database-specific or output-specific details. +type Migrator interface { + // ExecUp runs the up migration. For database migrators this executes + // m.Up in a transaction. For shell migrators this outputs a command + // referencing the .up.sql file. + ExecUp(ctx context.Context, m Migration) error + + // ExecDown runs the down migration. + ExecDown(ctx context.Context, m Migration) error + + // Applied returns all applied migrations from the _migrations table, + // sorted lexicographically by name. Returns an empty slice (not an + // error) if the migrations table or log does not exist yet. + Applied(ctx context.Context) ([]AppliedMigration, error) +} + +// idFromInsert extracts the hex ID from an INSERT INTO _migrations line. +// Matches: INSERT INTO _migrations (name, id) VALUES ('...', ''); +var idFromInsert = regexp.MustCompile( + `(?i)INSERT\s+INTO\s+_migrations\s*\(\s*name\s*,\s*id\s*\)\s*VALUES\s*\(\s*'[^']*'\s*,\s*'([0-9a-fA-F]+)'\s*\)`, +) + +// Collect reads .up.sql and .down.sql files from fsys, pairs them by +// basename, and returns them sorted lexicographically by name. +// If the up SQL contains an INSERT INTO _migrations line, the hex ID +// is extracted and stored in Migration.ID. +func Collect(fsys fs.FS) ([]Migration, error) { + ups := map[string]string{} + downs := map[string]string{} + + err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + + name := d.Name() + if base, ok := strings.CutSuffix(name, ".up.sql"); ok { + b, readErr := fs.ReadFile(fsys, path) + if readErr != nil { + return readErr + } + ups[base] = string(b) + return nil + } + if base, ok := strings.CutSuffix(name, ".down.sql"); ok { + b, readErr := fs.ReadFile(fsys, path) + if readErr != nil { + return readErr + } + downs[base] = string(b) + return nil + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err) + } + + var migrations []Migration + for name, upSQL := range ups { + downSQL, ok := downs[name] + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingDown, name) + } + var id string + if m := idFromInsert.FindStringSubmatch(upSQL); m != nil { + id = m[1] + } + migrations = append(migrations, Migration{ + Name: name, + ID: id, + Up: upSQL, + Down: downSQL, + }) + } + for name := range downs { + if _, ok := ups[name]; !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingUp, name) + } + } + + slices.SortFunc(migrations, func(a, b Migration) int { + return strings.Compare(a.Name, b.Name) + }) + + return migrations, nil +} + +// NamesOnly builds a Migration slice from a list of names, with empty +// Up/Down content. Useful for shell-based runners that reference files +// on disk rather than executing SQL directly. +func NamesOnly(names []string) []Migration { + migrations := make([]Migration, len(names)) + for i, name := range names { + migrations[i] = Migration{Name: name} + } + return migrations +} + +// isApplied returns true if the migration matches any applied entry by name or ID. +func isApplied(m Migration, applied []AppliedMigration) bool { + for _, a := range applied { + if a.Name == m.Name { + return true + } + if m.ID != "" && a.ID != "" && a.ID == m.ID { + return true + } + } + return false +} + +// findMigration looks up a migration by the applied entry's name or ID. +func findMigration(a AppliedMigration, byName map[string]Migration, byID map[string]Migration) (Migration, bool) { + if m, ok := byName[a.Name]; ok { + return m, true + } + if a.ID != "" { + if m, ok := byID[a.ID]; ok { + return m, true + } + } + return Migration{}, false +} + +// Up applies up to n pending migrations using the given Runner. +// If n <= 0, applies all pending. Returns the names of applied migrations. +func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { + applied, err := r.Applied(ctx) + if err != nil { + return nil, err + } + + var pending []Migration + for _, m := range migrations { + if !isApplied(m, applied) { + pending = append(pending, m) + } + } + + if n <= 0 { + n = len(pending) + } + if n > len(pending) { + n = len(pending) + } + + var ran []string + for _, m := range pending[:n] { + if err := r.ExecUp(ctx, m); err != nil { + return ran, fmt.Errorf("%s (up): %w", m.Name, err) + } + ran = append(ran, m.Name) + } + + return ran, nil +} + +// Down rolls back up to n applied migrations, most recent first. +// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations. +func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) { + applied, err := r.Applied(ctx) + if err != nil { + return nil, err + } + + byName := map[string]Migration{} + byID := map[string]Migration{} + for _, m := range migrations { + byName[m.Name] = m + if m.ID != "" { + byID[m.ID] = m + } + } + + reversed := make([]AppliedMigration, len(applied)) + copy(reversed, applied) + slices.Reverse(reversed) + + if n <= 0 || n > len(reversed) { + n = len(reversed) + } + + var ran []string + for _, a := range reversed[:n] { + m, ok := findMigration(a, byName, byID) + if !ok { + return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name) + } + if err := r.ExecDown(ctx, m); err != nil { + return ran, fmt.Errorf("%s (down): %w", a.Name, err) + } + ran = append(ran, a.Name) + } + + return ran, nil +} + +// GetStatus returns applied and pending migration lists. +func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status, error) { + applied, err := r.Applied(ctx) + if err != nil { + return nil, err + } + + appliedNames := make([]string, len(applied)) + for i, a := range applied { + appliedNames[i] = a.Name + } + + var pending []string + for _, m := range migrations { + if !isApplied(m, applied) { + pending = append(pending, m.Name) + } + } + + return &Status{ + Applied: appliedNames, + Pending: pending, + }, nil +} diff --git a/database/sqlmigrate/sqlmigrate_test.go b/database/sqlmigrate/sqlmigrate_test.go new file mode 100644 index 0000000..5111cf0 --- /dev/null +++ b/database/sqlmigrate/sqlmigrate_test.go @@ -0,0 +1,457 @@ +package sqlmigrate_test + +import ( + "context" + "errors" + "slices" + "strings" + "testing" + "testing/fstest" + + "github.com/therootcompany/golib/database/sqlmigrate" +) + +// applied builds an []AppliedMigration from names (IDs empty). +func applied(names ...string) []sqlmigrate.AppliedMigration { + out := make([]sqlmigrate.AppliedMigration, len(names)) + for i, n := range names { + out[i] = sqlmigrate.AppliedMigration{Name: n} + } + return out +} + +// mockMigrator tracks applied migrations in memory. +type mockMigrator struct { + applied []sqlmigrate.AppliedMigration + execErr error // if set, ExecUp/ExecDown return this on every call + upCalls []string + downCalls []string +} + +func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error { + m.upCalls = append(m.upCalls, mig.Name) + if m.execErr != nil { + return m.execErr + } + m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID}) + slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int { + return strings.Compare(a.Name, b.Name) + }) + return nil +} + +func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error { + m.downCalls = append(m.downCalls, mig.Name) + if m.execErr != nil { + return m.execErr + } + m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name }) + return nil +} + +func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) { + return slices.Clone(m.applied), nil +} + +// --- Collect --- + +func TestCollect(t *testing.T) { + t.Run("pairs and sorts", func(t *testing.T) { + fsys := fstest.MapFS{ + "002_second.up.sql": {Data: []byte("CREATE TABLE b;")}, + "002_second.down.sql": {Data: []byte("DROP TABLE b;")}, + "001_first.up.sql": {Data: []byte("CREATE TABLE a;")}, + "001_first.down.sql": {Data: []byte("DROP TABLE a;")}, + } + migrations, err := sqlmigrate.Collect(fsys) + if err != nil { + t.Fatal(err) + } + if len(migrations) != 2 { + t.Fatalf("got %d migrations, want 2", len(migrations)) + } + if migrations[0].Name != "001_first" { + t.Errorf("first = %q, want %q", migrations[0].Name, "001_first") + } + if migrations[1].Name != "002_second" { + t.Errorf("second = %q, want %q", migrations[1].Name, "002_second") + } + if migrations[0].Up != "CREATE TABLE a;" { + t.Errorf("first.Up = %q", migrations[0].Up) + } + if migrations[0].Down != "DROP TABLE a;" { + t.Errorf("first.Down = %q", migrations[0].Down) + } + }) + + t.Run("parses ID from INSERT", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")}, + "001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")}, + } + migrations, err := sqlmigrate.Collect(fsys) + if err != nil { + t.Fatal(err) + } + if migrations[0].ID != "abcd1234" { + t.Errorf("ID = %q, want %q", migrations[0].ID, "abcd1234") + } + }) + + t.Run("no ID when no INSERT", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte("CREATE TABLE a;")}, + "001_init.down.sql": {Data: []byte("DROP TABLE a;")}, + } + migrations, err := sqlmigrate.Collect(fsys) + if err != nil { + t.Fatal(err) + } + if migrations[0].ID != "" { + t.Errorf("ID = %q, want empty", migrations[0].ID) + } + }) + + t.Run("missing down", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_only-up.up.sql": {Data: []byte("CREATE TABLE x;")}, + } + _, err := sqlmigrate.Collect(fsys) + if !errors.Is(err, sqlmigrate.ErrMissingDown) { + t.Errorf("got %v, want ErrMissingDown", err) + } + }) + + t.Run("missing up", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_only-down.down.sql": {Data: []byte("DROP TABLE x;")}, + } + _, err := sqlmigrate.Collect(fsys) + if !errors.Is(err, sqlmigrate.ErrMissingUp) { + t.Errorf("got %v, want ErrMissingUp", err) + } + }) + + t.Run("ignores non-sql files", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.up.sql": {Data: []byte("UP")}, + "001_init.down.sql": {Data: []byte("DOWN")}, + "README.md": {Data: []byte("# Migrations")}, + "_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")}, + } + migrations, err := sqlmigrate.Collect(fsys) + if err != nil { + t.Fatal(err) + } + if len(migrations) != 1 { + t.Fatalf("got %d migrations, want 1", len(migrations)) + } + }) + + t.Run("empty fs", func(t *testing.T) { + fsys := fstest.MapFS{} + migrations, err := sqlmigrate.Collect(fsys) + if err != nil { + t.Fatal(err) + } + if len(migrations) != 0 { + t.Fatalf("got %d migrations, want 0", len(migrations)) + } + }) +} + +// --- NamesOnly --- + +func TestNamesOnly(t *testing.T) { + names := []string{"001_init", "002_users"} + migrations := sqlmigrate.NamesOnly(names) + if len(migrations) != 2 { + t.Fatalf("got %d, want 2", len(migrations)) + } + for i, m := range migrations { + if m.Name != names[i] { + t.Errorf("[%d].Name = %q, want %q", i, m.Name, names[i]) + } + if m.Up != "" || m.Down != "" { + t.Errorf("[%d] has non-empty content", i) + } + } +} + +// --- Up --- + +func TestUp(t *testing.T) { + ctx := t.Context() + migrations := []sqlmigrate.Migration{ + {Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, + {Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, + } + + t.Run("apply all", func(t *testing.T) { + m := &mockMigrator{} + ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(ran, []string{"001_init", "002_users", "003_posts"}) { + t.Errorf("applied = %v", ran) + } + }) + + t.Run("apply n", func(t *testing.T) { + m := &mockMigrator{} + ran, err := sqlmigrate.Up(ctx, m, migrations, 2) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(ran, []string{"001_init", "002_users"}) { + t.Errorf("applied = %v", ran) + } + }) + + t.Run("none pending", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} + ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + if err != nil { + t.Fatal(err) + } + if len(ran) != 0 { + t.Fatalf("applied %d, want 0", len(ran)) + } + }) + + t.Run("partial pending", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init")} + ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(ran, []string{"002_users", "003_posts"}) { + t.Errorf("applied = %v", ran) + } + }) + + t.Run("n exceeds pending", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init")} + ran, err := sqlmigrate.Up(ctx, m, migrations, 99) + if err != nil { + t.Fatal(err) + } + if len(ran) != 2 { + t.Fatalf("applied %d, want 2", len(ran)) + } + }) + + t.Run("exec error stops and returns partial", func(t *testing.T) { + m := &failOnNthMigrator{failAt: 1} + ran, err := sqlmigrate.Up(ctx, m, migrations, 0) + if err == nil { + t.Fatal("expected error") + } + if len(ran) != 1 { + t.Errorf("applied %d before error, want 1", len(ran)) + } + }) + + t.Run("skips migration matched by ID", func(t *testing.T) { + // DB has migration applied under old name, but same ID + m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{ + {Name: "001_old-name", ID: "aa11bb22"}, + }} + migs := []sqlmigrate.Migration{ + {Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, + } + ran, err := sqlmigrate.Up(ctx, m, migs, 0) + if err != nil { + t.Fatal(err) + } + // Only 002_users should be applied; 001 is matched by ID + if !slices.Equal(ran, []string{"002_users"}) { + t.Errorf("applied = %v, want [002_users]", ran) + } + }) +} + +// failOnNthMigrator fails on the Nth ExecUp call (0-indexed). +type failOnNthMigrator struct { + applied []sqlmigrate.AppliedMigration + calls int + failAt int +} + +func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error { + if m.calls == m.failAt { + m.calls++ + return errors.New("connection lost") + } + m.calls++ + m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID}) + slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int { + return strings.Compare(a.Name, b.Name) + }) + return nil +} + +func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error { + m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name }) + return nil +} + +func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) { + return slices.Clone(m.applied), nil +} + +// --- Down --- + +func TestDown(t *testing.T) { + ctx := t.Context() + migrations := []sqlmigrate.Migration{ + {Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + {Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"}, + {Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"}, + } + + t.Run("rollback all", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} + rolled, err := sqlmigrate.Down(ctx, m, migrations, 0) + if err != nil { + t.Fatal(err) + } + if len(rolled) != 3 { + t.Fatalf("rolled %d, want 3", len(rolled)) + } + if rolled[0] != "003_posts" { + t.Errorf("first rollback = %q, want 003_posts", rolled[0]) + } + if rolled[2] != "001_init" { + t.Errorf("last rollback = %q, want 001_init", rolled[2]) + } + }) + + t.Run("rollback n", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} + rolled, err := sqlmigrate.Down(ctx, m, migrations, 2) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(rolled, []string{"003_posts", "002_users"}) { + t.Errorf("rolled = %v", rolled) + } + }) + + t.Run("none applied", func(t *testing.T) { + m := &mockMigrator{} + rolled, err := sqlmigrate.Down(ctx, m, migrations, 0) + if err != nil { + t.Fatal(err) + } + if len(rolled) != 0 { + t.Fatalf("rolled %d, want 0", len(rolled)) + } + }) + + t.Run("exec error", func(t *testing.T) { + m := &mockMigrator{ + applied: applied("001_init", "002_users"), + execErr: errors.New("permission denied"), + } + rolled, err := sqlmigrate.Down(ctx, m, migrations, 1) + if err == nil { + t.Fatal("expected error") + } + if len(rolled) != 0 { + t.Errorf("rolled %d on error, want 0", len(rolled)) + } + }) + + t.Run("unknown migration in applied", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init", "999_unknown")} + _, err := sqlmigrate.Down(ctx, m, migrations, 0) + if !errors.Is(err, sqlmigrate.ErrMissingDown) { + t.Errorf("got %v, want ErrMissingDown", err) + } + }) + + t.Run("n exceeds applied", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init")} + rolled, err := sqlmigrate.Down(ctx, m, migrations, 99) + if err != nil { + t.Fatal(err) + } + if len(rolled) != 1 { + t.Fatalf("rolled %d, want 1", len(rolled)) + } + }) + + t.Run("finds migration by ID when name changed", func(t *testing.T) { + // DB has old name, file has new name, same ID + m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{ + {Name: "001_old-name", ID: "aa11bb22"}, + }} + migs := []sqlmigrate.Migration{ + {Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"}, + } + rolled, err := sqlmigrate.Down(ctx, m, migs, 1) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(rolled, []string{"001_old-name"}) { + t.Errorf("rolled = %v, want [001_old-name]", rolled) + } + }) +} + +// --- GetStatus --- + +func TestGetStatus(t *testing.T) { + ctx := t.Context() + migrations := []sqlmigrate.Migration{ + {Name: "001_init"}, + {Name: "002_users"}, + {Name: "003_posts"}, + } + + t.Run("all pending", func(t *testing.T) { + m := &mockMigrator{} + status, err := sqlmigrate.GetStatus(ctx, m, migrations) + if err != nil { + t.Fatal(err) + } + if len(status.Applied) != 0 { + t.Errorf("applied = %d, want 0", len(status.Applied)) + } + if len(status.Pending) != 3 { + t.Errorf("pending = %d, want 3", len(status.Pending)) + } + }) + + t.Run("partial", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init")} + status, err := sqlmigrate.GetStatus(ctx, m, migrations) + if err != nil { + t.Fatal(err) + } + if len(status.Applied) != 1 || status.Applied[0] != "001_init" { + t.Errorf("applied = %v", status.Applied) + } + if !slices.Equal(status.Pending, []string{"002_users", "003_posts"}) { + t.Errorf("pending = %v", status.Pending) + } + }) + + t.Run("all applied", func(t *testing.T) { + m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")} + status, err := sqlmigrate.GetStatus(ctx, m, migrations) + if err != nil { + t.Fatal(err) + } + if len(status.Applied) != 3 { + t.Errorf("applied = %d, want 3", len(status.Applied)) + } + if len(status.Pending) != 0 { + t.Errorf("pending = %d, want 0", len(status.Pending)) + } + }) +}