镜像来自
https://github.com/therootcompany/golib.git
synced 2025-10-29 20:22:51 +00:00
feat(sqlmigrator): can write up and down migrations
This commit is contained in:
父節點
710438c44e
當前提交
ab9dbdb6c6
503
cmd/sql-migrate/main.go
Normal file
503
cmd/sql-migrate/main.go
Normal file
@ -0,0 +1,503 @@
|
||||
//
|
||||
// Written in 2025 by AJ ONeal <aj@therootcompany.com>
|
||||
//
|
||||
// To the extent possible under law, the author(s) have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
//
|
||||
// You should have received a copy of the CC0 Public Domain Dedication along with
|
||||
// this software. If not, see <https://creativecommons.org/publicdomain/zero/1.0/>.
|
||||
|
||||
// Package sql-migrate provides a simple SQL migrator that's easy to roll back or mix and match during development
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMigrationDir = "./sql/migrations/"
|
||||
defaultMigrationLog = "./sql/migrations.log"
|
||||
defaultCommand = `psql "$PG_URL" < %s`
|
||||
)
|
||||
|
||||
var (
|
||||
nonWordRe = regexp.MustCompile(`\W+`)
|
||||
commandStartRe = regexp.MustCompile(`^#\s*command:\s*`)
|
||||
batchStartRe = regexp.MustCompile(`^#\s*batch:\s*`)
|
||||
commentStartRe = regexp.MustCompile(`(^|\s+)#.*`)
|
||||
)
|
||||
|
||||
type State struct {
|
||||
Date time.Time
|
||||
Command string
|
||||
Current int
|
||||
Lines []string
|
||||
Migrated []string
|
||||
SqlDir string
|
||||
LogFile string
|
||||
}
|
||||
|
||||
func parseLog(text string, date time.Time) *State {
|
||||
state := &State{Date: date, Command: "", Current: 0, Lines: []string{}, Migrated: []string{}}
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
state.Command = defaultCommand
|
||||
return state
|
||||
}
|
||||
state.Lines = strings.Split(text, "\n")
|
||||
batchCount := 0
|
||||
for i := range state.Lines {
|
||||
line := strings.TrimSpace(state.Lines[i])
|
||||
if commandStartRe.MatchString(line) {
|
||||
if state.Command != "" {
|
||||
log.Printf(" ignoring duplicate '%s'", line)
|
||||
} else {
|
||||
state.Command = commandStartRe.ReplaceAllString(line, "")
|
||||
}
|
||||
}
|
||||
if batchStartRe.MatchString(line) {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil || n <= 0 {
|
||||
log.Printf(" invalid '%s'", line)
|
||||
n = -1
|
||||
}
|
||||
batchCount++
|
||||
if n > state.Current {
|
||||
state.Current = n
|
||||
}
|
||||
if batchCount > state.Current {
|
||||
state.Current = batchCount
|
||||
}
|
||||
}
|
||||
migration := commentStartRe.ReplaceAllString(line, "")
|
||||
migration = strings.TrimSpace(migration)
|
||||
if migration != "" {
|
||||
state.Migrated = append(state.Migrated, migration)
|
||||
}
|
||||
state.Lines[i] = line
|
||||
}
|
||||
if state.Command == "" {
|
||||
state.Command = defaultCommand
|
||||
}
|
||||
if !strings.Contains(state.Command, "%s") {
|
||||
state.Command += " %s"
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func create(state *State, desc string) error {
|
||||
dateStr := state.Date.Format("2006-01-02")
|
||||
entries, err := os.ReadDir(state.SqlDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maxNumber := 0
|
||||
datePrefix := dateStr + "-"
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, datePrefix) {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(name, ".up.sql") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_"+desc+".up.sql") {
|
||||
return fmt.Errorf("migration for %q already exists:\n %s", desc, state.SqlDir+"/"+name)
|
||||
}
|
||||
if strings.HasSuffix(name, ".down.sql") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(name, "-", 4)
|
||||
if len(parts) < 4 {
|
||||
continue
|
||||
}
|
||||
numDesc := strings.SplitN(parts[3], "_", 2)
|
||||
if len(numDesc) < 2 {
|
||||
continue
|
||||
}
|
||||
num, err := strconv.Atoi(numDesc[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if num > maxNumber {
|
||||
maxNumber = num
|
||||
}
|
||||
}
|
||||
|
||||
number := maxNumber / 1000
|
||||
number *= 1000
|
||||
number += 1000
|
||||
if number > 9000 {
|
||||
return fmt.Errorf("it's over 9000! ")
|
||||
}
|
||||
|
||||
baseFilename := fmt.Sprintf("%s-%06d_%s", dateStr, number, desc)
|
||||
upPath := filepath.Join(state.SqlDir, baseFilename+".up.sql")
|
||||
downPath := filepath.Join(state.SqlDir, baseFilename+".down.sql")
|
||||
|
||||
// Use fmt.Appendf to build byte slice, ignoring error as it can't fail with static format
|
||||
upContent := fmt.Appendf(nil, "-- %s (up)\n", desc)
|
||||
_ = os.WriteFile(upPath, upContent, 0644)
|
||||
downContent := fmt.Appendf(nil, "-- %s (down)\n", desc)
|
||||
_ = os.WriteFile(downPath, downContent, 0644)
|
||||
|
||||
fmt.Fprintf(os.Stderr, " created pair %s\n", upPath)
|
||||
fmt.Fprintf(os.Stderr, " %s\n", downPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func listMigrations(state *State) (ups, downs []string, err error) {
|
||||
entries, err := os.ReadDir(state.SqlDir)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
|
||||
log.Printf(" ignoring '%s'", name)
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, ".up.sql") {
|
||||
base := strings.TrimSuffix(name, ".up.sql")
|
||||
ups = append(ups, base)
|
||||
companion := filepath.Join(state.SqlDir, base+".down.sql")
|
||||
if !fileExists(companion) {
|
||||
log.Printf(" missing '%s'", companion)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, ".down.sql") {
|
||||
base := strings.TrimSuffix(name, ".down.sql")
|
||||
downs = append(downs, base)
|
||||
companion := filepath.Join(state.SqlDir, base+".up.sql")
|
||||
if !fileExists(companion) {
|
||||
log.Printf(" missing '%s'", companion)
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf(" unknown '%s'", name)
|
||||
}
|
||||
sort.Strings(ups)
|
||||
sort.Strings(downs)
|
||||
return ups, downs, nil
|
||||
}
|
||||
|
||||
func up(state *State) error {
|
||||
ups, _, err := listMigrations(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pending []string
|
||||
for _, mig := range ups {
|
||||
found := slices.Contains(state.Migrated, mig)
|
||||
if !found {
|
||||
pending = append(pending, mig)
|
||||
}
|
||||
}
|
||||
if len(pending) == 0 {
|
||||
log.Println(" already up-to-date")
|
||||
return nil
|
||||
}
|
||||
n := state.Current + 1
|
||||
fmt.Printf("echo '# batch: %d' >> %s\n", n, state.LogFile)
|
||||
for _, mig := range pending {
|
||||
fmt.Println("")
|
||||
fmt.Printf("# INSERT INTO \"migrations\" ('%d', '%s')\n", n, mig)
|
||||
fmt.Printf("echo '%s' >> %s\n", mig, state.LogFile)
|
||||
path := filepath.Join(state.SqlDir, mig+".up.sql")
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
if !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") {
|
||||
path = "./" + path
|
||||
}
|
||||
}
|
||||
cmd := strings.Replace(state.Command, "%s", path, 1)
|
||||
fmt.Println(cmd)
|
||||
}
|
||||
fmt.Println("")
|
||||
return nil
|
||||
}
|
||||
|
||||
func down(state *State) error {
|
||||
lines := make([]string, len(state.Lines))
|
||||
copy(lines, state.Lines)
|
||||
lineCount := len(lines)
|
||||
slices.Reverse(lines)
|
||||
var batchLine string
|
||||
var batch []string
|
||||
for _, line := range lines {
|
||||
lineCount--
|
||||
if batchStartRe.MatchString(line) {
|
||||
batchLine = line
|
||||
break
|
||||
}
|
||||
mig := commentStartRe.ReplaceAllString(line, "")
|
||||
mig = strings.TrimSpace(mig)
|
||||
if mig == "" {
|
||||
log.Printf(" ignoring '%s'", line)
|
||||
continue
|
||||
}
|
||||
batch = append(batch, mig)
|
||||
}
|
||||
log.Printf("ROLLBACK %s", batchLine)
|
||||
for _, mig := range batch {
|
||||
fmt.Println("")
|
||||
fmt.Printf("# DELETE FROM \"migrations\" WHERE \"name\" = '%s';\n", mig)
|
||||
sqlfile := filepath.Join(state.SqlDir, mig+".down.sql")
|
||||
if !fileExists(sqlfile) {
|
||||
log.Printf(" missing '%s'", sqlfile)
|
||||
}
|
||||
cmd := strings.Replace(state.Command, "%s", sqlfile, 1)
|
||||
fmt.Println(cmd)
|
||||
}
|
||||
fmt.Println("")
|
||||
fmt.Println("# new file as to not overwrite the file while reading")
|
||||
fmt.Printf("head -n '%d' %s > %s.new\n", lineCount, state.LogFile, state.LogFile)
|
||||
fmt.Printf("mv %s.new %s\n", state.LogFile, state.LogFile)
|
||||
fmt.Println("")
|
||||
return nil
|
||||
}
|
||||
|
||||
func status(state *State) error {
|
||||
lines := make([]string, len(state.Lines))
|
||||
copy(lines, state.Lines)
|
||||
hasCommand := commandStartRe.MatchString(lines[0])
|
||||
if hasCommand {
|
||||
lines = lines[1:]
|
||||
}
|
||||
slices.Reverse(lines)
|
||||
var previous []string
|
||||
for _, line := range lines {
|
||||
previous = append([]string{line}, previous...)
|
||||
if batchStartRe.MatchString(line) {
|
||||
break
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "sqldir: %s\n", state.SqlDir)
|
||||
fmt.Fprintf(os.Stderr, "logfile: %s\n", state.LogFile)
|
||||
fmt.Fprintf(os.Stderr, "command: %s\n", state.Command)
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
fmt.Printf("# previous: %d\n", len(previous))
|
||||
for _, mig := range previous {
|
||||
fmt.Printf(" %s\n", mig)
|
||||
}
|
||||
if len(previous) == 0 {
|
||||
fmt.Println(" # (no previous migrations)")
|
||||
}
|
||||
fmt.Println("")
|
||||
ups, _, err := listMigrations(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pending []string
|
||||
for _, mig := range ups {
|
||||
found := slices.Contains(state.Migrated, mig)
|
||||
if !found {
|
||||
pending = append(pending, mig)
|
||||
}
|
||||
}
|
||||
fmt.Printf("# pending: %d\n", len(pending))
|
||||
for _, mig := range pending {
|
||||
fmt.Printf(" %s\n", mig)
|
||||
}
|
||||
if len(pending) == 0 {
|
||||
fmt.Println(" # (no pending migrations)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const helpText = `
|
||||
sql-migrate v0.7.0 - a feature-branch-friendly SQL migrator
|
||||
|
||||
USAGE
|
||||
sql-migrate <command> [-d sqldir] [-f logfile] [args]
|
||||
|
||||
EXAMPLE
|
||||
sql-migrate init -d ./sql/migrations/ -f ./sql/migrations.log
|
||||
sql-migrate create <kebab-case-description>
|
||||
sql-migrate status
|
||||
sql-migrate up
|
||||
sql-migrate down
|
||||
sql-migrate list
|
||||
|
||||
COMMANDS
|
||||
init - inits sql dir and migration file, adding or updating the
|
||||
default command
|
||||
create - creates a new, canonically-named up/down file pair in the
|
||||
migrations directory
|
||||
status - shows the same output as if processing a forward-migration
|
||||
for the most recent batch
|
||||
up - processes the first 'up' migration file missing from the
|
||||
migration state
|
||||
down - rolls back the latest entry of the latest migration batch
|
||||
(the whole batch if just one)
|
||||
list - lists migrations
|
||||
|
||||
OPTIONS
|
||||
-d <migrations directory> default: ./sql/migrations/
|
||||
-f <migration state file> default: ./sql/migrations.log
|
||||
|
||||
NOTES
|
||||
Migrations files are in the following format:
|
||||
<yyyy-mm-dd>-<number>_<name>.<up|down>.sql
|
||||
2020-01-01-1000_init.up.sql
|
||||
|
||||
The migration state file contains the client command template (defaults to
|
||||
'psql "$PG_URL" < %s'), followed by a list of batches identified by a batch
|
||||
number comment and a list of migration file basenames and optional user
|
||||
comments, such as:
|
||||
# command: psql "$PG_URL" < %s
|
||||
# batch: 1
|
||||
2020-01-01-1000_init.up.sql # does a lot
|
||||
2020-01-01-1100_add-customer-tables.up.sql
|
||||
# batch: 2
|
||||
# We did id! Finally!
|
||||
2020-01-01-2000_add-ALL-THE-TABLES.up.sql
|
||||
|
||||
The 'create' generates an up/down pair of files using the current date and
|
||||
the number 1000. If either file exists, the number is incremented by 1000 and
|
||||
tried again, up to 9000, or throws the error "it's over 9000!" on failure.
|
||||
`
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
//nolint
|
||||
fmt.Printf("%s\n", helpText)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
command := os.Args[1]
|
||||
if command == "help" || command == "--help" || command == "-h" || command == "version" || command == "--version" || command == "-V" {
|
||||
fmt.Printf("%s\n", helpText)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
fs := flag.NewFlagSet(command, flag.ExitOnError)
|
||||
sqlDir := fs.String("d", defaultMigrationDir, "migrations directory")
|
||||
logFile := fs.String("f", defaultMigrationLog, "migration log file")
|
||||
if err := fs.Parse(os.Args[2:]); err != nil {
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
date := time.Now()
|
||||
var state *State
|
||||
var err error
|
||||
|
||||
logText, err := os.ReadFile(*logFile)
|
||||
if os.IsNotExist(err) {
|
||||
if command != "init" {
|
||||
log.Printf(" run 'init' first: missing '%s'", *logFile)
|
||||
os.Exit(1)
|
||||
}
|
||||
text := fmt.Sprintf("# command: %s\n", defaultCommand)
|
||||
dir := filepath.Dir(*logFile)
|
||||
err = os.MkdirAll(*sqlDir, 0755)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = os.MkdirAll(dir, 0755)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = os.WriteFile(*logFile, []byte(text), 0644)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf(" created '%s'", *logFile)
|
||||
logText = []byte{}
|
||||
} else if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
state = parseLog(string(logText), date)
|
||||
state.SqlDir = *sqlDir
|
||||
state.LogFile = *logFile
|
||||
|
||||
switch command {
|
||||
case "init":
|
||||
if len(logText) > 0 {
|
||||
log.Printf(" found '%s'", *logFile)
|
||||
}
|
||||
case "create":
|
||||
args := fs.Args()
|
||||
if len(args) == 0 {
|
||||
log.Fatal("create requires a description")
|
||||
}
|
||||
desc := strings.Join(args, " ")
|
||||
desc = nonWordRe.ReplaceAllString(desc, " ")
|
||||
desc = strings.TrimSpace(desc)
|
||||
desc = nonWordRe.ReplaceAllString(desc, "-")
|
||||
desc = strings.ToLower(desc)
|
||||
err = create(state, desc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case "status":
|
||||
err = status(state)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case "list":
|
||||
ups, downs, err := listMigrations(state)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Ups:")
|
||||
if len(ups) == 0 {
|
||||
fmt.Println(" (none)")
|
||||
}
|
||||
for _, u := range ups {
|
||||
fmt.Println(u)
|
||||
}
|
||||
fmt.Println("Downs:")
|
||||
if len(downs) == 0 {
|
||||
fmt.Println(" (none)")
|
||||
}
|
||||
for _, d := range downs {
|
||||
fmt.Println(d)
|
||||
}
|
||||
case "up":
|
||||
err = up(state)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case "down":
|
||||
err = down(state)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
default:
|
||||
log.Printf("unknown command %s", command)
|
||||
fmt.Printf("%s\n", helpText)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user