AJ ONeal 4a9c331ef9
Fix INSERT INTO _migrations ordering in create and fixup
The create subcommand generated .up.sql files with INSERT INTO
_migrations as the FIRST statement, before the actual DDL. If the
DDL fails, the migration is incorrectly marked as applied. Move the
INSERT to be the LAST statement, matching how .down.sql already puts
DELETE FROM _migrations last.

Also fix the automatic fixup logic to append (not prepend) missing
INSERT statements to existing .up.sql files.

Fixes #86
2026-03-30 15:58:21 -06:00

1041 lines
30 KiB
Go

//
// Written in 2025 by AJ ONeal <aj@therootcompany.com>
//
// To the extent possible under law, the author(s) have dedicated all copyright
// and related and neighboring rights to this software to the public domain
// worldwide. This software is distributed without any warranty.
//
// You should have received a copy of the CC0 Public Domain Dedication along with
// this software. If not, see <https://creativecommons.org/publicdomain/zero/1.0/>.
// Package sql-migrate provides a simple SQL migrator that's easy to roll back or mix and match during development
package main
import (
"bufio"
"crypto/rand"
"encoding/hex"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"time"
)
const (
version = "2.0.3"
)
const (
defaultMigrationDir = "./sql/migrations/"
defaultLogPath = "../migrations.log"
sqlCommandPSQL = `psql "$PG_URL" -v ON_ERROR_STOP=on -A -t --file %s`
sqlCommandMariaDB = `mariadb --defaults-extra-file="$MY_CNF" -s -N --raw < %s`
sqlCommandMySQL = `mysql --defaults-extra-file="$MY_CNF" -s -N --raw < %s`
LOG_QUERY_NAME = "_migrations.sql"
M_MIGRATOR_NAME = "0001-01-01-001000_init-migrations"
M_MIGRATOR_UP_NAME = "0001-01-01-001000_init-migrations.up.sql"
M_MIGRATOR_DOWN_NAME = "0001-01-01-001000_init-migrations.down.sql"
defaultMigratorUpTmpl = `-- Config variables for sql-migrate (do not delete)
-- sql_command: %s
-- migrations_log: %s
--
CREATE TABLE IF NOT EXISTS _migrations (
id CHAR(8) PRIMARY KEY,
name VARCHAR(80) NULL UNIQUE,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- note: to enable text-based tools to grep and sort we put 'name' before 'id'
-- grep -r 'INSERT INTO _migrations' ./sql/migrations/ | cut -d':' -f2 | sort
INSERT INTO _migrations (name, id) VALUES ('0001-01-01-001000_init-migrations', '00000001');
`
defaultMigratorDown = `DELETE FROM _migrations WHERE id = '00000001';
DROP TABLE IF EXISTS _migrations;
`
LOG_MIGRATIONS_QUERY = `-- note: CLI arguments must be passed to the sql command to keep output clean
SELECT name FROM _migrations ORDER BY name;
`
shHeader = `#/bin/sh
set -e
set -u
if test -s ./.env; then
. ./.env
fi
`
)
const helpText = `
sql-migrate v` + version + ` - a feature-branch-friendly SQL migrator
USAGE
sql-migrate [-d sqldir] <command> [args]
EXAMPLE
sql-migrate -d ./sql/migrations/ init --sql-command <psql|mariadb|mysql>
sql-migrate -d ./sql/migrations/ create <kebab-case-description>
sql-migrate -d ./sql/migrations/ sync
sql-migrate -d ./sql/migrations/ status
sql-migrate -d ./sql/migrations/ up 99
sql-migrate -d ./sql/migrations/ down 1
sql-migrate -d ./sql/migrations/ list
COMMANDS
init - creates migrations directory, initial migration, log file,
and query for migrations
create - creates a new, canonically-named up/down file pair in the
migrations directory, with corresponding insert
sync - create a script to reload migrations.log from the DB
status - shows the same output as if processing a forward-migration
up [n] - create a script to run pending migrations (ALL by default)
down [n] - create a script to roll back migrations (ONE by default)
list - lists migrations
OPTIONS
-d <migrations directory> default: ./sql/migrations/
--help show command-specific help
NOTES
Migrations files are in the following format:
<yyyy-mm-dd>-<number>_<name>.<up|down>.sql
2020-01-01-1000_init-app.up.sql
The initial migration file contains configuration variables:
-- migrations_log: ./sql/migrations.log
-- sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --file %s
The log is generated on each migration file contains a list of all migrations:
0001-01-01-001000_init-migrations.up.sql
2020-12-31-001000_init-app.up.sql
2020-12-31-001100_add-customer-tables.up.sql
2020-12-31-002000_add-ALL-THE-TABLES.up.sql
The 'create' generates an up/down pair of files using the current date and
the number 1000. If either file exists, the number is incremented by 1000 and
tried again.
`
var (
nonWordRe = regexp.MustCompile(`\W+`)
commentStartRe = regexp.MustCompile(`(^|\s+)#.*`)
)
type State struct {
Date time.Time
SQLCommand string
Lines []string
Migrated []string
MigrationsDir string
LogPath string
}
type MainConfig struct {
migrationsDir string
logPath string
sqlCommand string
}
func main() {
var cfg MainConfig
var date = time.Now()
if len(os.Args) < 2 {
//nolint
fmt.Printf("%s\n", helpText)
os.Exit(0)
}
switch os.Args[1] {
case "help", "--help",
"version", "--version", "-V":
fmt.Printf("%s\n", helpText)
os.Exit(0)
default:
// do nothing
}
mainArgs := os.Args[1:]
fsMain := flag.NewFlagSet("", flag.ExitOnError)
fsMain.StringVar(&cfg.migrationsDir, "d", defaultMigrationDir, "directory for migrations (where 0001-01-01_init-migrations.up.sql will be added)")
if err := fsMain.Parse(mainArgs); err != nil {
os.Exit(2)
}
var subcmd string
// note: Args() includes any flags after the first non-flag arg
// sql-migrate -d ./migs/ init --migrations-log ./migrations.log
// => init -f ./migrations.log
subArgs := fsMain.Args()
if len(subArgs) > 0 {
subcmd = subArgs[0]
subArgs = subArgs[1:]
}
var fsSub *flag.FlagSet
switch subcmd {
case "init":
fsSub = flag.NewFlagSet("init", flag.ExitOnError)
fsSub.StringVar(&cfg.logPath, "migrations-log", "", fmt.Sprintf("migration log file (default: %s) relative to and saved in %s", defaultLogPath, M_MIGRATOR_NAME))
fsSub.StringVar(&cfg.sqlCommand, "sql-command", sqlCommandPSQL, "construct scripts with this to execute SQL files: 'psql', 'mysql', 'mariadb', or custom arguments")
case "create", "sync", "up", "down", "status", "list":
fsSub = flag.NewFlagSet(subcmd, flag.ExitOnError)
default:
log.Printf("unknown command %s", subcmd)
fmt.Printf("%s\n", helpText)
os.Exit(1)
}
if err := fsSub.Parse(subArgs); err != nil {
os.Exit(2)
}
leafArgs := fsSub.Args()
switch cfg.sqlCommand {
case "", "postgres", "postgresql", "pg", "psql", "plpgsql":
cfg.sqlCommand = sqlCommandPSQL
case "mariadb":
cfg.sqlCommand = sqlCommandMariaDB
case "mysql", "my":
cfg.sqlCommand = sqlCommandMySQL
default:
// leave as provided by the user
}
if !strings.HasSuffix(cfg.migrationsDir, "/") {
cfg.migrationsDir += "/"
}
if subcmd == "init" {
mustInit(&cfg)
}
entries, err := os.ReadDir(cfg.migrationsDir)
if err != nil {
if os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: missing migrations directory. Run 'sql-migrate -d %q init' to create it.\n", cfg.migrationsDir)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Error: couldn't list migrations directory: %v\n", err)
os.Exit(1)
}
ups, downs := migrationsList(cfg.migrationsDir, entries)
if !slices.Contains(ups, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, "Error: missing initial migration. Run 'sql-migrate -d %q init' to create %q.\n", cfg.migrationsDir, M_MIGRATOR_UP_NAME)
os.Exit(1)
}
if !slices.Contains(downs, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, "Error: missing initial migration. Run 'sql-migrate -d %q init' to create %q.\n", cfg.migrationsDir, M_MIGRATOR_DOWN_NAME)
os.Exit(1)
}
logQueryPath := filepath.Join(cfg.migrationsDir, LOG_QUERY_NAME)
if !fileExists(logQueryPath) {
fmt.Fprintf(os.Stderr, "Error: missing %q. Run 'sql-migrate -d %q init' to create it.\n", logQueryPath, cfg.migrationsDir)
os.Exit(1)
}
mMigratorUpPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME)
// mMigratorDownPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_DOWN_NAME)
state := State{
Date: date,
MigrationsDir: cfg.migrationsDir,
}
state.SQLCommand, state.LogPath, err = extractVars(mMigratorUpPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't read config from initial migration: %v\n", err)
os.Exit(1)
}
logText, err := os.ReadFile(state.LogPath)
if err != nil {
if !os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: couldn't read migrations log %q: %v\n", state.LogPath, err)
os.Exit(1)
}
if err := migrationsLogInit(&state, subcmd); err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't create log file directory: %v\n", err)
os.Exit(1)
}
}
if err := state.parseAndFixupBatches(string(logText)); err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't read migrations log or fixup batches: %v\n", err)
os.Exit(1)
}
switch subcmd {
case "init":
break
case "sync":
if err := syncLog(&state); err != nil {
log.Fatal(err)
}
return
case "create":
if len(leafArgs) == 0 {
log.Fatal("create requires a description")
}
desc := strings.Join(leafArgs, " ")
desc = nonWordRe.ReplaceAllString(desc, " ")
desc = strings.TrimSpace(desc)
desc = nonWordRe.ReplaceAllString(desc, "-")
desc = strings.ToLower(desc)
err = create(&state, desc)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: couldn't create migration: %v\n", err)
os.Exit(1)
}
case "status":
if len(leafArgs) > 0 && subcmd != "create" {
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
os.Exit(1)
}
err = status(&state, ups)
if err != nil {
log.Fatal(err)
}
case "list":
if len(leafArgs) > 0 && subcmd != "create" {
fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " "))
os.Exit(1)
}
fmt.Println("Ups:")
if len(ups) == 0 {
fmt.Println(" (none)")
}
for _, u := range ups {
fmt.Println(" ", u)
}
fmt.Println("")
fmt.Println("Downs:")
if len(downs) == 0 {
fmt.Println(" (none)")
}
for _, d := range downs {
fmt.Println(" ", d)
}
case "up":
var upN int
switch len(leafArgs) {
case 0:
// ignore
case 1:
upN, err = strconv.Atoi(leafArgs[0])
if err != nil || upN < 0 {
fmt.Fprintf(os.Stderr, "Error: %s is not a positive number\n", leafArgs[0])
os.Exit(1)
}
default:
fmt.Fprintf(os.Stderr, "Error: unrecognized arguments %q \n", strings.Join(leafArgs, "\" \""))
os.Exit(1)
}
err = up(&state, ups, upN)
if err != nil {
log.Fatal(err)
}
case "down":
var downN int
switch len(leafArgs) {
case 0:
// ignore
case 1:
downN, err = strconv.Atoi(leafArgs[0])
if err != nil || downN < 0 {
fmt.Fprintf(os.Stderr, "Error: %s is not a positive number\n", leafArgs[0])
os.Exit(1)
}
default:
fmt.Fprintf(os.Stderr, "Error: unrecognized arguments %q \n", strings.Join(leafArgs, "\" \""))
os.Exit(1)
}
err = down(&state, downN)
if err != nil {
log.Fatal(err)
}
default:
log.Printf("unknown command %s", subcmd)
fmt.Printf("%s\n", helpText)
os.Exit(1)
}
}
func migrationsList(migrationsDir string, entries []os.DirEntry) (ups, downs []string) {
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") || strings.HasPrefix(name, "+") {
if name != LOG_QUERY_NAME {
fmt.Fprintf(os.Stderr, " ignoring %s\n", filepathJoin(migrationsDir, name))
}
continue
}
if base, ok := strings.CutSuffix(name, ".up.sql"); ok {
ups = append(ups, base)
// TODO on ups add INSERT to file and to up migration if it doesn't exist
continue
}
if base, ok := strings.CutSuffix(name, ".down.sql"); ok {
downs = append(downs, base)
continue
}
fmt.Fprintf(os.Stderr, " unknown %s\n", filepathJoin(migrationsDir, name))
}
for _, down := range downs {
// TODO on downs add INSERT to file and to up migration if it doesn't exist
upName := strings.TrimSuffix(down, ".down.sql") + ".up.sql"
companion := filepath.Join(migrationsDir, upName)
if !fileExists(companion) {
fmt.Fprintf(os.Stderr, " missing '%s'\n", companion)
}
}
sort.Strings(ups)
sort.Strings(downs)
return ups, downs
}
func fileExists(path string) bool {
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false
}
fmt.Fprintf(os.Stderr, "Error: can't access %q\n", path)
os.Exit(1)
return false
}
return true
}
func filepathUnclean(path string) string {
if !strings.HasPrefix(path, "/") {
if !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") {
path = "./" + path
}
}
return path
}
func filepathJoin(src, dst string) string {
return filepathUnclean(filepath.Join(src, dst))
}
// initializes all necessary files and directories
// - ./sql/migrations.log
//
// - ./sql/migrations
//
// - ./sql/migrations/0001-01-01-001000_init-migrations.up.sql
// - migrations_log: ./sql/migrations.log
// - sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --file %s
//
// - ./sql/migrations/0001-01-01-001000_init-migrations.down.sql
func mustInit(cfg *MainConfig) {
fmt.Fprintf(os.Stderr, "Initializing %q ...\n", cfg.migrationsDir)
var resolvedLogPath = cfg.logPath
if cfg.sqlCommand != "" && !strings.Contains(cfg.sqlCommand, "%s") {
fmt.Fprintf(os.Stderr, "Error: --sql-command must contain a literal '%%s' to accept the path to the SQL file\n")
os.Exit(1)
}
entries, err := os.ReadDir(cfg.migrationsDir)
if err != nil {
if !os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: init failed to create %q: %v\n", cfg.migrationsDir, err)
os.Exit(1)
}
if err = os.MkdirAll(cfg.migrationsDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "Error: init failed to read %q: %v\n", cfg.migrationsDir, err)
os.Exit(1)
}
}
ups, downs := migrationsList(cfg.migrationsDir, entries)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't list existing migrations: %v\n", err)
os.Exit(1)
}
mMigratorUpPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME)
mMigratorDownPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_DOWN_NAME)
// write config
if slices.Contains(ups, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, " found %s\n", filepathJoin(cfg.migrationsDir, M_MIGRATOR_UP_NAME))
} else {
if cfg.logPath == "" {
migrationsParent := filepath.Dir(cfg.migrationsDir)
resolvedLogPath = filepath.Join(migrationsParent, defaultLogPath)
// resolvedLogPath, err = filepath.Rel(cfg.migrationsDir, cfg.logPath)
// if err != nil {
// fmt.Fprintf(os.Stderr, "Error: init couldn't resolve the migrations log relative to the migrations dir: %v\n", err)
// os.Exit(1)
// }
}
migratorUpQuery := fmt.Sprintf(defaultMigratorUpTmpl, cfg.sqlCommand, resolvedLogPath)
if created, err := initFile(mMigratorUpPath, migratorUpQuery); err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't create initial up migration: %v\n", err)
os.Exit(1)
} else if created {
fmt.Fprintf(os.Stderr, " created %s\n", filepathUnclean(mMigratorUpPath))
}
}
state := State{
MigrationsDir: cfg.migrationsDir,
}
state.SQLCommand, state.LogPath, err = extractVars(mMigratorUpPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't read config from initial migration: %v\n", err)
os.Exit(1)
}
if cfg.logPath != "" && filepath.Clean(cfg.logPath) != filepath.Clean(state.LogPath) {
fmt.Fprintf(os.Stderr,
"--migrations-log %q does not match %q from %q\n(drop the --migrations-log flag, or update the add migrations file)\n",
cfg.logPath, state.LogPath, mMigratorUpPath,
)
os.Exit(1)
}
if cfg.sqlCommand != "" && cfg.sqlCommand != state.SQLCommand {
fmt.Fprintf(os.Stderr,
"--sql-command %q does not match %q from %q\n(drop the --sql-command flag, or update the add migrations file)\n",
cfg.sqlCommand, state.SQLCommand, mMigratorUpPath,
)
os.Exit(1)
}
if slices.Contains(downs, M_MIGRATOR_NAME) {
fmt.Fprintf(os.Stderr, " found %s\n", filepathUnclean(mMigratorDownPath))
} else {
migratorDownQuery := defaultMigratorDown
if created, err := initFile(mMigratorDownPath, migratorDownQuery); err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't create initial up migration: %v\n", err)
os.Exit(1)
} else if created {
fmt.Fprintf(os.Stderr, " created %s\n", filepathUnclean(mMigratorDownPath))
}
}
logQueryPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
if created, err := initFile(logQueryPath, LOG_MIGRATIONS_QUERY); err != nil {
fmt.Fprintf(os.Stderr, "Error: init couldn't create migrations query: %v\n", err)
os.Exit(1)
} else if created {
fmt.Fprintf(os.Stderr, " created %s\n", filepathUnclean(logQueryPath))
} else {
fmt.Fprintf(os.Stderr, " found %s\n", filepathUnclean(logQueryPath))
}
if fileExists(state.LogPath) {
fmt.Fprintf(os.Stderr, " found %s\n", filepathUnclean(state.LogPath))
fmt.Fprintf(os.Stderr, "done\n")
return
}
}
func initFile(path, contents string) (bool, error) {
if fileExists(path) {
return false, nil
}
if err := os.WriteFile(path, []byte(contents), 0644); err != nil {
return false, err
}
return true, nil
}
func extractVars(curMigrationPath string) (sqlCommand string, logPath string, err error) {
f, err := os.Open(curMigrationPath)
if err != nil {
return "", "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
var logPathRel string
var logPathPrefix = "-- migrations_log:"
var commandPrefix = "-- sql_command:"
for scanner.Scan() {
txt := scanner.Text()
txt = strings.TrimSpace(txt)
if strings.HasPrefix(txt, logPathPrefix) {
logPathRel = strings.TrimSpace(txt[len(logPathPrefix):])
continue
} else if strings.HasPrefix(txt, commandPrefix) {
sqlCommand = strings.TrimSpace(txt[len(commandPrefix):])
continue
}
}
if logPathRel == "" {
return "", "", fmt.Errorf("Could not find '-- migrations_log: <relative-path>' in %q", curMigrationPath)
}
if sqlCommand == "" {
return "", "", fmt.Errorf("Could not find '-- sql_command: <args>' in %q", curMigrationPath)
}
// migrationsDir := filepath.Dir(curMigrationPath)
// logPath = filepath.Join(migrationsDir, logPathRel)
// return sqlCommand, logPath, nil
return sqlCommand, logPathRel, nil
}
func migrationsLogInit(state *State, subcmd string) error {
logDir := filepath.Dir(state.LogPath)
if err := os.MkdirAll(logDir, 0755); err != nil {
return err
}
if subcmd != "up" {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "Run the first migration to complete the initialization:\n")
fmt.Fprintf(os.Stderr, "(you'll need to provide DB credentials via .env or export)\n")
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, " sql-migrate -d %q up > ./up.sh && sh ./up.sh\n", state.MigrationsDir)
fmt.Fprintf(os.Stderr, "\n")
}
return nil
}
func (state *State) parseAndFixupBatches(text string) error {
text = strings.TrimSpace(text)
if text == "" {
return nil
}
fixedUp := []string{}
fixedDown := []string{}
state.Lines = strings.Split(text, "\n")
for i := range state.Lines {
line := strings.TrimSpace(state.Lines[i])
migration := commentStartRe.ReplaceAllString(line, "")
migration = strings.TrimSpace(migration)
if migration != "" {
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
state.Migrated = append(state.Migrated, migration)
if up {
fixedUp = append(fixedUp, migration)
}
if down {
fixedDown = append(fixedDown, migration)
}
}
state.Lines[i] = line
}
showFixes(fixedUp, fixedDown)
return nil
}
func showFixes(fixedUp, fixedDown []string) {
if len(fixedUp) > 0 {
fmt.Fprintf(os.Stderr, "Fixup: appended missing 'INSERT INTO _migrations ...' to:\n")
for _, up := range fixedUp {
fmt.Fprintf(os.Stderr, " %s\n", up)
}
fmt.Fprintf(os.Stderr, "\n")
}
if len(fixedDown) > 0 {
fmt.Fprintf(os.Stderr, "Fixup: appended missing 'DELETE FROM _migrations ...' to:\n")
for _, down := range fixedDown {
fmt.Fprintf(os.Stderr, " %s\n", down)
}
fmt.Fprintf(os.Stderr, "\n")
}
}
func create(state *State, desc string) error {
dateStr := state.Date.Format("2006-01-02")
entries, err := os.ReadDir(state.MigrationsDir)
if err != nil {
return err
}
maxNumber := 0
datePrefix := dateStr + "-"
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, datePrefix) {
continue
}
if !strings.HasSuffix(name, ".up.sql") {
continue
}
if strings.HasSuffix(name, "_"+desc+".up.sql") {
return fmt.Errorf("migration for %q already exists:\n %s", desc, state.MigrationsDir+"/"+name)
}
if strings.HasSuffix(name, ".down.sql") {
continue
}
parts := strings.SplitN(name, "-", 4)
if len(parts) < 4 {
continue
}
numDesc := strings.SplitN(parts[3], "_", 2)
if len(numDesc) < 2 {
continue
}
num, err := strconv.Atoi(numDesc[0])
if err != nil {
continue
}
if num > maxNumber {
maxNumber = num
}
}
number := maxNumber / 1_000
number *= 1_000
number += 1_000
if number > 9_000 && number < 10_000 {
fmt.Fprintf(os.Stderr, "Achievement Unlocked: It's over 9000!\n")
}
if number >= 999_999 {
fmt.Fprintf(os.Stderr, "Error: cowardly refusing to generate such a suspiciously high number of migrations after running out of numbers\n")
os.Exit(1)
}
basename := fmt.Sprintf("%s-%06d_%s", dateStr, number, desc)
upPath := filepath.Join(state.MigrationsDir, basename+".up.sql")
downPath := filepath.Join(state.MigrationsDir, basename+".down.sql")
id := MustRandomHex(4)
// Little Bobby Drop Tables says:
// We trust the person running the migrations to not use malicious names.
// (we don't want to embed db-specific logic here, and SQL doesn't define escaping)
migrationInsert := fmt.Sprintf("INSERT INTO _migrations (name, id) VALUES ('%s', '%s');", basename, id)
upContent := fmt.Appendf(nil, "-- %s (up)\nSELECT 'place your UP migration here';\n\n-- leave this as the last line\n%s\n", desc, migrationInsert)
_ = os.WriteFile(upPath, upContent, 0644)
migrationDelete := fmt.Sprintf("DELETE FROM _migrations WHERE id = '%s';", id)
downContent := fmt.Appendf(nil, "-- %s (down)\nSELECT 'place your DOWN migration here';\n\n-- leave this as the last line\n%s\n", desc, migrationDelete)
_ = os.WriteFile(downPath, downContent, 0644)
fmt.Fprintf(os.Stderr, " created pair %s\n", filepathUnclean(upPath))
fmt.Fprintf(os.Stderr, " %s\n", filepathUnclean(downPath))
return nil
}
func MustRandomHex(n int) string {
s, err := RandomHex(n)
if err != nil {
panic(err)
}
return s
}
func RandomHex(n int) (string, error) {
b := make([]byte, n) // 4 bytes = 8 hex chars
_, err := rand.Read(b)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// attempts to add missing INSERT and DELETE without breaking what already works
func fixupMigration(dir string, basename string) (up, down bool, warn error, err error) {
var id string
var insertsOnUp bool
upPath := filepath.Join(dir, basename+".up.sql")
upScan, err := os.Open(upPath)
if err != nil {
return false, false, nil, fmt.Errorf("failed (up): %w", err)
}
defer upScan.Close()
scanner := bufio.NewScanner(upScan)
for scanner.Scan() {
txt := scanner.Text()
txt = strings.TrimSpace(txt)
txt = strings.ToLower(txt)
if strings.HasPrefix(txt, "insert into _migrations") {
insertsOnUp = true
break
}
}
if !insertsOnUp {
id = MustRandomHex(4)
upScan.Close()
upBytes, err := os.ReadFile(upPath)
if err != nil {
warn = fmt.Errorf("failed to add 'INSERT INTO _migrations ...' to %s: %w", upPath, err)
return false, false, warn, nil
}
migrationInsertLn := fmt.Sprintf("\n-- leave this as the last line\nINSERT INTO _migrations (name, id) VALUES ('%s', '%s');\n", basename, id)
upBytes = append(upBytes, []byte(migrationInsertLn)...)
if err = os.WriteFile(upPath, upBytes, 0644); err != nil {
warn = fmt.Errorf("failed to append 'INSERT INTO _migrations ...' to %s: %w", upPath, err)
return false, false, warn, nil
}
up = true
}
var deletesOnDown bool
downPath := filepath.Join(dir, basename+".down.sql")
downScan, err := os.Open(downPath)
if err != nil {
return false, false, fmt.Errorf("failed (down): %w", err), nil
}
defer downScan.Close()
scanner = bufio.NewScanner(downScan)
for scanner.Scan() {
txt := scanner.Text()
txt = strings.TrimSpace(txt)
txt = strings.ToLower(txt)
if strings.HasPrefix(txt, "delete from _migrations") {
deletesOnDown = true
break
}
}
if !deletesOnDown {
if id == "" {
return false, false, fmt.Errorf("must manually append \"DELETE FROM _migrations WHERE id = '<id>'\" to %s with id from %s", downPath, basename+"up.sql"), nil
}
downScan.Close()
downFile, err := os.OpenFile(downPath, os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
warn = fmt.Errorf("failed to append 'DELETE FROM _migrations ...' to %s: %v", downPath, err)
return false, false, warn, nil
}
defer downFile.Close()
migrationInsertLn := fmt.Sprintf("\nDELETE FROM _migrations WHERE id = '%s';\n", id)
_, err = downFile.Write(([]byte(migrationInsertLn)))
if err != nil {
warn = fmt.Errorf("failed to add 'DELETE FROM _migrations ...' to %s: %w", downPath, err)
return false, false, warn, nil
}
down = true
}
return up, down, nil, nil
}
func syncLog(state *State) error {
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
getMigsPath = filepathUnclean(getMigsPath)
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
logPath := filepathUnclean(state.LogPath)
fmt.Printf(shHeader)
fmt.Println("")
fmt.Println("# SYNC: reload migrations log from DB")
fmt.Printf("%s > %s || true\n", getMigs, logPath)
fmt.Printf("cat %s\n", logPath)
return nil
}
func up(state *State, ups []string, n int) error {
var pending []string
for _, mig := range ups {
found := slices.Contains(state.Migrated, mig)
if !found {
pending = append(pending, mig)
}
}
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
getMigsPath = filepathUnclean(getMigsPath)
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
if len(pending) == 0 {
fmt.Fprintf(os.Stderr, "# Already up-to-date\n")
fmt.Fprintf(os.Stderr, "#\n")
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath))
return nil
}
if n == 0 {
n = len(pending)
}
fixedUp := []string{}
fixedDown := []string{}
fmt.Printf(shHeader)
fmt.Println("")
fmt.Println("# FORWARD / UP Migrations")
fmt.Println("")
for i, migration := range pending {
if i >= n {
break
}
path := filepath.Join(state.MigrationsDir, migration+".up.sql")
path = filepathUnclean(path)
{
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
if up {
fixedUp = append(fixedUp, migration)
}
if down {
fixedDown = append(fixedDown, migration)
}
}
cmd := strings.Replace(state.SQLCommand, "%s", path, 1)
fmt.Printf("# +%d %s\n", i+1, migration)
fmt.Println(cmd)
fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath))
fmt.Println("")
}
fmt.Println("cat", filepathUnclean(state.LogPath))
showFixes(fixedUp, fixedDown)
return nil
}
func down(state *State, n int) error {
lines := make([]string, len(state.Lines))
copy(lines, state.Lines)
slices.Reverse(lines)
getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME)
getMigsPath = filepathUnclean(getMigsPath)
getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1)
if len(lines) == 0 {
fmt.Fprintf(os.Stderr, "# No migration history\n")
fmt.Fprintf(os.Stderr, "#\n")
fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n")
fmt.Fprintf(os.Stderr, "# %s > %s\n", getMigs, filepathUnclean(state.LogPath))
return nil
}
if n == 0 {
n = 1
}
fixedUp := []string{}
fixedDown := []string{}
var applied []string
for _, line := range lines {
migration := commentStartRe.ReplaceAllString(line, "")
migration = strings.TrimSpace(migration)
if migration == "" {
continue
}
applied = append(applied, migration)
}
fmt.Printf(shHeader)
fmt.Println("")
fmt.Println("# ROLLBACK / DOWN Migration")
fmt.Println("")
for i, migration := range applied {
if i >= n {
break
}
downPath := filepath.Join(state.MigrationsDir, migration+".down.sql")
cmd := strings.Replace(state.SQLCommand, "%s", downPath, 1)
fmt.Printf("\n# -%d %s\n", i+1, migration)
if !fileExists(downPath) {
fmt.Fprintf(os.Stderr, "# Warn: missing %s\n", filepathUnclean(downPath))
fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n")
fmt.Printf("# ERROR: MISSING FILE\n")
} else {
up, down, warn, err := fixupMigration(state.MigrationsDir, migration)
if warn != nil {
fmt.Fprintf(os.Stderr, "Warn: %s\n", warn)
}
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
if up {
fixedUp = append(fixedUp, migration)
}
if down {
fixedDown = append(fixedDown, migration)
}
}
fmt.Println(cmd)
fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath))
fmt.Println("")
}
fmt.Println("cat", filepathUnclean(state.LogPath))
showFixes(fixedUp, fixedDown)
return nil
}
func status(state *State, ups []string) error {
previous := make([]string, len(state.Lines))
copy(previous, state.Lines)
slices.Reverse(previous)
fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", filepathUnclean(state.MigrationsDir))
fmt.Fprintf(os.Stderr, "migrations_log: %s\n", filepathUnclean(state.LogPath))
fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand)
fmt.Fprintf(os.Stderr, "\n")
fmt.Printf("# previous: %d\n", len(previous))
for _, mig := range previous {
fmt.Printf(" %s\n", mig)
}
if len(previous) == 0 {
fmt.Println(" # (no previous migrations)")
}
fmt.Println("")
var pending []string
for _, mig := range ups {
found := slices.Contains(state.Migrated, mig)
if !found {
pending = append(pending, mig)
}
}
fmt.Printf("# pending: %d\n", len(pending))
for _, mig := range pending {
fmt.Printf(" %s\n", mig)
}
if len(pending) == 0 {
fmt.Println(" # (no pending migrations)")
}
return nil
}