diff options
| -rw-r--r-- | Makefile | 14 | ||||
| -rw-r--r-- | fs.go | 110 | ||||
| -rw-r--r-- | fs_test.go | 50 |
3 files changed, 165 insertions, 9 deletions
@@ -1,13 +1,9 @@ +GO ?= go1.16beta1 -.PHONY: test test: lint - go test -short -coverprofile=coverage.txt -covermode=atomic ./... \ - && go tool cover -html=coverage.txt -o coverage.html + $(GO) test -race -short -coverprofile=coverage.txt -covermode=atomic ./... \ + && $(GO) tool cover -func=coverage.txt -.PHONY: lint -lint: - go vet ./... +lint: ; $(GO) vet ./... -.PHONY: clean -clean: - rm -rf coverage* +clean: ; rm -rf coverage* @@ -0,0 +1,110 @@ +// +build go1.16 + +package migrate + +import ( + "database/sql" + "fmt" + "io/fs" + "io/ioutil" + "path" + "strconv" + "strings" +) + +type fsMigration struct { + name string + fs fs.FS +} + +func NewFSMigrator(db *sql.DB, fs fs.ReadDirFS, opts ...Option) (*Migrator, error) { + migrations, err := readFS(fs, ".") + if err != nil { + return nil, err + } + + m := Migrator{db: db, migrations: migrations} + + for _, opt := range opts { + if err = opt(&m); err != nil { + return nil, err + } + } + + return &m, nil +} + +// The Version is extracted from the filename +// It implements the Migration interface +func (fm fsMigration) Version() int { + m := validFilename.FindStringSubmatch(path.Base(fm.name)) + if len(m) == 4 { + if version, err := strconv.ParseInt(m[1], 10, 32); err == nil { + return int(version) + } + } + return -1 +} + +// Run executes the migration +// It implements the Migration interface +func (fm fsMigration) Run(tx *sql.Tx) error { + f, err := fm.fs.Open(fm.name) + if err != nil { + return err + } + defer f.Close() + + buf, err := ioutil.ReadAll(f) + if err != nil { + return err + } + _, err = tx.Exec(string(buf[:])) + return err +} + +// Read all files in path +// They will be sorted by the migrator according to Version() +func readFS(fs fs.ReadDirFS, root string) ([]Migration, error) { + + // Scan entire directory + files, err := fs.ReadDir(root) + if err != nil { + return nil, err + } + + migrations := make([]Migration, 0) + seen := make(map[int]bool) + + for _, fi := range files { + if strings.HasPrefix(fi.Name(), ".") { + continue + } + if fi.IsDir() { + fms, err := readFS(fs, fi.Name()) + if err != nil { + return nil, err + } + migrations = append(migrations, fms...) + } else { + fm := fsMigration{name: path.Join(root, fi.Name()), fs: fs} + + // Ignore invalid filenames + v := fm.Version() + if v == -1 { + fmt.Printf("invalid version %d for %s\n", v, fm) + continue + } + if seen[v] { + fmt.Printf("duplicate version %d for %s\n", v, fm) + continue + } + migrations = append(migrations, fm) + seen[v] = true + } + } + if len(migrations) == 0 { + return nil, fmt.Errorf("no migrations found") + } + return migrations, nil +} diff --git a/fs_test.go b/fs_test.go new file mode 100644 index 0000000..3993163 --- /dev/null +++ b/fs_test.go @@ -0,0 +1,50 @@ +// +build go1.16 + +package migrate + +import ( + "database/sql" + "embed" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestFSMigrator(t *testing.T) { + + //go:embed testdata/*.sql + var testdata embed.FS + + db, err := sql.Open("sqlite3", testDB) + if err != nil { + t.Fatalf("DB setup failed: %v", err) + } + defer db.Close() + + migrator, err := NewFSMigrator(db, testdata) + if err != nil { + t.Fatal(err) + } + + if v, _ := migrator.Version(); v != NilVersion { + t.Errorf("expected migration version NilVersion, got %d", v) + } + + if c := len(migrator.migrations); c != 2 { + t.Errorf("expected migration count = 2, got %d", c) + } + + err = migrator.MigrateTo(1) + if err != nil { + t.Fatalf("Migrator.MigrateTo(3) failed: %v", err) + } + + v, err := migrator.Version() + if err != nil { + t.Fatalf("Migrator.Version() failed: %v", err) + } + + if int(v) != len(migrator.migrations)-1 { + t.Errorf("expected migration version %d, got %d", len(migrator.migrations)-1, v) + } +} |
