ref(database/sqlmigrate): extract migration library with shmigrate backend

Factor the inline migration logic from cmd/sql-migrate into reusable
packages: database/sqlmigrate (core types, matching, file collection)
and database/sqlmigrate/shmigrate (shell script generation backend).

No behavior changes — the CLI produces identical output. The shmigrate
package implements the sqlmigrate.Migrator interface so other backends
(pgmigrate, mymigrate, etc.) can follow the same pattern.
This commit is contained in:
AJ ONeal 2026-04-08 15:03:56 -06:00
parent c4964a5b65
commit 3547b7e409
No known key found for this signature in database
8 changed files with 1051 additions and 145 deletions

View File

@ -1,3 +1,13 @@
module github.com/therootcompany/golib/cmd/sql-migrate/v2 module github.com/therootcompany/golib/cmd/sql-migrate/v2
go 1.25.3 go 1.26.1
require (
github.com/therootcompany/golib/database/sqlmigrate v0.0.0
github.com/therootcompany/golib/database/sqlmigrate/shmigrate v0.0.0
)
replace (
github.com/therootcompany/golib/database/sqlmigrate => ../../database/sqlmigrate
github.com/therootcompany/golib/database/sqlmigrate/shmigrate => ../../database/sqlmigrate/shmigrate
)

View File

