From a3ecf5ac8187109c6cd55a77ee4c0b49aac0e0f1 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Thu, 9 Apr 2026 02:04:37 -0600 Subject: [PATCH] ref(sqlmigrate): add subpath to Collect, add Latest/Drop convenience functions API changes for v1: - Collect(fsys, subpath) takes a subdirectory path (use "." for root), enabling embed.FS with //go:embed sql/migrations/*.sql - Latest() applies all pending migrations (shorthand for Up with n=-1) - Drop() rolls back all applied migrations (shorthand for Down with n=-1) --- database/sqlmigrate/sqlmigrate.go | 25 ++++++++++++++++++++++--- database/sqlmigrate/sqlmigrate_test.go | 14 +++++++------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/database/sqlmigrate/sqlmigrate.go b/database/sqlmigrate/sqlmigrate.go index c627c8f..104e739 100644 --- a/database/sqlmigrate/sqlmigrate.go +++ b/database/sqlmigrate/sqlmigrate.go @@ -71,11 +71,20 @@ var idFromInsert = regexp.MustCompile( `(?i)INSERT\s+INTO\s+_migrations\s*\(\s*name\s*,\s*id\s*\)\s*VALUES\s*\(\s*'[^']*'\s*,\s*'([0-9a-fA-F]+)'\s*\)`, ) -// Collect reads .up.sql and .down.sql files from fsys, pairs them by -// basename, and returns them sorted lexicographically by name. +// Collect reads .up.sql and .down.sql files from fsys under subpath, +// pairs them by basename, and returns them sorted lexicographically by name. +// If subpath is "" or ".", the root of fsys is used. // If the up SQL contains an INSERT INTO _migrations line, the hex ID // is extracted and stored in Migration.ID. -func Collect(fsys fs.FS) ([]Migration, error) { +func Collect(fsys fs.FS, subpath string) ([]Migration, error) { + if subpath != "" && subpath != "." { + var err error + fsys, err = fs.Sub(fsys, subpath) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrWalkFailed, err) + } + } + ups := map[string]string{} downs := map[string]string{} @@ -285,3 +294,13 @@ func GetStatus(ctx context.Context, r Migrator, migrations []Migration) (*Status Pending: pending, }, nil } + +// Latest applies all pending migrations. Equivalent to Up(ctx, r, migrations, -1). +func Latest(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) { + return Up(ctx, r, migrations, -1) +} + +// Drop rolls back all applied migrations. Equivalent to Down(ctx, r, migrations, -1). +func Drop(ctx context.Context, r Migrator, migrations []Migration) ([]string, error) { + return Down(ctx, r, migrations, -1) +} diff --git a/database/sqlmigrate/sqlmigrate_test.go b/database/sqlmigrate/sqlmigrate_test.go index 21576cc..8348e7c 100644 --- a/database/sqlmigrate/sqlmigrate_test.go +++ b/database/sqlmigrate/sqlmigrate_test.go @@ -63,7 +63,7 @@ func TestCollect(t *testing.T) { "001_first.up.sql": {Data: []byte("CREATE TABLE a;")}, "001_first.down.sql": {Data: []byte("DROP TABLE a;")}, } - migrations, err := sqlmigrate.Collect(fsys) + migrations, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } @@ -89,7 +89,7 @@ func TestCollect(t *testing.T) { "001_init.up.sql": {Data: []byte("CREATE TABLE a;\nINSERT INTO _migrations (name, id) VALUES ('001_init', 'abcd1234');")}, "001_init.down.sql": {Data: []byte("DROP TABLE a;\nDELETE FROM _migrations WHERE id = 'abcd1234';")}, } - migrations, err := sqlmigrate.Collect(fsys) + migrations, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } @@ -103,7 +103,7 @@ func TestCollect(t *testing.T) { "001_init.up.sql": {Data: []byte("CREATE TABLE a;")}, "001_init.down.sql": {Data: []byte("DROP TABLE a;")}, } - migrations, err := sqlmigrate.Collect(fsys) + migrations, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } @@ -116,7 +116,7 @@ func TestCollect(t *testing.T) { fsys := fstest.MapFS{ "001_only-up.up.sql": {Data: []byte("CREATE TABLE x;")}, } - _, err := sqlmigrate.Collect(fsys) + _, err := sqlmigrate.Collect(fsys, ".") if !errors.Is(err, sqlmigrate.ErrMissingDown) { t.Errorf("got %v, want ErrMissingDown", err) } @@ -126,7 +126,7 @@ func TestCollect(t *testing.T) { fsys := fstest.MapFS{ "001_only-down.down.sql": {Data: []byte("DROP TABLE x;")}, } - _, err := sqlmigrate.Collect(fsys) + _, err := sqlmigrate.Collect(fsys, ".") if !errors.Is(err, sqlmigrate.ErrMissingUp) { t.Errorf("got %v, want ErrMissingUp", err) } @@ -139,7 +139,7 @@ func TestCollect(t *testing.T) { "README.md": {Data: []byte("# Migrations")}, "_migrations.sql": {Data: []byte("SELECT name FROM _migrations;")}, } - migrations, err := sqlmigrate.Collect(fsys) + migrations, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) } @@ -150,7 +150,7 @@ func TestCollect(t *testing.T) { t.Run("empty fs", func(t *testing.T) { fsys := fstest.MapFS{} - migrations, err := sqlmigrate.Collect(fsys) + migrations, err := sqlmigrate.Collect(fsys, ".") if err != nil { t.Fatal(err) }