diff options
| author | felix <felix> | 2018-06-22 14:51:30 +0000 |
|---|---|---|
| committer | felix <felix> | 2018-06-22 14:51:30 +0000 |
| commit | 68b150b104c9f832a624f2440b20dfe329a8a551 (patch) | |
| tree | c064530f258e470e623d1e090a883a8f7eb6efc9 /migrate_test.go | |
| parent | 3d2f2f2b0a357462bbda428d0626ee4b606b4ff2 (diff) | |
| download | migrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.gz migrate-68b150b104c9f832a624f2440b20dfe329a8a551.tar.bz2 | |
Add tests
Diffstat (limited to 'migrate_test.go')
| -rw-r--r-- | migrate_test.go | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/migrate_test.go b/migrate_test.go new file mode 100644 index 0000000..c857014 --- /dev/null +++ b/migrate_test.go @@ -0,0 +1,84 @@ +package migrate + +import ( + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +const testDB = "file:test?mode=memory&cache=shared" + +type mig struct { + version int64 + sql string +} + +func createMigration(v int64, sql string) Migration { + return &mig{version: v, sql: sql} +} + +func (m mig) Version() int64 { return m.version } + +func (m mig) Run(tx *sql.Tx) error { + _, err := tx.Exec(m.sql) + return err +} + +var testMigrations = []struct { + version int64 + sql string +}{ + // Out of order please + {version: 2, sql: "insert into test1 (pk) values (1)"}, + {version: 3, sql: "insert into test1 (pk) values (2)"}, + {version: 1, sql: "create table if not exists test1 (pk bigint not null primary key)"}, +} + +func TestMigrate(t *testing.T) { + // Load migrations + var migrations []Migration + for _, m := range testMigrations { + migrations = append(migrations, createMigration(m.version, m.sql)) + } + + db, err := sql.Open("sqlite3", testDB) + if err != nil { + t.Fatalf("DB setup failed: %v", err) + } + defer db.Close() + //db.SetMaxOpenConns(1) + + migrator := Migrator{ + db: db, + migrations: migrations, + } + + v, err := migrator.Version() + if err != nil { + t.Fatalf("Migrator.Version() failed: %v", err) + } + if v != NilVersion { + t.Fatalf("Migrator.Version() should be NilVersion, got %d", v) + } + + err = migrator.Migrate() + if err != nil { + t.Fatalf("Migrator.Migrate() failed: %v", err) + } + + v, err = migrator.Version() + if err != nil { + t.Fatalf("Migrator.Version() failed: %v", err) + } + + if int(v) != len(migrations) { + t.Errorf("expected migration version %d, got %d", len(migrations), v) + } + + var result int64 + err = db.QueryRow(`select pk from test1`).Scan(&result) + if err != nil { + t.Fatal(err) + } +} |