@ -13,6 +13,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"flag" "flag"
@ -27,6 +28,9 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/therootcompany/golib/database/sqlmigrate"
"github.com/therootcompany/golib/database/sqlmigrate/shmigrate"
) )
// replaced by goreleaser / ldflags // replaced by goreleaser / ldflags
@ -67,14 +71,6 @@ DROP TABLE IF EXISTS _migrations;
` `
LOG_MIGRATIONS_QUERY = `-- note: CLI arguments must be passed to the sql command to keep output clean LOG_MIGRATIONS_QUERY = `-- note: CLI arguments must be passed to the sql command to keep output clean
SELECT name FROM _migrations ORDER BY name; SELECT name FROM _migrations ORDER BY name;
`
shHeader = `#!/bin/sh
set -e
set -u
if test -s ./.env; then
. ./.env
fi
` `
) )
@ -297,14 +293,21 @@ func main() {
os.Exit(1) os.Exit(1)
} }
ctx := context.Background()
runner := &shmigrate.Migrator{
Writer: os.Stdout,
SqlCommand: state.SQLCommand,
MigrationsDir: state.MigrationsDir,
LogQueryPath: filepath.Join(state.MigrationsDir, LOG_QUERY_NAME),
LogPath: state.LogPath,
}
migrations := sqlmigrate.NamesOnly(ups)
switch subcmd { switch subcmd {
case "init": case "init":
break break
case "sync": case "sync":
if err := syncLog(&state); err != nil { syncLog(runner)
log.Fatal(err)
}
return
case "create": case "create":
if len(leafArgs) == 0 { if len(leafArgs) == 0 {
log.Fatal("create requires a description") log.Fatal("create requires a description")
@ -324,8 +327,7 @@ func main() {
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " ")) fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
os.Exit(1) os.Exit(1)
} }
err = status(&state, ups) if err := cmdStatus(ctx, &state, runner, migrations); err != nil {
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
case "list": case "list":
@ -364,12 +366,11 @@ func main() {
os.Exit(1) os.Exit(1)
} }
err = up(&state, ups, upN) if err := cmdUp(ctx, &state, runner, migrations, upN); err != nil {
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
case "down": case "down":
var downN int downN := 1
switch len(leafArgs) { switch len(leafArgs) {
case 0: case 0:
// default: roll back one // default: roll back one
@ -384,8 +385,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
err = down(&state, downN) if err := cmdDown(ctx, &state, runner, migrations, downN); err != nil {
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
default: default:
@ -875,190 +875,171 @@ func fixupMigration(dir string, basename string) (up, down bool, warn error, err
return up, down, nil, nil return up, down, nil, nil
} }
func syncLog(state *State) error { func syncLog(runner *shmigrate.Migrator) {
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
getMigsPath = filepathUnclean(getMigsPath) logPath := filepathUnclean(runner.LogPath)
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
logPath := filepathUnclean(state.LogPath)
fmt.Printf(shHeader) fmt.Printf(shmigrate.ShHeader)
fmt.Println("") fmt.Println("")
fmt.Println("# SYNC: reload migrations log from DB") fmt.Println("# SYNC: reload migrations log from DB")
fmt.Printf("%s > %s || true\n", getMigs, logPath) fmt.Printf("%s > %s || true\n", syncCmd, logPath)
fmt.Printf("cat %s\n", logPath) fmt.Printf("cat %s\n", logPath)
return nil
} }
func up(state *State, ups []string, n int) error { func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error {
var pending []string // fixup pending migrations before generating the script
for _, mig := range ups { fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
found := slices.Contains(state.Migrated, mig)
if !found { status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
pending = append(pending, mig) if err != nil {
} return err
} }
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) if len(status.Pending) == 0 {
getMigsPath = filepathUnclean(getMigsPath) syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
if len(pending) == 0 {
fmt.Fprintf(os.Stderr, "# Already up-to-date\n") fmt.Fprintf(os.Stderr, "# Already up-to-date\n")
fmt.Fprintf(os.Stderr, "#\n") fmt.Fprintf(os.Stderr, "#\n")
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n") fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath)) fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath))
return nil return nil
} }
if n == 0 {
n = len(pending)
}
fixedUp := []string{} fmt.Printf(shmigrate.ShHeader)
fixedDown := []string{}
fmt.Printf(shHeader)
fmt.Println("") fmt.Println("")
fmt.Println("# FORWARD / UP Migrations") fmt.Println("# FORWARD / UP Migrations")
fmt.Println("") fmt.Println("")
for i, migration := range pending {
if i >= n {
break
}
path := filepath.Join(state.MigrationsDir, migration+".up.sql") applied, err := sqlmigrate.Up(ctx, runner, migrations, n)
path = filepathUnclean(path) if err != nil {
{ return err
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
if up {
fixedUp = append(fixedUp, migration)
}
if down {
fixedDown = append(fixedDown, migration)
}
}
cmd := strings.Replace(state.SQLCommand, "%s", path, 1)
fmt.Printf("# +%d %s\n", i+1, migration)
fmt.Println(cmd)
fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath))
fmt.Println("")
} }
fmt.Println("cat", filepathUnclean(state.LogPath)) _ = applied
fmt.Println("cat", filepathUnclean(runner.LogPath))
showFixes(fixedUp, fixedDown) showFixes(fixedUp, fixedDown)
return nil return nil
} }
func down(state *State, n int) error { func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration, n int) error {
lines := make([]string, len(state.Lines)) // fixup applied migrations before generating the script
copy(lines, state.Lines) fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
slices.Reverse(lines)
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
getMigsPath = filepathUnclean(getMigsPath) if err != nil {
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1) return err
}
if len(lines) == 0 { if len(status.Applied) == 0 {
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
fmt.Fprintf(os.Stderr, "# No migration history\n") fmt.Fprintf(os.Stderr, "# No migration history\n")
fmt.Fprintf(os.Stderr, "#\n") fmt.Fprintf(os.Stderr, "#\n")
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n") fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath)) fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath))
return nil return nil
} }
if n == 0 {
n = 1
}
fixedUp := []string{} fmt.Printf(shmigrate.ShHeader)
fixedDown := []string{}
var applied []string
for _, line := range lines {
migration := commentStartRe.ReplaceAllString(line, "")
migration = strings.TrimSpace(migration)
if migration == "" {
continue
}
applied = append(applied, migration)
}
fmt.Printf(shHeader)
fmt.Println("") fmt.Println("")
fmt.Println("# ROLLBACK / DOWN Migration") fmt.Println("# ROLLBACK / DOWN Migration")
fmt.Println("") fmt.Println("")
for i, migration := range applied {
if i >= n {
break
}
downPath := filepath.Join(state.MigrationsDir, migration+".down.sql") // check for missing down files before generating script
cmd := strings.Replace(state.SQLCommand, "%s", downPath, 1) reversed := make([]string, len(status.Applied))
fmt.Printf("\n# -%d %s\n", i+1, migration) copy(reversed, status.Applied)
slices.Reverse(reversed)
limit := n
if limit <= 0 {
limit = 1
}
if limit > len(reversed) {
limit = len(reversed)
}
for _, name := range reversed[:limit] {
downPath := filepath.Join(state.MigrationsDir, name+".down.sql")
if !fileExists(downPath) { if !fileExists(downPath) {
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath)) fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n") fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
fmt.Printf("# ERROR: MISSING FILE\n")
} else {
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
if up {
fixedUp = append(fixedUp, migration)
}
if down {
fixedDown = append(fixedDown, migration)
}
} }
fmt.Println(cmd)
fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath))
fmt.Println("")
} }
fmt.Println("cat", filepathUnclean(state.LogPath))
rolled, err := sqlmigrate.Down(ctx, runner, migrations, n)
if err != nil {
return err
}
_ = rolled
fmt.Println("cat", filepathUnclean(runner.LogPath))
showFixes(fixedUp, fixedDown) showFixes(fixedUp, fixedDown)
return nil return nil
} }
func status(state *State, ups []string) error { func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Migration) error {
previous := make([]string, len(state.Lines)) status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
copy(previous, state.Lines) if err != nil {
slices.Reverse(previous) return err
}
fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir)) fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir))
fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath)) fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath))
fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand) fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand)
fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "\n")
fmt.Printf("# previous: %d\n", len(previous))
for _, mig := range previous { // show applied in reverse (most recent first)
applied := make([]string, len(status.Applied))
copy(applied, status.Applied)
slices.Reverse(applied)
fmt.Printf("# previous: %d\n", len(applied))
for _, mig := range applied {
fmt.Printf(" %s\n", mig) fmt.Printf(" %s\n", mig)
} }
if len(previous) == 0 { if len(applied) == 0 {
fmt.Println(" # (no previous migrations)") fmt.Println(" # (no previous migrations)")
} }
fmt.Println("") fmt.Println("")
var pending []string fmt.Printf("# pending: %d\n", len(status.Pending))
for _, mig := range ups { for _, mig := range status.Pending {
found := slices.Contains(state.Migrated, mig)
if !found {
pending = append(pending, mig)
}
}
fmt.Printf("# pending: %d\n", len(pending))
for _, mig := range pending {
fmt.Printf(" %s\n", mig) fmt.Printf(" %s\n", mig)
} }
if len(pending) == 0 { if len(status.Pending) == 0 {
fmt.Println(" # (no pending migrations)") fmt.Println(" # (no pending migrations)")
} }
return nil return nil
} }
// fixupAll runs fixupMigration on all known migrations (applied + pending).
func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Migration) (fixedUp, fixedDown []string) {
seen := map[string]bool{}
var all []string
for _, name := range applied {
if !seen[name] {
all = append(all, name)
seen[name] = true
}
}
for _, m := range migrations {
if !seen[m.Name] {
all = append(all, m.Name)
seen[m.Name] = true
}
}
for _, name := range all {
up, down, warn, err := fixupMigration(migrationsDir, name)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
if up {
fixedUp = append(fixedUp, name)
}
if down {
fixedDown = append(fixedDown, name)
}
}
return fixedUp, fixedDown
}

