aboutsummaryrefslogtreecommitdiff
path: root/migrate_test.go
diff options
context:
space:
mode:
authorfelix <felix>2018-06-22 14:51:30 +0000
committerfelix <felix>2018-06-22 14:51:30 +0000
commit68b150b104c9f832a624f2440b20dfe329a8a551 (patch)
treec064530f258e470e623d1e090a883a8f7eb6efc9 /migrate_test.go
parent3d2f2f2b0a357462bbda428d0626ee4b606b4ff2 (diff)
downloadmigrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.gz
migrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.bz2
Add tests
Diffstat (limited to 'migrate_test.go')
-rw-r--r--migrate_test.go84
1 files changed, 84 insertions, 0 deletions
diff --git a/migrate_test.go b/migrate_test.go
new file mode 100644
index 0000000..c857014
--- /dev/null
+++ b/migrate_test.go
@@ -0,0 +1,84 @@
+package migrate
+
+import (
+ "database/sql"
+ "testing"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+const testDB = "file:test?mode=memory&cache=shared"
+
+type mig struct {
+ version int64
+ sql string
+}
+
+func createMigration(v int64, sql string) Migration {
+ return &mig{version: v, sql: sql}
+}
+
+func (m mig) Version() int64 { return m.version }
+
+func (m mig) Run(tx *sql.Tx) error {
+ _, err := tx.Exec(m.sql)
+ return err
+}
+
+var testMigrations = []struct {
+ version int64
+ sql string
+}{
+ // Out of order please
+ {version: 2, sql: "insert into test1 (pk) values (1)"},
+ {version: 3, sql: "insert into test1 (pk) values (2)"},
+ {version: 1, sql: "create table if not exists test1 (pk bigint not null primary key)"},
+}
+
+func TestMigrate(t *testing.T) {
+ // Load migrations
+ var migrations []Migration
+ for _, m := range testMigrations {
+ migrations = append(migrations, createMigration(m.version, m.sql))
+ }
+
+ db, err := sql.Open("sqlite3", testDB)
+ if err != nil {
+ t.Fatalf("DB setup failed: %v", err)
+ }
+ defer db.Close()
+ //db.SetMaxOpenConns(1)
+
+ migrator := Migrator{
+ db: db,
+ migrations: migrations,
+ }
+
+ v, err := migrator.Version()
+ if err != nil {
+ t.Fatalf("Migrator.Version() failed: %v", err)
+ }
+ if v != NilVersion {
+ t.Fatalf("Migrator.Version() should be NilVersion, got %d", v)
+ }
+
+ err = migrator.Migrate()
+ if err != nil {
+ t.Fatalf("Migrator.Migrate() failed: %v", err)
+ }
+
+ v, err = migrator.Version()
+ if err != nil {
+ t.Fatalf("Migrator.Version() failed: %v", err)
+ }
+
+ if int(v) != len(migrations) {
+ t.Errorf("expected migration version %d, got %d", len(migrations), v)
+ }
+
+ var result int64
+ err = db.QueryRow(`select pk from test1`).Scan(&result)
+ if err != nil {
+ t.Fatal(err)
+ }
+}