diff options
Diffstat (limited to 'file.go')
| -rw-r--r-- | file.go | 124 |
1 files changed, 124 insertions, 0 deletions
@@ -0,0 +1,124 @@ +package migrate + +import ( + "database/sql" + "fmt" + "io/ioutil" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "strconv" +) + +// Regex matches the following pattern: +// 123_name.ext +var validFilename = regexp.MustCompile(`^([0-9]+)_(.*)\.(.*)$`) + +// Just the path to the migration file +type fileMigration string + +// NewFileMigrator creates a new set of migrations from a path +// Each one is run in a transaction. +func NewFileMigrator(db *sql.DB, path string, opts ...Option) (*Migrator, error) { + migrations, err := readFiles(path) + if err != nil { + return nil, err + } + + m := Migrator{db: db, migrations: migrations} + + for _, opt := range opts { + if err = opt(&m); err != nil { + return nil, err + } + } + + return &m, nil +} + +// The Version is extracted from the filename +// It implements the Migration interface +func (fm fileMigration) Version() int64 { + m := validFilename.FindStringSubmatch(path.Base(string(fm))) + if len(m) == 4 { + if version, err := strconv.ParseInt(m[1], 10, 64); err == nil { + return version + } + } + return -1 +} + +// Run executes the migration +// It implements the Migration interface +func (fm fileMigration) Run(tx *sql.Tx) error { + r, err := os.Open(string(fm)) + if err != nil { + return err + } + defer r.Close() + + buf, err := ioutil.ReadAll(r) + if err != nil { + return err + } + _, err = tx.Exec(string(buf[:])) + return err +} + +// Read all files in path +// They will be sorted by the migrator according to Version() +func readFiles(uri string) (migrations []Migration, err error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + + // Host might be `.` + p := u.Host + u.Path + + if len(p) == 0 { + // Default to current directory + wd, err := os.Getwd() + if err != nil { + return nil, err + } + p = wd + } else if p[0:1] == "." || p[0:1] != "/" { + // Ensure path is absolute + abs, err := filepath.Abs(p) + if err != nil { + return nil, err + } + p = abs + } + + // Scan entire directory + files, err := ioutil.ReadDir(p) + if err != nil { + return nil, err + } + + seen := make(map[int64]bool) + + for _, fi := range files { + if !fi.IsDir() { + fm := fileMigration(path.Join(p, fi.Name())) + + // Ignore invalid filenames + v := fm.Version() + if v == -1 { + fmt.Printf("invalid version %d for %s\n", v, fm) + continue + } + if seen[v] { + fmt.Printf("duplicate version %d for %s\n", v, fm) + continue + } + migrations = append(migrations, fm) + seen[v] = true + } + } + return migrations, nil +} |
