diff options
| author | felix <felix> | 2018-06-22 10:37:31 +0000 |
|---|---|---|
| committer | felix <felix> | 2018-06-22 10:37:31 +0000 |
| commit | 3d2f2f2b0a357462bbda428d0626ee4b606b4ff2 (patch) | |
| tree | 3c03836e84e6905611b8fae00d33b539d6fe6e27 | |
| parent | 456b81fd8b4bc08cfdf9075a1d0e487df6af829f (diff) | |
| download | migrate-3d2f2f2b0a357462bbda428d0626ee4b606b4ff2.tar.gz migrate-3d2f2f2b0a357462bbda428d0626ee4b606b4ff2.tar.bz2 | |
Add initial working files
| -rw-r--r-- | file.go | 124 | ||||
| -rw-r--r-- | migrate.go | 215 | ||||
| -rw-r--r-- | options.go | 24 |
3 files changed, 363 insertions, 0 deletions
@@ -0,0 +1,124 @@ +package migrate + +import ( + "database/sql" + "fmt" + "io/ioutil" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "strconv" +) + +// Regex matches the following pattern: +// 123_name.ext +var validFilename = regexp.MustCompile(`^([0-9]+)_(.*)\.(.*)$`) + +// Just the path to the migration file +type fileMigration string + +// NewFileMigrator creates a new set of migrations from a path +// Each one is run in a transaction. +func NewFileMigrator(db *sql.DB, path string, opts ...Option) (*Migrator, error) { + migrations, err := readFiles(path) + if err != nil { + return nil, err + } + + m := Migrator{db: db, migrations: migrations} + + for _, opt := range opts { + if err = opt(&m); err != nil { + return nil, err + } + } + + return &m, nil +} + +// The Version is extracted from the filename +// It implements the Migration interface +func (fm fileMigration) Version() int64 { + m := validFilename.FindStringSubmatch(path.Base(string(fm))) + if len(m) == 4 { + if version, err := strconv.ParseInt(m[1], 10, 64); err == nil { + return version + } + } + return -1 +} + +// Run executes the migration +// It implements the Migration interface +func (fm fileMigration) Run(tx *sql.Tx) error { + r, err := os.Open(string(fm)) + if err != nil { + return err + } + defer r.Close() + + buf, err := ioutil.ReadAll(r) + if err != nil { + return err + } + _, err = tx.Exec(string(buf[:])) + return err +} + +// Read all files in path +// They will be sorted by the migrator according to Version() +func readFiles(uri string) (migrations []Migration, err error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + + // Host might be `.` + p := u.Host + u.Path + + if len(p) == 0 { + // Default to current directory + wd, err := os.Getwd() + if err != nil { + return nil, err + } + p = wd + } else if p[0:1] == "." || p[0:1] != "/" { + // Ensure path is absolute + abs, err := filepath.Abs(p) + if err != nil { + return nil, err + } + p = abs + } + + // Scan entire directory + files, err := ioutil.ReadDir(p) + if err != nil { + return nil, err + } + + seen := make(map[int64]bool) + + for _, fi := range files { + if !fi.IsDir() { + fm := fileMigration(path.Join(p, fi.Name())) + + // Ignore invalid filenames + v := fm.Version() + if v == -1 { + fmt.Printf("invalid version %d for %s\n", v, fm) + continue + } + if seen[v] { + fmt.Printf("duplicate version %d for %s\n", v, fm) + continue + } + migrations = append(migrations, fm) + seen[v] = true + } + } + return migrations, nil +} diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..8d3e7ab --- /dev/null +++ b/migrate.go @@ -0,0 +1,215 @@ +package migrate + +import ( + "database/sql" + "fmt" + "sort" + "time" +) + +const ( + // NilVersion is a Claytons version + // "the version you are at when you are not at a version" + NilVersion = -1 +) + +// A Migrator collates and runs migrations +type Migrator struct { + db *sql.DB + migrations []Migration + versionTable *string + stmts map[string]*sql.Stmt + prepared bool +} + +// Migration interface +type Migration interface { + // The version of this migration + Version() int64 + // Run the migration + Run(*sql.Tx) error +} + +// ResultFunc is the callback signature +type ResultFunc func(int64, int64, error) + +// Sort those migrations +type sorted []Migration + +func (s sorted) Len() int { return len(s) } +func (s sorted) Less(i, j int) bool { return s[i].Version() < s[j].Version() } +func (s sorted) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Version reports the current version of the database +func (m *Migrator) Version() (int64, error) { + err := m.prepareForMigration() + if err != nil { + return NilVersion, err + } + + rows, err := m.stmts["getVersion"].Query() + if rows.Next() { + var version int64 + err = rows.Scan(&version) + if err == sql.ErrNoRows { + return NilVersion, nil + } + if err == nil { + return version, nil + } + } + return 0, err +} + +// Migrate migrates the database to the highest possible version +func (m *Migrator) Migrate(cb ResultFunc) error { + err := m.prepareForMigration() + if err != nil { + return err + } + + // Get the last available migration + v := m.migrations[len(m.migrations)-1].Version() + return m.MigrateTo(v, cb) +} + +// MigrateTo migrates the database to the specified version +func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error { + err := m.prepareForMigration() + if err != nil { + return err + } + + maxVersion := m.migrations[len(m.migrations)-1].Version() + + currVersion, err := m.Version() + if err != nil { + return err + } + + if currVersion >= toVersion { + if cb != nil { + go cb(maxVersion, currVersion, nil) + } + return nil + } + + for _, mig := range m.migrations { + nextVersion := mig.Version() + + // Skip old migrations + if nextVersion <= currVersion { + continue + } + + // Ensure contiguous + if currVersion != NilVersion && nextVersion != currVersion+1 { + return fmt.Errorf("non-contiguous migration: %v -> %v", currVersion, nextVersion) + } + + if currVersion < nextVersion && nextVersion <= toVersion { + err = func() error { + fmt.Println("running migration", nextVersion) + // Start a transaction + tx, err := m.db.Begin() + if err != nil { + if cb != nil { + go cb(maxVersion, currVersion, err) + } + return err + } + defer tx.Rollback() + + // Run the migration + if err = mig.Run(tx); err != nil { + if cb != nil { + go cb(maxVersion, currVersion, err) + } + return err + } + // Update the version entry + fmt.Println("updating version") + if err = m.setVersion(tx, nextVersion); err != nil { + if cb != nil { + go cb(maxVersion, currVersion, err) + } + return err + } + // Commit the transaction + fmt.Println("committing version") + return tx.Commit() + }() + if err != nil { + if cb != nil { + go cb(maxVersion, currVersion, err) + } + return err + } + if cb != nil { + go cb(maxVersion, currVersion, nil) + } + } + currVersion = nextVersion + } + + return nil +} + +func (m *Migrator) setVersion(tx *sql.Tx, version int64) (err error) { + if version >= 0 { + _, err = tx.Stmt(m.stmts["insertVersion"]).Exec(version, time.Now().Unix()) + } + return err +} + +func (m *Migrator) prepareForMigration() error { + if m.prepared { + return nil + } + + if m.versionTable == nil { + vt := "current_schema_version" + m.versionTable = &vt + } + + if _, err := m.db.Exec(fmt.Sprintf(createTableSQL, *m.versionTable)); err != nil { + return err + } + + if err := m.prepareStmts(); err != nil { + return err + } + + sort.Sort(sorted(m.migrations)) + + m.prepared = true + return nil +} + +func (m *Migrator) prepareStmts() error { + m.stmts = make(map[string]*sql.Stmt) + s, err := m.db.Prepare(fmt.Sprintf(getVersionSQL, NilVersion, *m.versionTable)) + if err != nil { + return err + } + m.stmts["getVersion"] = s + + s, err = m.db.Prepare(fmt.Sprintf(insertVersionSQL, *m.versionTable)) + if err != nil { + return err + } + m.stmts["insertVersion"] = s + + return nil +} + +const ( + getVersionSQL = `select coalesce(max(version), %d) from %q` + insertVersionSQL = `insert into %q (version, applied) values ($1, $2)` + + // Use Unix timestamp for time so it works for SQLite and PostgreSQL + createTableSQL = `create table if not exists %q ( + version bigint not null primary key, + applied int + )` +) diff --git a/options.go b/options.go new file mode 100644 index 0000000..61175dd --- /dev/null +++ b/options.go @@ -0,0 +1,24 @@ +package migrate + +//import "context" + +// An Option configures a migrator +type Option func(*Migrator) error + +// SetVersionTable configures the table used for recording the schema version +func SetVersionTable(vt string) Option { + return func(m *Migrator) error { + m.versionTable = &vt + return nil + } +} + +// SetContext configures the context for queries +/* +func SetContext(ctx context.Context) Option { + return func(m *Migrator) error { + m.ctx = ctx + return nil + } +} +*/ |
