aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfelix <felix>2018-06-22 14:51:30 +0000
committerfelix <felix>2018-06-22 14:51:30 +0000
commit68b150b104c9f832a624f2440b20dfe329a8a551 (patch)
treec064530f258e470e623d1e090a883a8f7eb6efc9
parent3d2f2f2b0a357462bbda428d0626ee4b606b4ff2 (diff)
downloadmigrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.gz
migrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.bz2
Add tests
-rw-r--r--migrate.go87
-rw-r--r--migrate_test.go84
-rw-r--r--options.go8
3 files changed, 127 insertions, 52 deletions
diff --git a/migrate.go b/migrate.go
index 8d3e7ab..a4c493d 100644
--- a/migrate.go
+++ b/migrate.go
@@ -13,15 +13,6 @@ const (
NilVersion = -1
)
-// A Migrator collates and runs migrations
-type Migrator struct {
- db *sql.DB
- migrations []Migration
- versionTable *string
- stmts map[string]*sql.Stmt
- prepared bool
-}
-
// Migration interface
type Migration interface {
// The version of this migration
@@ -33,6 +24,16 @@ type Migration interface {
// ResultFunc is the callback signature
type ResultFunc func(int64, int64, error)
+// A Migrator collates and runs migrations
+type Migrator struct {
+ db *sql.DB
+ migrations []Migration
+ versionTable *string
+ stmts map[string]*sql.Stmt
+ prepared bool
+ callback ResultFunc
+}
+
// Sort those migrations
type sorted []Migration
@@ -47,22 +48,19 @@ func (m *Migrator) Version() (int64, error) {
return NilVersion, err
}
- rows, err := m.stmts["getVersion"].Query()
- if rows.Next() {
- var version int64
- err = rows.Scan(&version)
+ var version int64
+ err = m.stmts["getVersion"].QueryRow().Scan(&version)
+ if err != nil {
if err == sql.ErrNoRows {
return NilVersion, nil
}
- if err == nil {
- return version, nil
- }
+ return NilVersion, err
}
- return 0, err
+ return version, nil
}
// Migrate migrates the database to the highest possible version
-func (m *Migrator) Migrate(cb ResultFunc) error {
+func (m *Migrator) Migrate() error {
err := m.prepareForMigration()
if err != nil {
return err
@@ -70,11 +68,11 @@ func (m *Migrator) Migrate(cb ResultFunc) error {
// Get the last available migration
v := m.migrations[len(m.migrations)-1].Version()
- return m.MigrateTo(v, cb)
+ return m.MigrateTo(v)
}
// MigrateTo migrates the database to the specified version
-func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error {
+func (m *Migrator) MigrateTo(toVersion int64) error {
err := m.prepareForMigration()
if err != nil {
return err
@@ -84,13 +82,11 @@ func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error {
currVersion, err := m.Version()
if err != nil {
- return err
+ return fmt.Errorf("migration %d failed: %s", currVersion, err)
}
if currVersion >= toVersion {
- if cb != nil {
- go cb(maxVersion, currVersion, nil)
- }
+ go m.callback(maxVersion, currVersion, nil)
return nil
}
@@ -109,45 +105,33 @@ func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error {
if currVersion < nextVersion && nextVersion <= toVersion {
err = func() error {
- fmt.Println("running migration", nextVersion)
// Start a transaction
tx, err := m.db.Begin()
if err != nil {
- if cb != nil {
- go cb(maxVersion, currVersion, err)
- }
- return err
+ return fmt.Errorf("migration %d failed: %s", currVersion, err)
}
- defer tx.Rollback()
+ defer tx.Commit()
// Run the migration
if err = mig.Run(tx); err != nil {
- if cb != nil {
- go cb(maxVersion, currVersion, err)
- }
- return err
+ tx.Rollback()
+ return fmt.Errorf("migration %d failed: %s", currVersion, err)
}
// Update the version entry
- fmt.Println("updating version")
if err = m.setVersion(tx, nextVersion); err != nil {
- if cb != nil {
- go cb(maxVersion, currVersion, err)
- }
- return err
+ tx.Rollback()
+ return fmt.Errorf("migration %d failed: %s", currVersion, err)
}
- // Commit the transaction
- fmt.Println("committing version")
return tx.Commit()
}()
+
+ if m.callback != nil {
+ go m.callback(maxVersion, currVersion, err)
+ }
+
if err != nil {
- if cb != nil {
- go cb(maxVersion, currVersion, err)
- }
return err
}
- if cb != nil {
- go cb(maxVersion, currVersion, nil)
- }
}
currVersion = nextVersion
}
@@ -204,12 +188,11 @@ func (m *Migrator) prepareStmts() error {
}
const (
- getVersionSQL = `select coalesce(max(version), %d) from %q`
- insertVersionSQL = `insert into %q (version, applied) values ($1, $2)`
+ getVersionSQL = `select coalesce(max(version), %d) from %s`
+ insertVersionSQL = `insert into %s (version, applied) values ($1, $2)`
// Use Unix timestamp for time so it works for SQLite and PostgreSQL
- createTableSQL = `create table if not exists %q (
+ createTableSQL = `create table if not exists %s (
version bigint not null primary key,
- applied int
- )`
+ applied int)`
)
diff --git a/migrate_test.go b/migrate_test.go
new file mode 100644
index 0000000..c857014
--- /dev/null
+++ b/migrate_test.go
@@ -0,0 +1,84 @@
+package migrate
+
+import (
+ "database/sql"
+ "testing"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+const testDB = "file:test?mode=memory&cache=shared"
+
+type mig struct {
+ version int64
+ sql string
+}
+
+func createMigration(v int64, sql string) Migration {
+ return &mig{version: v, sql: sql}
+}
+
+func (m mig) Version() int64 { return m.version }
+
+func (m mig) Run(tx *sql.Tx) error {
+ _, err := tx.Exec(m.sql)
+ return err
+}
+
+var testMigrations = []struct {
+ version int64
+ sql string
+}{
+ // Out of order please
+ {version: 2, sql: "insert into test1 (pk) values (1)"},
+ {version: 3, sql: "insert into test1 (pk) values (2)"},
+ {version: 1, sql: "create table if not exists test1 (pk bigint not null primary key)"},
+}
+
+func TestMigrate(t *testing.T) {
+ // Load migrations
+ var migrations []Migration
+ for _, m := range testMigrations {
+ migrations = append(migrations, createMigration(m.version, m.sql))
+ }
+
+ db, err := sql.Open("sqlite3", testDB)
+ if err != nil {
+ t.Fatalf("DB setup failed: %v", err)
+ }
+ defer db.Close()
+ //db.SetMaxOpenConns(1)
+
+ migrator := Migrator{
+ db: db,
+ migrations: migrations,
+ }
+
+ v, err := migrator.Version()
+ if err != nil {
+ t.Fatalf("Migrator.Version() failed: %v", err)
+ }
+ if v != NilVersion {
+ t.Fatalf("Migrator.Version() should be NilVersion, got %d", v)
+ }
+
+ err = migrator.Migrate()
+ if err != nil {
+ t.Fatalf("Migrator.Migrate() failed: %v", err)
+ }
+
+ v, err = migrator.Version()
+ if err != nil {
+ t.Fatalf("Migrator.Version() failed: %v", err)
+ }
+
+ if int(v) != len(migrations) {
+ t.Errorf("expected migration version %d, got %d", len(migrations), v)
+ }
+
+ var result int64
+ err = db.QueryRow(`select pk from test1`).Scan(&result)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/options.go b/options.go
index 61175dd..7856336 100644
--- a/options.go
+++ b/options.go
@@ -13,6 +13,14 @@ func SetVersionTable(vt string) Option {
}
}
+// SetCallback configures the table used for recording the schema version
+func SetCallback(cb ResultFunc) Option {
+ return func(m *Migrator) error {
+ m.callback = cb
+ return nil
+ }
+}
+
// SetContext configures the context for queries
/*
func SetContext(ctx context.Context) Option {