View File

@ -0,0 +1,22 @@
# sqlmigrate
Database-agnostic SQL migration library for Go.
[![Go Reference](https://pkg.go.dev/badge/github.com/therootcompany/golib/database/sqlmigrate.svg)](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate)
## Backend packages
Each backend is a separate Go module to avoid pulling unnecessary drivers:
| Package | Database | Driver |
|---------|----------|--------|
| [pgmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/pgmigrate) | PostgreSQL | pgx/v5 |
| [mymigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/mymigrate) | MySQL / MariaDB | go-sql-driver/mysql |
| [litemigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/litemigrate) | SQLite | database/sql (caller imports driver) |
| [msmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/msmigrate) | SQL Server | go-mssqldb |
| [shmigrate](https://pkg.go.dev/github.com/therootcompany/golib/database/sqlmigrate/shmigrate) | Shell scripts | (generates POSIX sh) |
## CLI
The [sql-migrate](https://pkg.go.dev/github.com/therootcompany/golib/cmd/sql-migrate/v2) CLI
uses shmigrate to generate shell scripts for managing migrations without a Go dependency at runtime.

View File

@ -0,0 +1,3 @@
module github.com/therootcompany/golib/database/sqlmigrate
go 1.26.1

View File

@ -0,0 +1,7 @@
module github.com/therootcompany/golib/database/sqlmigrate/shmigrate
go 1.26.1
require github.com/therootcompany/golib/database/sqlmigrate v0.0.0
replace github.com/therootcompany/golib/database/sqlmigrate => ../

View File

@ -0,0 +1,150 @@
// Package shmigrate implements sqlmigrate.Runner by generating POSIX
// shell commands that reference migration files on disk. It is used by
// the sql-migrate CLI to produce scripts that can be piped to sh.
package shmigrate
import (
"bufio"
"context"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/therootcompany/golib/database/sqlmigrate"
)
// ShHeader is the standard header for generated shell scripts.
const ShHeader = `#!/bin/sh
set -e
set -u
if test -s ./.env; then
. ./.env
fi
`
// Migrator generates shell scripts for migration execution.
// It implements sqlmigrate.Migrator so it can be used with
// sqlmigrate.Up, sqlmigrate.Down, and sqlmigrate.GetStatus.
type Migrator struct {
// Writer receives the generated shell script.
Writer io.Writer
// SqlCommand is the shell command template with %s for the file path.
// Example: `psql "$PG_URL" -v ON_ERROR_STOP=on -A -t --file %s`
SqlCommand string
// MigrationsDir is the path to the migrations directory on disk.
MigrationsDir string
// LogQueryPath is the path to the _migrations.sql query file.
// Used to sync the migrations log after each migration.
LogQueryPath string
// LogPath is the path to the migrations.log file.
LogPath string
// FS is an optional filesystem for reading the migrations log.
// When nil, the OS filesystem is used.
FS fs.FS
counter int
}
// verify interface compliance at compile time
var _ sqlmigrate.Migrator = (*Migrator)(nil)
// ExecUp outputs a shell command to run the .up.sql migration file.
func (r *Migrator) ExecUp(ctx context.Context, m sqlmigrate.Migration) error {
r.counter++
return r.exec(m.Name, ".up.sql", fmt.Sprintf("+%d", r.counter))
}
// ExecDown outputs a shell command to run the .down.sql migration file.
func (r *Migrator) ExecDown(ctx context.Context, m sqlmigrate.Migration) error {
r.counter++
return r.exec(m.Name, ".down.sql", fmt.Sprintf("-%d", r.counter))
}
func (r *Migrator) exec(name, suffix, label string) error {
path := unclean(filepath.Join(r.MigrationsDir, name+suffix))
cmd := strings.Replace(r.SqlCommand, "%s", path, 1)
syncCmd := strings.Replace(r.SqlCommand, "%s", unclean(r.LogQueryPath), 1)
logPath := unclean(r.LogPath)
fmt.Fprintf(r.Writer, "# %s %s\n", label, name)
if _, err := os.Stat(filepath.Join(r.MigrationsDir, name+suffix)); err != nil {
fmt.Fprintln(r.Writer, "# ERROR: MISSING FILE")
}
fmt.Fprintln(r.Writer, cmd)
fmt.Fprintf(r.Writer, "%s > %s\n", syncCmd, logPath)
fmt.Fprintln(r.Writer)
return nil
}
// Applied reads the migrations log file and returns applied migrations.
// Supports two formats:
// - New: "id\tname" (tab-separated, written by updated _migrations.sql)
// - Old: "name" (name only, for backwards compatibility)
//
// Returns an empty slice if the file does not exist. When FS is set, reads
// from that filesystem; otherwise reads from the OS filesystem.
func (r *Migrator) Applied(ctx context.Context) ([]sqlmigrate.AppliedMigration, error) {
var f io.ReadCloser
var err error
if r.FS != nil {
f, err = r.FS.Open(r.LogPath)
} else {
f, err = os.Open(r.LogPath)
}
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("reading migrations log: %w", err)
}
defer f.Close()
var applied []sqlmigrate.AppliedMigration
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// strip inline comments
if idx := strings.Index(line, "#"); idx >= 0 {
line = strings.TrimSpace(line[:idx])
}
if line == "" {
continue
}
if id, name, ok := strings.Cut(line, "\t"); ok {
applied = append(applied, sqlmigrate.AppliedMigration{ID: id, Name: name})
} else {
applied = append(applied, sqlmigrate.AppliedMigration{Name: line})
}
}
return applied, nil
}
// Reset resets the migration counter. Call between Up and Down
// if generating both in the same script.
func (r *Migrator) Reset() {
r.counter = 0
}
// unclean ensures a relative path starts with ./ or ../ so it
// is not interpreted as a command name in shell scripts.
func unclean(path string) string {
if strings.HasPrefix(path, "/") {
return path
}
if strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") {
return path
}
return "./" + path
}

View File

@ -0,0 +1,276 @@
// Package sqlmigrate provides a database-agnostic SQL migration interface.
//
// Backend implementations (each a separate Go module):
// - pgmigrate: PostgreSQL via pgx/v5
// - mymigrate: MySQL/MariaDB via go-sql-driver/mysql
// - litemigrate: SQLite via database/sql
// - msmigrate: SQL Server via go-mssqldb
// - shmigrate: POSIX shell script generation
package sqlmigrate
import (
"context"
"errors"
"fmt"
"io/fs"
"regexp"
"slices"
"strings"
)
// Sentinel errors for migration operations.
var (
ErrMissingUp = errors.New("missing up migration")
ErrMissingDown = errors.New("missing down migration")
ErrWalkFailed = errors.New("walking migrations")
ErrExecFailed = errors.New("migration exec failed")
ErrQueryApplied = errors.New("querying applied migrations")
)
// Migration represents a paired up/down migration.
type Migration struct {
Name string // e.g. "2026-04-05-001000_create-todos"
ID string // 8-char hex from INSERT INTO _migrations, parsed by Collect
Up string // SQL content of the .up.sql file
Down string // SQL content of the .down.sql file
}
// AppliedMigration represents a migration recorded in the _migrations table.
type AppliedMigration struct {
ID string
Name string
}
// Status represents the current migration state.
type Status struct {
Applied []string
Pending []string
}
// Migrator executes migrations. Implementations handle the
// database-specific or output-specific details.
type Migrator interface {
// ExecUp runs the up migration. For database migrators this executes
// m.Up in a transaction. For shell migrators this outputs a command
// referencing the .up.sql file.
ExecUp(ctx context.Context, m Migration) error
// ExecDown runs the down migration.
ExecDown(ctx context.Context, m Migration) error
// Applied returns all applied migrations from the _migrations table,
// sorted lexicographically by name. Returns an empty slice (not an
// error) if the migrations table or log does not exist yet.
Applied(ctx context.Context) ([]AppliedMigration, error)
}
// idFromInsert extracts the hex ID from an INSERT INTO _migrations line.
// Matches: INSERT INTO _migrations (name, id) VALUES ('...', '<hex>');
var idFromInsert = regexp.MustCompile(
`(?i)INSERT\s+INTO\s+_migrations\s*\(\s*name\s*,\s*id\s*\)\s*VALUES\s*\(\s*'[^']*'\s*,\s*'([0-9a-fA-F]+)'\s*\)`,
)
// Collect reads .up.sql and .down.sql files from fsys, pairs them by
// basename, and returns them sorted lexicographically by name.
// If the up SQL contains an INSERT INTO _migrations line, the hex ID
// is extracted and stored in Migration.ID.
func Collect(fsys fs.FS) ([]Migration, error) {
ups := map[string]string{}
downs := map[string]string{}
err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
name := d.Name()
if base, ok := strings.CutSuffix(name, ".up.sql"); ok {
b, readErr := fs.ReadFile(fsys, path)
if readErr != nil {
return readErr
}
ups[base] = string(b)
return nil
}
if base, ok := strings.CutSuffix(name, ".down.sql"); ok {
b, readErr := fs.ReadFile(fsys, path)
if readErr != nil {
return readErr
}
downs[base] = string(b)
return nil
}
return nil
})
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err)
}
var migrations []Migration
for name, upSQL := range ups {
downSQL, ok := downs[name]
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingDown, name)
}
var id string
if m := idFromInsert.FindStringSubmatch(upSQL); m != nil {
id = m[1]
}
migrations = append(migrations, Migration{
Name: name,
ID: id,
Up: upSQL,
Down: downSQL,
})
}
for name := range downs {
if _, ok := ups[name]; !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingUp, name)
}
}
slices.SortFunc(migrations, func(a, b Migration) int {
return strings.Compare(a.Name, b.Name)
})
return migrations, nil
}
// NamesOnly builds a Migration slice from a list of names, with empty
// Up/Down content. Useful for shell-based runners that reference files
// on disk rather than executing SQL directly.
func NamesOnly(names []string) []Migration {
migrations := make([]Migration, len(names))
for i, name := range names {
migrations[i] = Migration{Name: name}
}
return migrations
}
// isApplied returns true if the migration matches any applied entry by name or ID.
func isApplied(m Migration, applied []AppliedMigration) bool {
for _, a := range applied {
if a.Name == m.Name {
return true
}
if m.ID != "" && a.ID != "" && a.ID == m.ID {
return true
}
}
return false
}
// findMigration looks up a migration by the applied entry's name or ID.
func findMigration(a AppliedMigration, byName map[string]Migration, byID map[string]Migration) (Migration, bool) {
if m, ok := byName[a.Name]; ok {
return m, true
}
if a.ID != "" {
if m, ok := byID[a.ID]; ok {
return m, true
}
}
return Migration{}, false
}
// Up applies up to n pending migrations using the given Runner.
// If n <= 0, applies all pending. Returns the names of applied migrations.
func Up(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
applied, err := r.Applied(ctx)
if err != nil {
return nil, err
}
var pending []Migration
for _, m := range migrations {
if !isApplied(m, applied) {
pending = append(pending, m)
}
}
if n <= 0 {
n = len(pending)
}
if n > len(pending) {
n = len(pending)
}
var ran []string
for _, m := range pending[:n] {
if err := r.ExecUp(ctx, m); err != nil {
return ran, fmt.Errorf("%s (up): %w", m.Name, err)
}
ran = append(ran, m.Name)
}
return ran, nil
}
// Down rolls back up to n applied migrations, most recent first.
// If n <= 0, rolls back all applied. Returns the names of rolled-back migrations.
func Down(ctx context.Context, r Migrator, migrations []Migration, n int) ([]string, error) {
applied, err := r.Applied(ctx)
if err != nil {
return nil, err
}
byName := map[string]Migration{}
byID := map[string]Migration{}
for _, m := range migrations {
byName[m.Name] = m
if m.ID != "" {
byID[m.ID] = m
}
}
reversed := make([]AppliedMigration, len(applied))
copy(reversed, applied)
slices.Reverse(reversed)
if n <= 0 || n > len(reversed) {
n = len(reversed)
}
var ran []string
for _, a := range reversed[:n] {
m, ok := findMigration(a, byName, byID)
if !ok {
return ran, fmt.Errorf("%w: %s", ErrMissingDown, a.Name)
}
if err := r.ExecDown(ctx, m); err != nil {
return ran, fmt.Errorf("%s (down): %w", a.Name, err)
}
ran = append(ran, a.Name)
}
return ran, nil
}
// GetStatus returns applied and pending migration lists.
func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status, error) {
applied, err := r.Applied(ctx)
if err != nil {
return nil, err
}
appliedNames := make([]string, len(applied))
for i, a := range applied {
appliedNames[i] = a.Name
}
var pending []string
for _, m := range migrations {
if !isApplied(m, applied) {
pending = append(pending, m.Name)
}
}
return &Status{
Applied: appliedNames,
Pending: pending,
}, nil
}

