AJ ONeal dda4491c94
chore: golint fixes across all modules
Run scripts/golint: goimports, go fix, go vet, go mod tidy.
2026-04-09 14:23:33 -06:00

1139 lines
34 KiB
Go

//
// Written in 2025 by AJ ONeal <aj@therootcompany.com>
//
// To the extent possible under law, the author(s) have dedicated all copyright
// and related and neighboring rights to this software to the public domain
// worldwide. This software is distributed without any warranty.
//
// You should have received a copy of the CC0 Public Domain Dedication along with
// this software. If not, see <https://creativecommons.org/publicdomain/zero/1.0/>.
// Package sql-migrate provides a simple SQL migrator that's easy to roll back or mix and match during development
package main
import (
"bufio"
"context"
"crypto/rand"
"encoding/hex"
"flag"
"fmt"
"io"
"log"
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/therootcompany/golib/database/sqlmigrate"
"github.com/therootcompany/golib/database/sqlmigrate/shmigrate"
)
// replaced by goreleaser / ldflags
var (
version = "0.0.0-dev"
commit = "0000000"
date = "0001-01-01"
)
const (
defaultMigrationDir = "./sql/migrations/"
defaultLogPath = "../migrations.log"
sqlCommandPSQL = `psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --tuples-only --file %s`
sqlCommandMariaDB = `mariadb --defaults-extra-file="$MY_CNF" --silent --skip-column-names --raw < %s`
sqlCommandMySQL = `mysql --defaults-extra-file="$MY_CNF" --silent --skip-column-names --raw < %s`
sqlCommandSQLite = `sqlite3 "$SQLITE_PATH" < %s`
sqlCommandSQLCmd = `sqlcmd --exit-on-error --headers -1 --trim-spaces --encrypt-connection strict --input-file %s`
LOG_QUERY_NAME = "_migrations.sql"
M_MIGRATOR_NAME = "0001-01-01-001000_init-migrations"
M_MIGRATOR_UP_NAME = "0001-01-01-001000_init-migrations.up.sql"
M_MIGRATOR_DOWN_NAME = "0001-01-01-001000_init-migrations.down.sql"
defaultMigratorUpTmpl = `-- Config variables for sql-migrate (do not delete)
-- sql_command: %s
-- migrations_log: %s
--
CREATE TABLE IF NOT EXISTS _migrations (
id CHAR(8) PRIMARY KEY,
name VARCHAR(80) NULL UNIQUE,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- note: to enable text-based tools to grep and sort we put 'name' before 'id'
-- grep -r 'INSERT INTO _migrations' ./sql/migrations/ | cut -d':' -f2 | sort
INSERT INTO _migrations (name, id) VALUES ('0001-01-01-001000_init-migrations', '00000001');
`
defaultMigratorDown = `DELETE FROM _migrations WHERE id = '00000001';
DROP TABLE IF EXISTS _migrations;
`
// Used for detection during auto-upgrade.
logMigrationsQueryPrev2_2_0 = `SELECT name FROM _migrations ORDER BY name;`
logMigrationsQueryNote = "-- note: CLI arguments must be passed to the sql command to keep output machine-readable\n"
logMigrationsQuerySQLCmdNote = "-- connection: set SQLCMDSERVER, SQLCMDDATABASE, SQLCMDUSER, SQLCMDPASSWORD in .env\n"
)
// printVersion displays the version, commit, and build date.
func printVersion(w io.Writer) {
_, _ = fmt.Fprintf(w, "sql-migrate v%s %s (%s)\n", version, commit[:7], date)
}
var helpText = `
sql-migrate - a feature-branch-friendly SQL migrator
USAGE
sql-migrate [-d sqldir] <command> [args]
EXAMPLE
sql-migrate -d ./sql/migrations/ init --sql-command <psql|mariadb|mysql|sqlite|sqlcmd>
sql-migrate -d ./sql/migrations/ create <kebab-case-description>
sql-migrate -d ./sql/migrations/ sync
sql-migrate -d ./sql/migrations/ status
sql-migrate -d ./sql/migrations/ up 99
sql-migrate -d ./sql/migrations/ down 1
sql-migrate -d ./sql/migrations/ list
COMMANDS
init - creates migrations directory, initial migration, log file,
and query for migrations
create - creates a new, canonically-named up/down file pair in the
migrations directory, with corresponding insert
sync - create a script to reload migrations.log from the DB
(run after upgrading sql-migrate)
status - shows the same output as if processing a forward-migration
up [n] - create a script to run pending migrations (ALL by default)
down [n] - create a script to roll back migrations (ONE by default)
list - lists migrations
OPTIONS
-d <migrations directory> default: ./sql/migrations/
--help show command-specific help
NOTES
Migrations files are in the following format:
<yyyy-mm-dd>-<number>_<name>.<up|down>.sql
2020-01-01-1000_init-app.up.sql
The initial migration file contains configuration variables:
-- migrations_log: ./sql/migrations.log
-- sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --tuples-only --file %s
The log is generated on each migration file contains a list of all migrations:
0001-01-01-001000_init-migrations.up.sql
2020-12-31-001000_init-app.up.sql
2020-12-31-001100_add-customer-tables.up.sql
2020-12-31-002000_add-ALL-THE-TABLES.up.sql
The 'create' generates an up/down pair of files using the current date and
the number 1000. If either file exists, the number is incremented by 1000 and
tried again.
NOTE: POSTGRES SCHEMAS
Set PGOPTIONS to target a specific PostgreSQL schema:
PGOPTIONS="-c search_path=tenant123" sql-migrate up | sh
Each schema gets its own _migrations table, so tenants are migrated
independently. PGOPTIONS is supported by psql and all libpq clients.
NOTE: SQL SERVER (go-sqlcmd)
Requires the modern sqlcmd (go-mssqldb), not the legacy ODBC version.
Install: brew install sqlcmd (macOS), winget install sqlcmd (Windows)
The default uses --encrypt-connection strict (TDS 8.0), which provides
TLS-first on TCP with ALPN 'tds/8.0' and SNI — required for proper TLS
termination at load balancers and reverse proxies.
Set these SQLCMD environment variables in your .env file:
SQLCMDSERVER='host\instance' # or host,port (e.g. localhost,1433)
SQLCMDDATABASE=myapp
SQLCMDUSER=sa
SQLCMDPASSWORD=secret
SQLCMDSERVER is the instance, not just the host. Common formats:
SQLCMDSERVER=localhost # default instance
SQLCMDSERVER='localhost\SQLEXPRESS' # named instance (quote the backslash)
SQLCMDSERVER='localhost,1433' # host and port
sqlcmd reads these automatically — no credentials in the command template.
For local development without TLS:
--sql-command 'sqlcmd --exit-on-error --headers -1 --trim-spaces --encrypt-connection disable --input-file %s'
UPGRADING
After upgrading sql-migrate, run sync to refresh the log format:
sql-migrate -d ./sql/migrations/ sync | sh
`
var (
nonWordRe = regexp.MustCompile(`\W+`)
commentStartRe = regexp.MustCompile(`(^|\s+)#.*`)
)
// logMigrationsSelect returns the DB-specific SELECT line for _migrations.
// The output format is "id<tab>name" per row, compatible with shmigrate.Applied().
func logMigrationsSelect(sqlCommand string) string {
var selectExpr string
switch {
case strings.Contains(sqlCommand, "psql"):
selectExpr = "id || CHR(9) || name"
case strings.Contains(sqlCommand, "mysql") || strings.Contains(sqlCommand, "mariadb"):
selectExpr = "CONCAT(id, CHAR(9), name)"
case strings.Contains(sqlCommand, "sqlite"):
selectExpr = "id || CHAR(9) || name"
case strings.Contains(sqlCommand, "sqlcmd"):
selectExpr = "id + CHAR(9) + name"
default:
fmt.Fprintf(os.Stderr, "Error: unrecognized --sql-command %q; cannot generate _migrations.sql\n", sqlCommand)
os.Exit(1)
}
return fmt.Sprintf("SELECT %s FROM _migrations ORDER BY name;", selectExpr)
}
type State struct {
Date time.Time
SQLCommand string
Lines []string
Migrated []string
MigrationsDir string
LogPath string
}
type MainConfig struct {
migrationsDir string
logPath string
sqlCommand string
}
func main() {
var cfg MainConfig
var today = time.Now()
if len(os.Args) < 2 {
printVersion(os.Stdout)
fmt.Println("")
fmt.Printf("%s\n", helpText)
os.Exit(0)
}
switch os.Args[1] {
case "-V", "-version", "--version", "version":
printVersion(os.Stdout)
os.Exit(0)
case "help", "-help", "--help":
printVersion(os.Stdout)
fmt.Println("")
fmt.Printf("%s\n", helpText)
os.Exit(0)
default:
// do nothing
}
mainArgs := os.Args[1:]
fsMain := flag.NewFlagSet("", flag.ExitOnError)
fsMain.StringVar(&cfg.migrationsDir, "d", defaultMigrationDir, "directory for migrations (where 0001-01-01_init-migrations.up.sql will be added)")
if err := fsMain.Parse(mainArgs); err != nil {
os.Exit(2)
}
var subcmd string
// note: Args() includes any flags after the first non-flag arg
// sql-migrate -d ./migs/ init --migrations-log ./migrations.log
// => init -f ./migrations.log
subArgs := fsMain.Args()
if len(subArgs) > 0 {
subcmd = subArgs[0]
subArgs = subArgs[1:]
}
var fsSub *flag.FlagSet
switch subcmd {
case "init":
fsSub = flag.NewFlagSet("init", flag.ExitOnError)
fsSub.StringVar(&cfg.logPath, "migrations-log", "", fmt.Sprintf("migration log file (default: %s) relative to and saved in %s", defaultLogPath, M_MIGRATOR_NAME))
fsSub.StringVar(&cfg.sqlCommand, "sql-command", sqlCommandPSQL, "construct scripts with this to execute SQL files: 'psql', 'mysql', 'mariadb', 'sqlite', 'sqlcmd', or custom arguments")
case "create", "sync", "up", "down", "status", "list":
fsSub = flag.NewFlagSet(subcmd, flag.ExitOnError)
default:
log.Printf("unknown command %s", subcmd)
printVersion(os.Stderr)
fmt.Fprintf(os.Stderr, "%s\n", helpText)
os.Exit(1)
}
if err := fsSub.Parse(subArgs); err != nil {
os.Exit(2)
}
leafArgs := fsSub.Args()
switch cfg.sqlCommand {
case "", "postgres", "postgresql", "pg", "psql", "plpgsql":
cfg.sqlCommand = sqlCommandPSQL
case "mariadb":
cfg.sqlCommand = sqlCommandMariaDB
case "mysql", "my":
cfg.sqlCommand = sqlCommandMySQL
case "sqlite", "sqlite3", "lite":
cfg.sqlCommand = sqlCommandSQLite
case "sqlcmd", "mssql", "sqlserver":
cfg.sqlCommand = sqlCommandSQLCmd
default:
// leave as provided by the user
}
if !strings.HasSuffix(cfg.migrationsDir, "/") {
cfg.migrationsDir += "/"
}
if subcmd == "init" {
mustInit(&cfg)
}
entries, err := os.ReadDir(cfg.migrationsDir)
if err != nil {
if os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: missing migrations directory. Run 'sql-migrate -d %q init' to create it.\n", cfg.migrationsDir)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Error: couldn't list migrations directory: %v\n", err)
os.Exit(1)
}
ups, downs := migrationsList(cfg.migrationsDir, entries)
if !slices.Contains(ups, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, "Error: missing initial migration. Run 'sql-migrate -d %q init' to create %q.\n", cfg.migrationsDir, M_MIGRATOR_UP_NAME)
os.Exit(1)
}
if !slices.Contains(downs, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, "Error: missing initial migration. Run 'sql-migrate -d %q init' to create %q.\n", cfg.migrationsDir, M_MIGRATOR_DOWN_NAME)
os.Exit(1)
}
logQueryPath := filepath.Join(cfg.migrationsDir, LOG_QUERY_NAME)
if !fileExists(logQueryPath) {
fmt.Fprintf(os.Stderr, "Error: missing %q. Run 'sql-migrate -d %q init' to create it.\n", logQueryPath, cfg.migrationsDir)
os.Exit(1)
}
mMigratorUpPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME)
// mMigratorDownPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_DOWN_NAME)
state := State{
Date: today,
MigrationsDir: cfg.migrationsDir,
}
state.SQLCommand, state.LogPath, err = extractVars(mMigratorUpPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't read config from initial migration: %v\n", err)
os.Exit(1)
}
// auto-upgrade _migrations.sql to include id in output
maybeUpgradeLogQuery(logQueryPath, state.SQLCommand)
logText, err := os.ReadFile(state.LogPath)
if err != nil {
if !os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: couldn't read migrations log %q: %v\n", state.LogPath, err)
os.Exit(1)
}
if err := migrationsLogInit(&state, subcmd); err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't create log file directory: %v\n", err)
os.Exit(1)
}
}
if err := state.parseAndFixupBatches(string(logText)); err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't read migrations log or fixup batches: %v\n", err)
os.Exit(1)
}
ctx := context.Background()
runner := &shmigrate.Migrator{
Writer: os.Stdout,
SqlCommand: state.SQLCommand,
MigrationsDir: state.MigrationsDir,
LogQueryPath: filepath.Join(state.MigrationsDir, LOG_QUERY_NAME),
LogPath: state.LogPath,
}
migrations := sqlmigrate.NamesOnly(ups)
switch subcmd {
case "init":
break
case "sync":
syncLog(runner)
case "create":
if len(leafArgs) == 0 {
log.Fatal("create requires a description")
}
desc := strings.Join(leafArgs, " ")
desc = nonWordRe.ReplaceAllString(desc, " ")
desc = strings.TrimSpace(desc)
desc = nonWordRe.ReplaceAllString(desc, "-")
desc = strings.ToLower(desc)
err = create(&state, desc)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't create migration: %v\n", err)
os.Exit(1)
}
case "status":
if len(leafArgs) > 0 {
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
os.Exit(1)
}
if err := cmdStatus(ctx, &state, runner, migrations); err != nil {
log.Fatal(err)
}
case "list":
if len(leafArgs) > 0 {
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
os.Exit(1)
}
fmt.Println("Ups:")
if len(ups) == 0 {
fmt.Println(" (none)")
}
for _, u := range ups {
fmt.Println(" ", u)
}
fmt.Println("")
fmt.Println("Downs:")
if len(downs) == 0 {
fmt.Println(" (none)")
}
for _, d := range downs {
fmt.Println(" ", d)
}
case "up":
upN := -1
switch len(leafArgs) {
case 0:
// no arg: apply all pending
case 1:
upN, err = strconv.Atoi(leafArgs[0])
if err != nil || upN < 1 {
fmt.Fprintf(os.Stderr, "Error: %s is not a positive number\n", leafArgs[0])
os.Exit(1)
}
default:
fmt.Fprintf(os.Stderr, "Error: unrecognized arguments %q \n", strings.Join(leafArgs, "\" \""))
os.Exit(1)
}
if err := cmdUp(ctx, &state, runner, migrations, upN); err != nil {
log.Fatal(err)
}
case "down":
downN := 1
switch len(leafArgs) {
case 0:
// default: roll back one
case 1:
downN, err = strconv.Atoi(leafArgs[0])
if err != nil || downN < 1 {
fmt.Fprintf(os.Stderr, "Error: %s is not a positive number\n", leafArgs[0])
os.Exit(1)
}
default:
fmt.Fprintf(os.Stderr, "Error: unrecognized arguments %q \n", strings.Join(leafArgs, "\" \""))
os.Exit(1)
}
if err := cmdDown(ctx, &state, runner, migrations, downN); err != nil {
log.Fatal(err)
}
default:
log.Printf("unknown command %s", subcmd)
printVersion(os.Stderr)
fmt.Fprintf(os.Stderr, "%s\n", helpText)
os.Exit(1)
}
}
func migrationsList(migrationsDir string, entries []os.DirEntry) (ups, downs []string) {
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") || strings.HasPrefix(name, "+") {
if name != LOG_QUERY_NAME {
fmt.Fprintf(os.Stderr, " ignoring %s\n", filepathJoin(migrationsDir, name))
}
continue
}
if base, ok := strings.CutSuffix(name, ".up.sql"); ok {
ups = append(ups, base)
// TODO on ups add INSERT to file and to up migration if it doesn't exist
continue
}
if base, ok := strings.CutSuffix(name, ".down.sql"); ok {
downs = append(downs, base)
continue
}
fmt.Fprintf(os.Stderr, " unknown %s\n", filepathJoin(migrationsDir, name))
}
for _, down := range downs {
// TODO on downs add INSERT to file and to up migration if it doesn't exist
upName := strings.TrimSuffix(down, ".down.sql") + ".up.sql"
companion := filepath.Join(migrationsDir, upName)
if !fileExists(companion) {
fmt.Fprintf(os.Stderr, " missing '%s'\n", companion)
}
}
sort.Strings(ups)
sort.Strings(downs)
return ups, downs
}
func fileExists(path string) bool {
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false
}
fmt.Fprintf(os.Stderr, "Error: can't access %q\n", path)
os.Exit(1)
return false
}
return true
}
func filepathUnclean(path string) string {
if !strings.HasPrefix(path, "/") {
if !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") {
path = "./" + path
}
}
return path
}
func filepathJoin(src, dst string) string {
return filepathUnclean(filepath.Join(src, dst))
}
// initializes all necessary files and directories
// - ./sql/migrations.log
//
// - ./sql/migrations
//
// - ./sql/migrations/0001-01-01-001000_init-migrations.up.sql
// - migrations_log: ./sql/migrations.log
// - sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --tuples-only --file %s
//
// - ./sql/migrations/0001-01-01-001000_init-migrations.down.sql
func mustInit(cfg *MainConfig) {
fmt.Fprintf(os.Stderr, "Initializing %q ...\n", cfg.migrationsDir)
var resolvedLogPath = cfg.logPath
if cfg.sqlCommand != "" && !strings.Contains(cfg.sqlCommand, "%s") {
fmt.Fprintf(os.Stderr, "Error: --sql-command must contain a literal '%%s' to accept the path to the SQL file\n")
os.Exit(1)
}
entries, err := os.ReadDir(cfg.migrationsDir)
if err != nil {
if !os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: init failed to create %q: %v\n", cfg.migrationsDir, err)
os.Exit(1)
}
if err = os.MkdirAll(cfg.migrationsDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "Error: init failed to read %q: %v\n", cfg.migrationsDir, err)
os.Exit(1)
}
}
ups, downs := migrationsList(cfg.migrationsDir, entries)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't list existing migrations: %v\n", err)
os.Exit(1)
}
mMigratorUpPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME)
mMigratorDownPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_DOWN_NAME)
// write config
if slices.Contains(ups, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, " found %s\n", filepathJoin(cfg.migrationsDir, M_MIGRATOR_UP_NAME))
} else {
if cfg.logPath == "" {
migrationsParent := filepath.Dir(cfg.migrationsDir)
resolvedLogPath = filepath.Join(migrationsParent, defaultLogPath)
// resolvedLogPath, err = filepath.Rel(cfg.migrationsDir, cfg.logPath)
// if err != nil {
// fmt.Fprintf(os.Stderr, "Error: init couldn't resolve the migrations log relative to the migrations dir: %v\n", err)
// os.Exit(1)
// }
}
migratorUpQuery := fmt.Sprintf(defaultMigratorUpTmpl, cfg.sqlCommand, resolvedLogPath)
if created, err := initFile(mMigratorUpPath, migratorUpQuery); err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't create initial up migration: %v\n", err)
os.Exit(1)
} else if created {
fmt.Fprintf(os.Stderr, " created %s\n", filepathUnclean(mMigratorUpPath))
}
}
state := State{
MigrationsDir: cfg.migrationsDir,
}
state.SQLCommand, state.LogPath, err = extractVars(mMigratorUpPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't read config from initial migration: %v\n", err)
os.Exit(1)
}
if cfg.logPath != "" && filepath.Clean(cfg.logPath) != filepath.Clean(state.LogPath) {
fmt.Fprintf(os.Stderr,
"--migrations-log %q does not match %q from %q\n(drop the --migrations-log flag, or update the add migrations file)\n",
cfg.logPath, state.LogPath, mMigratorUpPath,
)
os.Exit(1)
}
if cfg.sqlCommand != "" && cfg.sqlCommand != state.SQLCommand {
fmt.Fprintf(os.Stderr,
"--sql-command %q does not match %q from %q\n(drop the --sql-command flag, or update the add migrations file)\n",
cfg.sqlCommand, state.SQLCommand, mMigratorUpPath,
)
os.Exit(1)
}
if slices.Contains(downs, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, " found %s\n", filepathUnclean(mMigratorDownPath))
} else {
migratorDownQuery := defaultMigratorDown
if created, err := initFile(mMigratorDownPath, migratorDownQuery); err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't create initial up migration: %v\n", err)
os.Exit(1)
} else if created {
fmt.Fprintf(os.Stderr, " created %s\n", filepathUnclean(mMigratorDownPath))
}
}
logQueryPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
queryHeader := logMigrationsQueryNote
if strings.Contains(state.SQLCommand, "sqlcmd") {
queryHeader += logMigrationsQuerySQLCmdNote
}
if created, err := initFile(logQueryPath, queryHeader+logMigrationsSelect(state.SQLCommand)+"\n"); err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't create migrations query: %v\n", err)
os.Exit(1)
} else if created {
fmt.Fprintf(os.Stderr, " created %s\n", filepathUnclean(logQueryPath))
} else {
fmt.Fprintf(os.Stderr, " found %s\n", filepathUnclean(logQueryPath))
}
if fileExists(state.LogPath) {
fmt.Fprintf(os.Stderr, " found %s\n", filepathUnclean(state.LogPath))
fmt.Fprintf(os.Stderr, "done\n")
return
}
}
// maybeUpgradeLogQuery replaces the old name-only SELECT in _migrations.sql
// with the new id+name SELECT. Only the matching line is replaced; comments
// and other customizations are preserved.
func maybeUpgradeLogQuery(path, sqlCommand string) {
b, err := os.ReadFile(path)
if err != nil {
return
}
var replaced bool
var lines []string
for line := range strings.SplitSeq(string(b), "\n") {
if !replaced && strings.TrimSpace(line) == logMigrationsQueryPrev2_2_0 {
line = logMigrationsSelect(sqlCommand)
replaced = true
}
lines = append(lines, line)
}
if !replaced {
return
}
if err := os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644); err != nil {
fmt.Fprintf(os.Stderr, "Warn: couldn't upgrade %s: %v\n", path, err)
return
}
fmt.Fprintf(os.Stderr, " upgraded %s (added id to output)\n", filepathUnclean(path))
}
func initFile(path, contents string) (bool, error) {
if fileExists(path) {
return false, nil
}
if err := os.WriteFile(path, []byte(contents), 0644); err != nil {
return false, err
}
return true, nil
}
func extractVars(curMigrationPath string) (sqlCommand string, logPath string, err error) {
f, err := os.Open(curMigrationPath)
if err != nil {
return "", "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
var logPathRel string
var logPathPrefix = "-- migrations_log:"
var commandPrefix = "-- sql_command:"
for scanner.Scan() {
txt := scanner.Text()
txt = strings.TrimSpace(txt)
if strings.HasPrefix(txt, logPathPrefix) {
logPathRel = strings.TrimSpace(txt[len(logPathPrefix):])
continue
} else if strings.HasPrefix(txt, commandPrefix) {
sqlCommand = strings.TrimSpace(txt[len(commandPrefix):])
continue
}
}
if logPathRel == "" {
return "", "", fmt.Errorf("Could not find '-- migrations_log: <relative-path>' in %q", curMigrationPath)
}
if sqlCommand == "" {
return "", "", fmt.Errorf("Could not find '-- sql_command: <args>' in %q", curMigrationPath)
}
// migrationsDir := filepath.Dir(curMigrationPath)
// logPath = filepath.Join(migrationsDir, logPathRel)
// return sqlCommand, logPath, nil
return sqlCommand, logPathRel, nil
}
func migrationsLogInit(state *State, subcmd string) error {
logDir := filepath.Dir(state.LogPath)
if err := os.MkdirAll(logDir, 0755); err != nil {
return err
}
if subcmd != "up" {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "Run the first migration to complete the initialization:\n")
fmt.Fprintf(os.Stderr, "(you'll need to provide DB credentials via .env or export)\n")
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, " sql-migrate -d %q up > ./up.sh && sh ./up.sh\n", state.MigrationsDir)
fmt.Fprintf(os.Stderr, "\n")
}
return nil
}
func (state *State) parseAndFixupBatches(text string) error {
text = strings.TrimSpace(text)
if text == "" {
return nil
}
fixedUp := []string{}
fixedDown := []string{}
state.Lines = strings.Split(text, "\n")
for i := range state.Lines {
line := strings.TrimSpace(state.Lines[i])
migration := commentStartRe.ReplaceAllString(line, "")
migration = strings.TrimSpace(migration)
if migration != "" {
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
state.Migrated = append(state.Migrated, migration)
if up {
fixedUp = append(fixedUp, migration)
}
if down {
fixedDown = append(fixedDown, migration)
}
}
state.Lines[i] = line
}
showFixes(fixedUp, fixedDown)
return nil
}
func showFixes(fixedUp, fixedDown []string) {
if len(fixedUp) > 0 {
fmt.Fprintf(os.Stderr, "Fixup: appended missing 'INSERT INTO _migrations ...' to:\n")
for _, up := range fixedUp {
fmt.Fprintf(os.Stderr, " %s\n", up)
}
fmt.Fprintf(os.Stderr, "\n")
}
if len(fixedDown) > 0 {
fmt.Fprintf(os.Stderr, "Fixup: appended missing 'DELETE FROM _migrations ...' to:\n")
for _, down := range fixedDown {
fmt.Fprintf(os.Stderr, " %s\n", down)
}
fmt.Fprintf(os.Stderr, "\n")
}
}
func create(state *State, desc string) error {
dateStr := state.Date.Format("2006-01-02")
entries, err := os.ReadDir(state.MigrationsDir)
if err != nil {
return err
}
maxNumber := 0
datePrefix := dateStr + "-"
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, datePrefix) {
continue
}
if !strings.HasSuffix(name, ".up.sql") {
continue
}
if strings.HasSuffix(name, "_"+desc+".up.sql") {
return fmt.Errorf("migration for %q already exists:\n %s", desc, state.MigrationsDir+"/"+name)
}
if strings.HasSuffix(name, ".down.sql") {
continue
}
parts := strings.SplitN(name, "-", 4)
if len(parts) < 4 {
continue
}
numDesc := strings.SplitN(parts[3], "_", 2)
if len(numDesc) < 2 {
continue
}
num, err := strconv.Atoi(numDesc[0])
if err != nil {
continue
}
if num > maxNumber {
maxNumber = num
}
}
number := maxNumber / 1_000
number *= 1_000
number += 1_000
if number > 9_000 && number < 10_000 {
fmt.Fprintf(os.Stderr, "Achievement Unlocked: It's over 9000!\n")
}
if number >= 999_999 {
fmt.Fprintf(os.Stderr, "Error: cowardly refusing to generate such a suspiciously high number of migrations after running out of numbers\n")
os.Exit(1)
}
basename := fmt.Sprintf("%s-%06d_%s", dateStr, number, desc)
upPath := filepath.Join(state.MigrationsDir, basename+".up.sql")
downPath := filepath.Join(state.MigrationsDir, basename+".down.sql")
id := MustRandomHex(4)
// Little Bobby Drop Tables says:
// We trust the person running the migrations to not use malicious names.
// (we don't want to embed db-specific logic here, and SQL doesn't define escaping)
migrationInsert := fmt.Sprintf("INSERT INTO _migrations (name, id) VALUES ('%s', '%s');", basename, id)
upContent := fmt.Appendf(nil, "-- %s (up)\nSELECT 'place your UP migration here';\n\n-- leave this as the last line\n%s\n", desc, migrationInsert)
_ = os.WriteFile(upPath, upContent, 0644)
migrationDelete := fmt.Sprintf("DELETE FROM _migrations WHERE id = '%s';", id)
downContent := fmt.Appendf(nil, "-- %s (down)\nSELECT 'place your DOWN migration here';\n\n-- leave this as the last line\n%s\n", desc, migrationDelete)
_ = os.WriteFile(downPath, downContent, 0644)
fmt.Fprintf(os.Stderr, " created pair %s\n", filepathUnclean(upPath))
fmt.Fprintf(os.Stderr, " %s\n", filepathUnclean(downPath))
return nil
}
func MustRandomHex(n int) string {
s, err := RandomHex(n)
if err != nil {
panic(err)
}
return s
}
func RandomHex(n int) (string, error) {
b := make([]byte, n) // 4 bytes = 8 hex chars
_, err := rand.Read(b)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// attempts to add missing INSERT and DELETE without breaking what already works
func fixupMigration(dir string, basename string) (up, down bool, warn error, err error) {
var id string
var insertsOnUp bool
upPath := filepath.Join(dir, basename+".up.sql")
upScan, err := os.Open(upPath)
if err != nil {
return false, false, nil, fmt.Errorf("failed (up): %w", err)
}
defer upScan.Close()
scanner := bufio.NewScanner(upScan)
for scanner.Scan() {
txt := scanner.Text()
txt = strings.TrimSpace(txt)
txt = strings.ToLower(txt)
if strings.HasPrefix(txt, "insert into _migrations") {
insertsOnUp = true
break
}
}
if !insertsOnUp {
id = MustRandomHex(4)
upScan.Close()
upBytes, err := os.ReadFile(upPath)
if err != nil {
warn = fmt.Errorf("failed to add 'INSERT INTO _migrations ...' to %s: %w", upPath, err)
return false, false, warn, nil
}
migrationInsertLn := fmt.Sprintf("\n-- leave this as the last line\nINSERT INTO _migrations (name, id) VALUES ('%s', '%s');\n", basename, id)
upBytes = append(upBytes, []byte(migrationInsertLn)...)
if err = os.WriteFile(upPath, upBytes, 0644); err != nil {
warn = fmt.Errorf("failed to append 'INSERT INTO _migrations ...' to %s: %w", upPath, err)
return false, false, warn, nil
}
up = true
}
var deletesOnDown bool
downPath := filepath.Join(dir, basename+".down.sql")
downScan, err := os.Open(downPath)
if err != nil {
return false, false, fmt.Errorf("failed (down): %w", err), nil
}
defer downScan.Close()
scanner = bufio.NewScanner(downScan)
for scanner.Scan() {
txt := scanner.Text()
txt = strings.TrimSpace(txt)
txt = strings.ToLower(txt)
if strings.HasPrefix(txt, "delete from _migrations") {
deletesOnDown = true
break
}
}
if !deletesOnDown {
if id == "" {
return false, false, fmt.Errorf("must manually append \"DELETE FROM _migrations WHERE id = '<id>'\" to %s with id from %s", downPath, basename+"up.sql"), nil
}
downScan.Close()
downFile, err := os.OpenFile(downPath, os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
warn = fmt.Errorf("failed to append 'DELETE FROM _migrations ...' to %s: %v", downPath, err)
return false, false, warn, nil
}
defer downFile.Close()
migrationInsertLn := fmt.Sprintf("\nDELETE FROM _migrations WHERE id = '%s';\n", id)
_, err = downFile.Write(([]byte(migrationInsertLn)))
if err != nil {
warn = fmt.Errorf("failed to add 'DELETE FROM _migrations ...' to %s: %w", downPath, err)
return false, false, warn, nil
}
down = true
}
return up, down, nil, nil
}
func syncLog(runner *shmigrate.Migrator) {
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
logPath := filepathUnclean(runner.LogPath)
fmt.Printf(shmigrate.ShHeader)
fmt.Println("")
fmt.Println("# SYNC: reload migrations log from DB")
fmt.Printf("%s > %s || true\n", syncCmd, logPath)
fmt.Printf("cat %s\n", logPath)
}
func cmdUp(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error {
// fixup pending migrations before generating the script
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
if err != nil {
return err
}
if len(status.Pending) == 0 {
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
fmt.Fprintf(os.Stderr, "# Already up-to-date\n")
fmt.Fprintf(os.Stderr, "#\n")
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath))
return nil
}
fmt.Printf(shmigrate.ShHeader)
fmt.Println("")
fmt.Println("# FORWARD / UP Migrations")
fmt.Println("")
applied, err := sqlmigrate.Up(ctx, runner, migrations, n)
if err != nil {
return err
}
_ = applied
fmt.Println("cat", filepathUnclean(runner.LogPath))
showFixes(fixedUp, fixedDown)
return nil
}
func cmdDown(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script, n int) error {
// fixup applied migrations before generating the script
fixedUp, fixedDown := fixupAll(state.MigrationsDir, state.Migrated, migrations)
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
if err != nil {
return err
}
if len(status.Applied) == 0 {
syncCmd := strings.Replace(runner.SqlCommand, "%s", filepathUnclean(runner.LogQueryPath), 1)
fmt.Fprintf(os.Stderr, "# No migration history\n")
fmt.Fprintf(os.Stderr, "#\n")
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
fmt.Fprintf(os.Stderr, "# %s > %s\n", syncCmd, filepathUnclean(runner.LogPath))
return nil
}
fmt.Printf(shmigrate.ShHeader)
fmt.Println("")
fmt.Println("# ROLLBACK / DOWN Migration")
fmt.Println("")
// check for missing down files before generating script
reversed := make([]sqlmigrate.Migration, len(status.Applied))
copy(reversed, status.Applied)
slices.Reverse(reversed)
limit := n
if limit <= 0 {
limit = 1
}
if limit > len(reversed) {
limit = len(reversed)
}
for _, a := range reversed[:limit] {
downPath := filepath.Join(state.MigrationsDir, a.Name+".down.sql")
if !fileExists(downPath) {
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
}
}
rolled, err := sqlmigrate.Down(ctx, runner, migrations, n)
if err != nil {
return err
}
_ = rolled
fmt.Println("cat", filepathUnclean(runner.LogPath))
showFixes(fixedUp, fixedDown)
return nil
}
func cmdStatus(ctx context.Context, state *State, runner *shmigrate.Migrator, migrations []sqlmigrate.Script) error {
status, err := sqlmigrate.GetStatus(ctx, runner, migrations)
if err != nil {
return err
}
fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir))
fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath))
fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand)
fmt.Fprintf(os.Stderr, "\n")
// show applied in reverse (most recent first)
appliedList := make([]sqlmigrate.Migration, len(status.Applied))
copy(appliedList, status.Applied)
slices.Reverse(appliedList)
fmt.Printf("# previous: %d\n", len(appliedList))
for _, mig := range appliedList {
fmt.Printf(" %s\n", mig.Name)
}
if len(appliedList) == 0 {
fmt.Println(" # (no previous migrations)")
}
fmt.Println("")
fmt.Printf("# pending: %d\n", len(status.Pending))
for _, mig := range status.Pending {
fmt.Printf(" %s\n", mig.Name)
}
if len(status.Pending) == 0 {
fmt.Println(" # (no pending migrations)")
}
return nil
}
// fixupAll runs fixupMigration on all known migrations (applied + pending).
func fixupAll(migrationsDir string, applied []string, migrations []sqlmigrate.Script) (fixedUp, fixedDown []string) {
seen := map[string]bool{}
var all []string
for _, name := range applied {
if !seen[name] {
all = append(all, name)
seen[name] = true
}
}
for _, m := range migrations {
if !seen[m.Name] {
all = append(all, m.Name)
seen[m.Name] = true
}
}
for _, name := range all {
up, down, warn, err := fixupMigration(migrationsDir, name)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
if up {
fixedUp = append(fixedUp, name)
}
if down {
fixedDown = append(fixedDown, name)
}
}
return fixedUp, fixedDown
}