From ab9dbdb6c67dd8868f0f42a0f1f834f19a379079 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 27 Oct 2025 11:49:00 -0600 Subject: [PATCH] feat(sqlmigrator): can write up and down migrations --- cmd/sql-migrate/main.go | 503 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 503 insertions(+) create mode 100644 cmd/sql-migrate/main.go diff --git a/cmd/sql-migrate/main.go b/cmd/sql-migrate/main.go new file mode 100644 index 0000000..2bccc74 --- /dev/null +++ b/cmd/sql-migrate/main.go @@ -0,0 +1,503 @@ +// +// Written in 2025 by AJ ONeal +// +// To the extent possible under law, the author(s) have dedicated all copyright +// and related and neighboring rights to this software to the public domain +// worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along with +// this software. If not, see . + +// Package sql-migrate provides a simple SQL migrator that's easy to roll back or mix and match during development +package main + +import ( + "flag" + "fmt" + "log" + "os" + "path/filepath" + "regexp" + "slices" + "sort" + "strconv" + "strings" + "time" +) + +const ( + defaultMigrationDir = "./sql/migrations/" + defaultMigrationLog = "./sql/migrations.log" + defaultCommand = `psql "$PG_URL" < %s` +) + +var ( + nonWordRe = regexp.MustCompile(`\W+`) + commandStartRe = regexp.MustCompile(`^#\s*command:\s*`) + batchStartRe = regexp.MustCompile(`^#\s*batch:\s*`) + commentStartRe = regexp.MustCompile(`(^|\s+)#.*`) +) + +type State struct { + Date time.Time + Command string + Current int + Lines []string + Migrated []string + SqlDir string + LogFile string +} + +func parseLog(text string, date time.Time) *State { + state := &State{Date: date, Command: "", Current: 0, Lines: []string{}, Migrated: []string{}} + text = strings.TrimSpace(text) + if text == "" { + state.Command = defaultCommand + return state + } + state.Lines = strings.Split(text, "\n") + batchCount := 0 + for i := range state.Lines { + line := strings.TrimSpace(state.Lines[i]) + if commandStartRe.MatchString(line) { + if state.Command != "" { + log.Printf(" ignoring duplicate '%s'", line) + } else { + state.Command = commandStartRe.ReplaceAllString(line, "") + } + } + if batchStartRe.MatchString(line) { + parts := strings.SplitN(line, ":", 2) + if len(parts) < 2 { + continue + } + n, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil || n <= 0 { + log.Printf(" invalid '%s'", line) + n = -1 + } + batchCount++ + if n > state.Current { + state.Current = n + } + if batchCount > state.Current { + state.Current = batchCount + } + } + migration := commentStartRe.ReplaceAllString(line, "") + migration = strings.TrimSpace(migration) + if migration != "" { + state.Migrated = append(state.Migrated, migration) + } + state.Lines[i] = line + } + if state.Command == "" { + state.Command = defaultCommand + } + if !strings.Contains(state.Command, "%s") { + state.Command += " %s" + } + return state +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func create(state *State, desc string) error { + dateStr := state.Date.Format("2006-01-02") + entries, err := os.ReadDir(state.SqlDir) + if err != nil { + return err + } + + maxNumber := 0 + datePrefix := dateStr + "-" + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if !strings.HasPrefix(name, datePrefix) { + continue + } + if !strings.HasSuffix(name, ".up.sql") { + continue + } + if strings.HasSuffix(name, "_"+desc+".up.sql") { + return fmt.Errorf("migration for %q already exists:\n %s", desc, state.SqlDir+"/"+name) + } + if strings.HasSuffix(name, ".down.sql") { + continue + } + + parts := strings.SplitN(name, "-", 4) + if len(parts) < 4 { + continue + } + numDesc := strings.SplitN(parts[3], "_", 2) + if len(numDesc) < 2 { + continue + } + num, err := strconv.Atoi(numDesc[0]) + if err != nil { + continue + } + + if num > maxNumber { + maxNumber = num + } + } + + number := maxNumber / 1000 + number *= 1000 + number += 1000 + if number > 9000 { + return fmt.Errorf("it's over 9000! ") + } + + baseFilename := fmt.Sprintf("%s-%06d_%s", dateStr, number, desc) + upPath := filepath.Join(state.SqlDir, baseFilename+".up.sql") + downPath := filepath.Join(state.SqlDir, baseFilename+".down.sql") + + // Use fmt.Appendf to build byte slice, ignoring error as it can't fail with static format + upContent := fmt.Appendf(nil, "-- %s (up)\n", desc) + _ = os.WriteFile(upPath, upContent, 0644) + downContent := fmt.Appendf(nil, "-- %s (down)\n", desc) + _ = os.WriteFile(downPath, downContent, 0644) + + fmt.Fprintf(os.Stderr, " created pair %s\n", upPath) + fmt.Fprintf(os.Stderr, " %s\n", downPath) + return nil +} + +func listMigrations(state *State) (ups, downs []string, err error) { + entries, err := os.ReadDir(state.SqlDir) + if err != nil { + return nil, nil, err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { + log.Printf(" ignoring '%s'", name) + continue + } + if strings.HasSuffix(name, ".up.sql") { + base := strings.TrimSuffix(name, ".up.sql") + ups = append(ups, base) + companion := filepath.Join(state.SqlDir, base+".down.sql") + if !fileExists(companion) { + log.Printf(" missing '%s'", companion) + } + continue + } + if strings.HasSuffix(name, ".down.sql") { + base := strings.TrimSuffix(name, ".down.sql") + downs = append(downs, base) + companion := filepath.Join(state.SqlDir, base+".up.sql") + if !fileExists(companion) { + log.Printf(" missing '%s'", companion) + } + continue + } + log.Printf(" unknown '%s'", name) + } + sort.Strings(ups) + sort.Strings(downs) + return ups, downs, nil +} + +func up(state *State) error { + ups, _, err := listMigrations(state) + if err != nil { + return err + } + var pending []string + for _, mig := range ups { + found := slices.Contains(state.Migrated, mig) + if !found { + pending = append(pending, mig) + } + } + if len(pending) == 0 { + log.Println(" already up-to-date") + return nil + } + n := state.Current + 1 + fmt.Printf("echo '# batch: %d' >> %s\n", n, state.LogFile) + for _, mig := range pending { + fmt.Println("") + fmt.Printf("# INSERT INTO \"migrations\" ('%d', '%s')\n", n, mig) + fmt.Printf("echo '%s' >> %s\n", mig, state.LogFile) + path := filepath.Join(state.SqlDir, mig+".up.sql") + if !strings.HasPrefix(path, "/") { + if !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") { + path = "./" + path + } + } + cmd := strings.Replace(state.Command, "%s", path, 1) + fmt.Println(cmd) + } + fmt.Println("") + return nil +} + +func down(state *State) error { + lines := make([]string, len(state.Lines)) + copy(lines, state.Lines) + lineCount := len(lines) + slices.Reverse(lines) + var batchLine string + var batch []string + for _, line := range lines { + lineCount-- + if batchStartRe.MatchString(line) { + batchLine = line + break + } + mig := commentStartRe.ReplaceAllString(line, "") + mig = strings.TrimSpace(mig) + if mig == "" { + log.Printf(" ignoring '%s'", line) + continue + } + batch = append(batch, mig) + } + log.Printf("ROLLBACK %s", batchLine) + for _, mig := range batch { + fmt.Println("") + fmt.Printf("# DELETE FROM \"migrations\" WHERE \"name\" = '%s';\n", mig) + sqlfile := filepath.Join(state.SqlDir, mig+".down.sql") + if !fileExists(sqlfile) { + log.Printf(" missing '%s'", sqlfile) + } + cmd := strings.Replace(state.Command, "%s", sqlfile, 1) + fmt.Println(cmd) + } + fmt.Println("") + fmt.Println("# new file as to not overwrite the file while reading") + fmt.Printf("head -n '%d' %s > %s.new\n", lineCount, state.LogFile, state.LogFile) + fmt.Printf("mv %s.new %s\n", state.LogFile, state.LogFile) + fmt.Println("") + return nil +} + +func status(state *State) error { + lines := make([]string, len(state.Lines)) + copy(lines, state.Lines) + hasCommand := commandStartRe.MatchString(lines[0]) + if hasCommand { + lines = lines[1:] + } + slices.Reverse(lines) + var previous []string + for _, line := range lines { + previous = append([]string{line}, previous...) + if batchStartRe.MatchString(line) { + break + } + } + fmt.Fprintf(os.Stderr, "sqldir: %s\n", state.SqlDir) + fmt.Fprintf(os.Stderr, "logfile: %s\n", state.LogFile) + fmt.Fprintf(os.Stderr, "command: %s\n", state.Command) + fmt.Fprintf(os.Stderr, "\n") + fmt.Printf("# previous: %d\n", len(previous)) + for _, mig := range previous { + fmt.Printf(" %s\n", mig) + } + if len(previous) == 0 { + fmt.Println(" # (no previous migrations)") + } + fmt.Println("") + ups, _, err := listMigrations(state) + if err != nil { + return err + } + var pending []string + for _, mig := range ups { + found := slices.Contains(state.Migrated, mig) + if !found { + pending = append(pending, mig) + } + } + fmt.Printf("# pending: %d\n", len(pending)) + for _, mig := range pending { + fmt.Printf(" %s\n", mig) + } + if len(pending) == 0 { + fmt.Println(" # (no pending migrations)") + } + return nil +} + +const helpText = ` +sql-migrate v0.7.0 - a feature-branch-friendly SQL migrator + +USAGE + sql-migrate [-d sqldir] [-f logfile] [args] + +EXAMPLE + sql-migrate init -d ./sql/migrations/ -f ./sql/migrations.log + sql-migrate create + sql-migrate status + sql-migrate up + sql-migrate down + sql-migrate list + +COMMANDS + init - inits sql dir and migration file, adding or updating the + default command + create - creates a new, canonically-named up/down file pair in the + migrations directory + status - shows the same output as if processing a forward-migration + for the most recent batch + up - processes the first 'up' migration file missing from the + migration state + down - rolls back the latest entry of the latest migration batch + (the whole batch if just one) + list - lists migrations + +OPTIONS + -d default: ./sql/migrations/ + -f default: ./sql/migrations.log + +NOTES + Migrations files are in the following format: + -_..sql + 2020-01-01-1000_init.up.sql + + The migration state file contains the client command template (defaults to + 'psql "$PG_URL" < %s'), followed by a list of batches identified by a batch + number comment and a list of migration file basenames and optional user + comments, such as: + # command: psql "$PG_URL" < %s + # batch: 1 + 2020-01-01-1000_init.up.sql # does a lot + 2020-01-01-1100_add-customer-tables.up.sql + # batch: 2 + # We did id! Finally! + 2020-01-01-2000_add-ALL-THE-TABLES.up.sql + + The 'create' generates an up/down pair of files using the current date and + the number 1000. If either file exists, the number is incremented by 1000 and + tried again, up to 9000, or throws the error "it's over 9000!" on failure. +` + +func main() { + if len(os.Args) < 2 { + //nolint + fmt.Printf("%s\n", helpText) + os.Exit(0) + } + + command := os.Args[1] + if command == "help" || command == "--help" || command == "-h" || command == "version" || command == "--version" || command == "-V" { + fmt.Printf("%s\n", helpText) + os.Exit(0) + } + + fs := flag.NewFlagSet(command, flag.ExitOnError) + sqlDir := fs.String("d", defaultMigrationDir, "migrations directory") + logFile := fs.String("f", defaultMigrationLog, "migration log file") + if err := fs.Parse(os.Args[2:]); err != nil { + os.Exit(2) + } + + date := time.Now() + var state *State + var err error + + logText, err := os.ReadFile(*logFile) + if os.IsNotExist(err) { + if command != "init" { + log.Printf(" run 'init' first: missing '%s'", *logFile) + os.Exit(1) + } + text := fmt.Sprintf("# command: %s\n", defaultCommand) + dir := filepath.Dir(*logFile) + err = os.MkdirAll(*sqlDir, 0755) + if err != nil { + log.Fatal(err) + } + err = os.MkdirAll(dir, 0755) + if err != nil { + log.Fatal(err) + } + err = os.WriteFile(*logFile, []byte(text), 0644) + if err != nil { + log.Fatal(err) + } + log.Printf(" created '%s'", *logFile) + logText = []byte{} + } else if err != nil { + log.Fatal(err) + } + + state = parseLog(string(logText), date) + state.SqlDir = *sqlDir + state.LogFile = *logFile + + switch command { + case "init": + if len(logText) > 0 { + log.Printf(" found '%s'", *logFile) + } + case "create": + args := fs.Args() + if len(args) == 0 { + log.Fatal("create requires a description") + } + desc := strings.Join(args, " ") + desc = nonWordRe.ReplaceAllString(desc, " ") + desc = strings.TrimSpace(desc) + desc = nonWordRe.ReplaceAllString(desc, "-") + desc = strings.ToLower(desc) + err = create(state, desc) + if err != nil { + log.Fatal(err) + } + case "status": + err = status(state) + if err != nil { + log.Fatal(err) + } + case "list": + ups, downs, err := listMigrations(state) + if err != nil { + log.Fatal(err) + } + fmt.Println("Ups:") + if len(ups) == 0 { + fmt.Println(" (none)") + } + for _, u := range ups { + fmt.Println(u) + } + fmt.Println("Downs:") + if len(downs) == 0 { + fmt.Println(" (none)") + } + for _, d := range downs { + fmt.Println(d) + } + case "up": + err = up(state) + if err != nil { + log.Fatal(err) + } + case "down": + err = down(state) + if err != nil { + log.Fatal(err) + } + default: + log.Printf("unknown command %s", command) + fmt.Printf("%s\n", helpText) + os.Exit(1) + } +}