View File

@ -0,0 +1,457 @@
package sqlmigrate_test
import (
"context"
"errors"
"slices"
"strings"
"testing"
"testing/fstest"
"github.com/therootcompany/golib/database/sqlmigrate"
)
// applied builds an []AppliedMigration from names (IDs empty).
func applied(names ...string) []sqlmigrate.AppliedMigration {
out := make([]sqlmigrate.AppliedMigration, len(names))
for i, n := range names {
out[i] = sqlmigrate.AppliedMigration{Name: n}
}
return out
}
// mockMigrator tracks applied migrations in memory.
type mockMigrator struct {
applied []sqlmigrate.AppliedMigration
execErr error // if set, ExecUp/ExecDown return this on every call
upCalls []string
downCalls []string
}
func (m *mockMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error {
m.upCalls = append(m.upCalls, mig.Name)
if m.execErr != nil {
return m.execErr
}
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID})
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int {
return strings.Compare(a.Name, b.Name)
})
return nil
}
func (m *mockMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error {
m.downCalls = append(m.downCalls, mig.Name)
if m.execErr != nil {
return m.execErr
}
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name })
return nil
}
func (m *mockMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) {
return slices.Clone(m.applied), nil
}
// --- Collect ---
func TestCollect(t *testing.T) {
t.Run("pairs and sorts", func(t *testing.T) {
fsys := fstest.MapFS{
"002_second.up.sql": {Data: []byte("CREATE TABLE b;")},
"002_second.down.sql": {Data: []byte("DROP TABLE b;")},
"001_first.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_first.down.sql": {Data: []byte("DROP TABLE a;")},
}
migrations, err := sqlmigrate.Collect(fsys)
if err != nil {
t.Fatal(err)
}
if len(migrations) != 2 {
t.Fatalf("got %d migrations, want 2", len(migrations))
}
if migrations[0].Name != "001_first" {
t.Errorf("first = %q, want %q", migrations[0].Name, "001_first")
}
if migrations[1].Name != "002_second" {
t.Errorf("second = %q, want %q", migrations[1].Name, "002_second")
}
if migrations[0].Up != "CREATE TABLE a;" {
t.Errorf("first.Up = %q", migrations[0].Up)
}
if migrations[0].Down != "DROP TABLE a;" {
t.Errorf("first.Down = %q", migrations[0].Down)
}
})
t.Run("parses ID from INSERT", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")},
}
migrations, err := sqlmigrate.Collect(fsys)
if err != nil {
t.Fatal(err)
}
if migrations[0].ID != "abcd1234" {
t.Errorf("ID = %q, want %q", migrations[0].ID, "abcd1234")
}
})
t.Run("no ID when no INSERT", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("CREATE TABLE a;")},
"001_init.down.sql": {Data: []byte("DROP TABLE a;")},
}
migrations, err := sqlmigrate.Collect(fsys)
if err != nil {
t.Fatal(err)
}
if migrations[0].ID != "" {
t.Errorf("ID = %q, want empty", migrations[0].ID)
}
})
t.Run("missing down", func(t *testing.T) {
fsys := fstest.MapFS{
"001_only-up.up.sql": {Data: []byte("CREATE TABLE x;")},
}
_, err := sqlmigrate.Collect(fsys)
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err)
}
})
t.Run("missing up", func(t *testing.T) {
fsys := fstest.MapFS{
"001_only-down.down.sql": {Data: []byte("DROP TABLE x;")},
}
_, err := sqlmigrate.Collect(fsys)
if !errors.Is(err, sqlmigrate.ErrMissingUp) {
t.Errorf("got %v, want ErrMissingUp", err)
}
})
t.Run("ignores non-sql files", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.up.sql": {Data: []byte("UP")},
"001_init.down.sql": {Data: []byte("DOWN")},
"README.md": {Data: []byte("# Migrations")},
"_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")},
}
migrations, err := sqlmigrate.Collect(fsys)
if err != nil {
t.Fatal(err)
}
if len(migrations) != 1 {
t.Fatalf("got %d migrations, want 1", len(migrations))
}
})
t.Run("empty fs", func(t *testing.T) {
fsys := fstest.MapFS{}
migrations, err := sqlmigrate.Collect(fsys)
if err != nil {
t.Fatal(err)
}
if len(migrations) != 0 {
t.Fatalf("got %d migrations, want 0", len(migrations))
}
})
}
// --- NamesOnly ---
func TestNamesOnly(t *testing.T) {
names := []string{"001_init", "002_users"}
migrations := sqlmigrate.NamesOnly(names)
if len(migrations) != 2 {
t.Fatalf("got %d, want 2", len(migrations))
}
for i, m := range migrations {
if m.Name != names[i] {
t.Errorf("[%d].Name = %q, want %q", i, m.Name, names[i])
}
if m.Up != "" || m.Down != "" {
t.Errorf("[%d] has non-empty content", i)
}
}
}
// --- Up ---
func TestUp(t *testing.T) {
ctx := t.Context()
migrations := []sqlmigrate.Migration{
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
}
t.Run("apply all", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ran, []string{"001_init", "002_users", "003_posts"}) {
t.Errorf("applied = %v", ran)
}
})
t.Run("apply n", func(t *testing.T) {
m := &mockMigrator{}
ran, err := sqlmigrate.Up(ctx, m, migrations, 2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ran, []string{"001_init", "002_users"}) {
t.Errorf("applied = %v", ran)
}
})
t.Run("none pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
if err != nil {
t.Fatal(err)
}
if len(ran) != 0 {
t.Fatalf("applied %d, want 0", len(ran))
}
})
t.Run("partial pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ran, []string{"002_users", "003_posts"}) {
t.Errorf("applied = %v", ran)
}
})
t.Run("n exceeds pending", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
ran, err := sqlmigrate.Up(ctx, m, migrations, 99)
if err != nil {
t.Fatal(err)
}
if len(ran) != 2 {
t.Fatalf("applied %d, want 2", len(ran))
}
})
t.Run("exec error stops and returns partial", func(t *testing.T) {
m := &failOnNthMigrator{failAt: 1}
ran, err := sqlmigrate.Up(ctx, m, migrations, 0)
if err == nil {
t.Fatal("expected error")
}
if len(ran) != 1 {
t.Errorf("applied %d before error, want 1", len(ran))
}
})
t.Run("skips migration matched by ID", func(t *testing.T) {
// DB has migration applied under old name, but same ID
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{
{Name: "001_old-name", ID: "aa11bb22"},
}}
migs := []sqlmigrate.Migration{
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
}
ran, err := sqlmigrate.Up(ctx, m, migs, 0)
if err != nil {
t.Fatal(err)
}
// Only 002_users should be applied; 001 is matched by ID
if !slices.Equal(ran, []string{"002_users"}) {
t.Errorf("applied = %v, want [002_users]", ran)
}
})
}
// failOnNthMigrator fails on the Nth ExecUp call (0-indexed).
type failOnNthMigrator struct {
applied []sqlmigrate.AppliedMigration
calls int
failAt int
}
func (m *failOnNthMigrator) ExecUp(_ context.Context, mig sqlmigrate.Migration) error {
if m.calls == m.failAt {
m.calls++
return errors.New("connection lost")
}
m.calls++
m.applied = append(m.applied, sqlmigrate.AppliedMigration{Name: mig.Name, ID: mig.ID})
slices.SortFunc(m.applied, func(a, b sqlmigrate.AppliedMigration) int {
return strings.Compare(a.Name, b.Name)
})
return nil
}
func (m *failOnNthMigrator) ExecDown(_ context.Context, mig sqlmigrate.Migration) error {
m.applied = slices.DeleteFunc(m.applied, func(a sqlmigrate.AppliedMigration) bool { return a.Name == mig.Name })
return nil
}
func (m *failOnNthMigrator) Applied(_ context.Context) ([]sqlmigrate.AppliedMigration, error) {
return slices.Clone(m.applied), nil
}
// --- Down ---
func TestDown(t *testing.T) {
ctx := t.Context()
migrations := []sqlmigrate.Migration{
{Name: "001_init", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
{Name: "002_users", Up: "CREATE TABLE b;", Down: "DROP TABLE b;"},
{Name: "003_posts", Up: "CREATE TABLE c;", Down: "DROP TABLE c;"},
}
t.Run("rollback all", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 3 {
t.Fatalf("rolled %d, want 3", len(rolled))
}
if rolled[0] != "003_posts" {
t.Errorf("first rollback = %q, want 003_posts", rolled[0])
}
if rolled[2] != "001_init" {
t.Errorf("last rollback = %q, want 001_init", rolled[2])
}
})
t.Run("rollback n", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(rolled, []string{"003_posts", "002_users"}) {
t.Errorf("rolled = %v", rolled)
}
})
t.Run("none applied", func(t *testing.T) {
m := &mockMigrator{}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 0)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 0 {
t.Fatalf("rolled %d, want 0", len(rolled))
}
})
t.Run("exec error", func(t *testing.T) {
m := &mockMigrator{
applied: applied("001_init", "002_users"),
execErr: errors.New("permission denied"),
}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 1)
if err == nil {
t.Fatal("expected error")
}
if len(rolled) != 0 {
t.Errorf("rolled %d on error, want 0", len(rolled))
}
})
t.Run("unknown migration in applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "999_unknown")}
_, err := sqlmigrate.Down(ctx, m, migrations, 0)
if !errors.Is(err, sqlmigrate.ErrMissingDown) {
t.Errorf("got %v, want ErrMissingDown", err)
}
})
t.Run("n exceeds applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
rolled, err := sqlmigrate.Down(ctx, m, migrations, 99)
if err != nil {
t.Fatal(err)
}
if len(rolled) != 1 {
t.Fatalf("rolled %d, want 1", len(rolled))
}
})
t.Run("finds migration by ID when name changed", func(t *testing.T) {
// DB has old name, file has new name, same ID
m := &mockMigrator{applied: []sqlmigrate.AppliedMigration{
{Name: "001_old-name", ID: "aa11bb22"},
}}
migs := []sqlmigrate.Migration{
{Name: "001_new-name", ID: "aa11bb22", Up: "CREATE TABLE a;", Down: "DROP TABLE a;"},
}
rolled, err := sqlmigrate.Down(ctx, m, migs, 1)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(rolled, []string{"001_old-name"}) {
t.Errorf("rolled = %v, want [001_old-name]", rolled)
}
})
}
// --- GetStatus ---
func TestGetStatus(t *testing.T) {
ctx := t.Context()
migrations := []sqlmigrate.Migration{
{Name: "001_init"},
{Name: "002_users"},
{Name: "003_posts"},
}
t.Run("all pending", func(t *testing.T) {
m := &mockMigrator{}
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 0 {
t.Errorf("applied = %d, want 0", len(status.Applied))
}
if len(status.Pending) != 3 {
t.Errorf("pending = %d, want 3", len(status.Pending))
}
})
t.Run("partial", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init")}
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 1 || status.Applied[0] != "001_init" {
t.Errorf("applied = %v", status.Applied)
}
if !slices.Equal(status.Pending, []string{"002_users", "003_posts"}) {
t.Errorf("pending = %v", status.Pending)
}
})
t.Run("all applied", func(t *testing.T) {
m := &mockMigrator{applied: applied("001_init", "002_users", "003_posts")}
status, err := sqlmigrate.GetStatus(ctx, m, migrations)
if err != nil {
t.Fatal(err)
}
if len(status.Applied) != 3 {
t.Errorf("applied = %d, want 3", len(status.Applied))
}
if len(status.Pending) != 0 {
t.Errorf("pending = %d, want 0", len(status.Pending))
}
})
}