diff options
| author | felix <felix> | 2018-06-22 14:51:30 +0000 |
|---|---|---|
| committer | felix <felix> | 2018-06-22 14:51:30 +0000 |
| commit | 68b150b104c9f832a624f2440b20dfe329a8a551 (patch) | |
| tree | c064530f258e470e623d1e090a883a8f7eb6efc9 | |
| parent | 3d2f2f2b0a357462bbda428d0626ee4b606b4ff2 (diff) | |
| download | migrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.gz migrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.bz2 | |
Add tests
| -rw-r--r-- | migrate.go | 87 | ||||
| -rw-r--r-- | migrate_test.go | 84 | ||||
| -rw-r--r-- | options.go | 8 |
3 files changed, 127 insertions, 52 deletions
@@ -13,15 +13,6 @@ const ( NilVersion = -1 ) -// A Migrator collates and runs migrations -type Migrator struct { - db *sql.DB - migrations []Migration - versionTable *string - stmts map[string]*sql.Stmt - prepared bool -} - // Migration interface type Migration interface { // The version of this migration @@ -33,6 +24,16 @@ type Migration interface { // ResultFunc is the callback signature type ResultFunc func(int64, int64, error) +// A Migrator collates and runs migrations +type Migrator struct { + db *sql.DB + migrations []Migration + versionTable *string + stmts map[string]*sql.Stmt + prepared bool + callback ResultFunc +} + // Sort those migrations type sorted []Migration @@ -47,22 +48,19 @@ func (m *Migrator) Version() (int64, error) { return NilVersion, err } - rows, err := m.stmts["getVersion"].Query() - if rows.Next() { - var version int64 - err = rows.Scan(&version) + var version int64 + err = m.stmts["getVersion"].QueryRow().Scan(&version) + if err != nil { if err == sql.ErrNoRows { return NilVersion, nil } - if err == nil { - return version, nil - } + return NilVersion, err } - return 0, err + return version, nil } // Migrate migrates the database to the highest possible version -func (m *Migrator) Migrate(cb ResultFunc) error { +func (m *Migrator) Migrate() error { err := m.prepareForMigration() if err != nil { return err @@ -70,11 +68,11 @@ func (m *Migrator) Migrate(cb ResultFunc) error { // Get the last available migration v := m.migrations[len(m.migrations)-1].Version() - return m.MigrateTo(v, cb) + return m.MigrateTo(v) } // MigrateTo migrates the database to the specified version -func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error { +func (m *Migrator) MigrateTo(toVersion int64) error { err := m.prepareForMigration() if err != nil { return err @@ -84,13 +82,11 @@ func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error { currVersion, err := m.Version() if err != nil { - return err + return fmt.Errorf("migration %d failed: %s", currVersion, err) } if currVersion >= toVersion { - if cb != nil { - go cb(maxVersion, currVersion, nil) - } + go m.callback(maxVersion, currVersion, nil) return nil } @@ -109,45 +105,33 @@ func (m *Migrator) MigrateTo(toVersion int64, cb ResultFunc) error { if currVersion < nextVersion && nextVersion <= toVersion { err = func() error { - fmt.Println("running migration", nextVersion) // Start a transaction tx, err := m.db.Begin() if err != nil { - if cb != nil { - go cb(maxVersion, currVersion, err) - } - return err + return fmt.Errorf("migration %d failed: %s", currVersion, err) } - defer tx.Rollback() + defer tx.Commit() // Run the migration if err = mig.Run(tx); err != nil { - if cb != nil { - go cb(maxVersion, currVersion, err) - } - return err + tx.Rollback() + return fmt.Errorf("migration %d failed: %s", currVersion, err) } // Update the version entry - fmt.Println("updating version") if err = m.setVersion(tx, nextVersion); err != nil { - if cb != nil { - go cb(maxVersion, currVersion, err) - } - return err + tx.Rollback() + return fmt.Errorf("migration %d failed: %s", currVersion, err) } - // Commit the transaction - fmt.Println("committing version") return tx.Commit() }() + + if m.callback != nil { + go m.callback(maxVersion, currVersion, err) + } + if err != nil { - if cb != nil { - go cb(maxVersion, currVersion, err) - } return err } - if cb != nil { - go cb(maxVersion, currVersion, nil) - } } currVersion = nextVersion } @@ -204,12 +188,11 @@ func (m *Migrator) prepareStmts() error { } const ( - getVersionSQL = `select coalesce(max(version), %d) from %q` - insertVersionSQL = `insert into %q (version, applied) values ($1, $2)` + getVersionSQL = `select coalesce(max(version), %d) from %s` + insertVersionSQL = `insert into %s (version, applied) values ($1, $2)` // Use Unix timestamp for time so it works for SQLite and PostgreSQL - createTableSQL = `create table if not exists %q ( + createTableSQL = `create table if not exists %s ( version bigint not null primary key, - applied int - )` + applied int)` ) diff --git a/migrate_test.go b/migrate_test.go new file mode 100644 index 0000000..c857014 --- /dev/null +++ b/migrate_test.go @@ -0,0 +1,84 @@ +package migrate + +import ( + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +const testDB = "file:test?mode=memory&cache=shared" + +type mig struct { + version int64 + sql string +} + +func createMigration(v int64, sql string) Migration { + return &mig{version: v, sql: sql} +} + +func (m mig) Version() int64 { return m.version } + +func (m mig) Run(tx *sql.Tx) error { + _, err := tx.Exec(m.sql) + return err +} + +var testMigrations = []struct { + version int64 + sql string +}{ + // Out of order please + {version: 2, sql: "insert into test1 (pk) values (1)"}, + {version: 3, sql: "insert into test1 (pk) values (2)"}, + {version: 1, sql: "create table if not exists test1 (pk bigint not null primary key)"}, +} + +func TestMigrate(t *testing.T) { + // Load migrations + var migrations []Migration + for _, m := range testMigrations { + migrations = append(migrations, createMigration(m.version, m.sql)) + } + + db, err := sql.Open("sqlite3", testDB) + if err != nil { + t.Fatalf("DB setup failed: %v", err) + } + defer db.Close() + //db.SetMaxOpenConns(1) + + migrator := Migrator{ + db: db, + migrations: migrations, + } + + v, err := migrator.Version() + if err != nil { + t.Fatalf("Migrator.Version() failed: %v", err) + } + if v != NilVersion { + t.Fatalf("Migrator.Version() should be NilVersion, got %d", v) + } + + err = migrator.Migrate() + if err != nil { + t.Fatalf("Migrator.Migrate() failed: %v", err) + } + + v, err = migrator.Version() + if err != nil { + t.Fatalf("Migrator.Version() failed: %v", err) + } + + if int(v) != len(migrations) { + t.Errorf("expected migration version %d, got %d", len(migrations), v) + } + + var result int64 + err = db.QueryRow(`select pk from test1`).Scan(&result) + if err != nil { + t.Fatal(err) + } +} @@ -13,6 +13,14 @@ func SetVersionTable(vt string) Option { } } +// SetCallback configures the table used for recording the schema version +func SetCallback(cb ResultFunc) Option { + return func(m *Migrator) error { + m.callback = cb + return nil + } +} + // SetContext configures the context for queries /* func SetContext(ctx context.Context) Option { |
