aboutsummaryrefslogtreecommitdiff
path: root/file.go
diff options
context:
space:
mode:
Diffstat (limited to 'file.go')
-rw-r--r--file.go124
1 files changed, 124 insertions, 0 deletions
diff --git a/file.go b/file.go
new file mode 100644
index 0000000..ecf8932
--- /dev/null
+++ b/file.go
@@ -0,0 +1,124 @@
+package migrate
+
+import (
+ "database/sql"
+ "fmt"
+ "io/ioutil"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "regexp"
+ "strconv"
+)
+
+// Regex matches the following pattern:
+// 123_name.ext
+var validFilename = regexp.MustCompile(`^([0-9]+)_(.*)\.(.*)$`)
+
+// Just the path to the migration file
+type fileMigration string
+
+// NewFileMigrator creates a new set of migrations from a path
+// Each one is run in a transaction.
+func NewFileMigrator(db *sql.DB, path string, opts ...Option) (*Migrator, error) {
+ migrations, err := readFiles(path)
+ if err != nil {
+ return nil, err
+ }
+
+ m := Migrator{db: db, migrations: migrations}
+
+ for _, opt := range opts {
+ if err = opt(&m); err != nil {
+ return nil, err
+ }
+ }
+
+ return &m, nil
+}
+
+// The Version is extracted from the filename
+// It implements the Migration interface
+func (fm fileMigration) Version() int64 {
+ m := validFilename.FindStringSubmatch(path.Base(string(fm)))
+ if len(m) == 4 {
+ if version, err := strconv.ParseInt(m[1], 10, 64); err == nil {
+ return version
+ }
+ }
+ return -1
+}
+
+// Run executes the migration
+// It implements the Migration interface
+func (fm fileMigration) Run(tx *sql.Tx) error {
+ r, err := os.Open(string(fm))
+ if err != nil {
+ return err
+ }
+ defer r.Close()
+
+ buf, err := ioutil.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ _, err = tx.Exec(string(buf[:]))
+ return err
+}
+
+// Read all files in path
+// They will be sorted by the migrator according to Version()
+func readFiles(uri string) (migrations []Migration, err error) {
+ u, err := url.Parse(uri)
+ if err != nil {
+ return nil, err
+ }
+
+ // Host might be `.`
+ p := u.Host + u.Path
+
+ if len(p) == 0 {
+ // Default to current directory
+ wd, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ p = wd
+ } else if p[0:1] == "." || p[0:1] != "/" {
+ // Ensure path is absolute
+ abs, err := filepath.Abs(p)
+ if err != nil {
+ return nil, err
+ }
+ p = abs
+ }
+
+ // Scan entire directory
+ files, err := ioutil.ReadDir(p)
+ if err != nil {
+ return nil, err
+ }
+
+ seen := make(map[int64]bool)
+
+ for _, fi := range files {
+ if !fi.IsDir() {
+ fm := fileMigration(path.Join(p, fi.Name()))
+
+ // Ignore invalid filenames
+ v := fm.Version()
+ if v == -1 {
+ fmt.Printf("invalid version %d for %s\n", v, fm)
+ continue
+ }
+ if seen[v] {
+ fmt.Printf("duplicate version %d for %s\n", v, fm)
+ continue
+ }
+ migrations = append(migrations, fm)
+ seen[v] = true
+ }
+ }
+ return migrations, nil
+}