aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfelix <felix>2018-06-22 10:37:31 +0000
committerfelix <felix>2018-06-22 10:37:31 +0000
commit3d2f2f2b0a357462bbda428d0626ee4b606b4ff2 (patch)
tree3c03836e84e6905611b8fae00d33b539d6fe6e27
parent456b81fd8b4bc08cfdf9075a1d0e487df6af829f (diff)
downloadmigrate-3d2f2f2b0a357462bbda428d0626ee4b606b4ff2.tar.gz
migrate-3d2f2f2b0a357462bbda428d0626ee4b606b4ff2.tar.bz2
Add initial working files
-rw-r--r--file.go124
-rw-r--r--migrate.go215
-rw-r--r--options.go24
3 files changed, 363 insertions, 0 deletions
diff --git a/file.go b/file.go
new file mode 100644
index 0000000..ecf8932
--- /dev/null
+++ b/file.go
@@ -0,0 +1,124 @@
+package migrate
+
+import (
+ "database/sql"
+ "fmt"
+ "io/ioutil"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "regexp"
+ "strconv"
+)
+
+// Regex matches the following pattern:
+// 123_name.ext
+var validFilename = regexp.MustCompile(`^([0-9]+)_(.*)\.(.*)$`)
+
+// Just the path to the migration file
+type fileMigration string
+
+// NewFileMigrator creates a new set of migrations from a path
+// Each one is run in a transaction.
+func NewFileMigrator(db *sql.DB, path string, opts ...Option) (*Migrator, error) {
+ migrations, err := readFiles(path)
+ if err != nil {
+ return nil, err
+ }
+
+ m := Migrator{db: db, migrations: migrations}
+
+ for _, opt := range opts {
+ if err = opt(&m); err != nil {
+ return nil, err
+ }
+ }
+
+ return &m, nil
+}
+
+// The Version is extracted from the filename
+// It implements the Migration interface
+func (fm fileMigration) Version() int64 {
+ m := validFilename.FindStringSubmatch(path.Base(string(fm)))
+ if len(m) == 4 {
+ if version, err := strconv.ParseInt(m[1], 10, 64); err == nil {
+ return version
+ }
+ }
+ return -1
+}
+
+// Run executes the migration
+// It implements the Migration interface
+func (fm fileMigration) Run(tx *sql.Tx) error {
+ r, err := os.Open(string(fm))
+ if err != nil {
+ return err
+ }
+ defer r.Close()
+
+ buf, err := ioutil.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ _, err = tx.Exec(string(buf[:]))
+ return err
+}
+
+// Read all files in path
+// They will be sorted by the migrator according to Version()
+func readFiles(uri string) (migrations []Migration, err error) {
+ u, err := url.Parse(uri)
+ if err != nil {
+ return nil, err
+ }
+
+ // Host might be `.`
+ p := u.Host + u.Path
+
+ if len(p) == 0 {
+ // Default to current directory
+ wd, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ p = wd
+ } else if p[0:1] == "." || p[0:1] != "/" {
+ // Ensure path is absolute
+ abs, err := filepath.Abs(p)
+ if err != nil {
+ return nil, err
+ }
+ p = abs
+ }
+
+ // Scan entire directory
+ files, err := ioutil.ReadDir(p)
+ if err != nil {
+ return nil, err
+ }
+
+ seen := make(map[int64]bool)
+
+ for _, fi := range files {
+ if !fi.IsDir() {
+ fm := fileMigration(path.Join(p, fi.Name()))
+
+ // Ignore invalid filenames
+ v := fm.Version()
+ if v == -1 {
+ fmt.Printf("invalid version %d for %s\n", v, fm)
+ continue
+ }
+ if seen[v] {
+ fmt.Printf("duplicate version %d for %s\n", v, fm)
+ continue
+ }
+ migrations = append(migrations, fm)
+ seen[v] = true
+ }
+ }
+ return migrations, nil
+}
diff --git a/migrate.go b/migrate.go
new file mode 100644
index 0000000..8d3e7ab
--- /dev/null
+++ b/migrate.go
@@ -0,0 +1,215 @@
+package migrate
+
+import (
+ "database/sql"
+ "fmt"
+ "sort"
+ "time"
+)
+
+const (
+ // NilVersion is a Claytons version
+ // "the version you are at when you are not at a version"
+ NilVersion = -1
+)
+
+// A Migrator collates and runs migrations
+type Migrator struct {
+ db *sql.DB
+ migrations []Migration
+ versionTable *string
+ stmts map[string]*sql.Stmt
+ prepared bool
+}
+
+// Migration interface
+type Migration interface {
+ // The version of this migration
+ Version() int64
+ // Run the migration
+ Run(*sql.Tx) error
+}
+
+// ResultFunc is the callback signature
+type ResultFunc func(int64, int64, error)
+
+// Sort those migrations
+type sorted []Migration
+
+func (s sorted) Len() int { return len(s) }
+func (s sorted) Less(i, j int) bool { return s[i].Version() < s[j].Version() }
+func (s sorted) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Version reports the current version of the database
+func (m *Migrator) Version() (int64, error) {
+ err := m.prepareForMigration()
+ if err != nil {
+ return NilVersion, err
+ }
+
+ rows, err := m.stmts["getVersion"].Query()
+ if rows.Next() {
+ var version int64
+ err = rows.Scan(&version)
+ if err == sql.ErrNoRows {
+ return NilVersion, nil
+ }
+ if err == nil {
+ return version, nil
+ }
+ }
+ return 0, err
+}
+
+// Migrate migrates the database to the highest possible version
+func (m *Migrator) Migrate(cb ResultFunc) error {
+ err := m.prepareForMigration()
+ if err != nil {
+ return err
+ }
+
+ // Get the last available migration
+ v := m.migrations[len(m.migrations)-1].Version()
+ return m.MigrateTo(v, cb)
+}
+
+// MigrateTo migrates the database to the specified version
+func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error {
+ err := m.prepareForMigration()
+ if err != nil {
+ return err
+ }
+
+ maxVersion := m.migrations[len(m.migrations)-1].Version()
+
+ currVersion, err := m.Version()
+ if err != nil {
+ return err
+ }
+
+ if currVersion >= toVersion {
+ if cb != nil {
+ go cb(maxVersion, currVersion, nil)
+ }
+ return nil
+ }
+
+ for _, mig := range m.migrations {
+ nextVersion := mig.Version()
+
+ // Skip old migrations
+ if nextVersion <= currVersion {
+ continue
+ }
+
+ // Ensure contiguous
+ if currVersion != NilVersion && nextVersion != currVersion+1 {
+ return fmt.Errorf("non-contiguous migration: %v -> %v", currVersion, nextVersion)
+ }
+
+ if currVersion < nextVersion && nextVersion <= toVersion {
+ err = func() error {
+ fmt.Println("running migration", nextVersion)
+ // Start a transaction
+ tx, err := m.db.Begin()
+ if err != nil {
+ if cb != nil {
+ go cb(maxVersion, currVersion, err)
+ }
+ return err
+ }
+ defer tx.Rollback()
+
+ // Run the migration
+ if err = mig.Run(tx); err != nil {
+ if cb != nil {
+ go cb(maxVersion, currVersion, err)
+ }
+ return err
+ }
+ // Update the version entry
+ fmt.Println("updating version")
+ if err = m.setVersion(tx, nextVersion); err != nil {
+ if cb != nil {
+ go cb(maxVersion, currVersion, err)
+ }
+ return err
+ }
+ // Commit the transaction
+ fmt.Println("committing version")
+ return tx.Commit()
+ }()
+ if err != nil {
+ if cb != nil {
+ go cb(maxVersion, currVersion, err)
+ }
+ return err
+ }
+ if cb != nil {
+ go cb(maxVersion, currVersion, nil)
+ }
+ }
+ currVersion = nextVersion
+ }
+
+ return nil
+}
+
+func (m *Migrator) setVersion(tx *sql.Tx, version int64) (err error) {
+ if version >= 0 {
+ _, err = tx.Stmt(m.stmts["insertVersion"]).Exec(version, time.Now().Unix())
+ }
+ return err
+}
+
+func (m *Migrator) prepareForMigration() error {
+ if m.prepared {
+ return nil
+ }
+
+ if m.versionTable == nil {
+ vt := "current_schema_version"
+ m.versionTable = &vt
+ }
+
+ if _, err := m.db.Exec(fmt.Sprintf(createTableSQL, *m.versionTable)); err != nil {
+ return err
+ }
+
+ if err := m.prepareStmts(); err != nil {
+ return err
+ }
+
+ sort.Sort(sorted(m.migrations))
+
+ m.prepared = true
+ return nil
+}
+
+func (m *Migrator) prepareStmts() error {
+ m.stmts = make(map[string]*sql.Stmt)
+ s, err := m.db.Prepare(fmt.Sprintf(getVersionSQL, NilVersion, *m.versionTable))
+ if err != nil {
+ return err
+ }
+ m.stmts["getVersion"] = s
+
+ s, err = m.db.Prepare(fmt.Sprintf(insertVersionSQL, *m.versionTable))
+ if err != nil {
+ return err
+ }
+ m.stmts["insertVersion"] = s
+
+ return nil
+}
+
+const (
+ getVersionSQL = `select coalesce(max(version), %d) from %q`
+ insertVersionSQL = `insert into %q (version, applied) values ($1, $2)`
+
+ // Use Unix timestamp for time so it works for SQLite and PostgreSQL
+ createTableSQL = `create table if not exists %q (
+ version bigint not null primary key,
+ applied int
+ )`
+)
diff --git a/options.go b/options.go
new file mode 100644
index 0000000..61175dd
--- /dev/null
+++ b/options.go
@@ -0,0 +1,24 @@
+package migrate
+
+//import "context"
+
+// An Option configures a migrator
+type Option func(*Migrator) error
+
+// SetVersionTable configures the table used for recording the schema version
+func SetVersionTable(vt string) Option {
+ return func(m *Migrator) error {
+ m.versionTable = &vt
+ return nil
+ }
+}
+
+// SetContext configures the context for queries
+/*
+func SetContext(ctx context.Context) Option {
+ return func(m *Migrator) error {
+ m.ctx = ctx
+ return nil
+ }
+}
+*/