1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
package migrate
import (
"database/sql"
"fmt"
"io/ioutil"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
)
// Regex matches the following pattern:
// 123_name.ext
var validFilename = regexp.MustCompile(`^([0-9]+)_(.*)\.(.*)$`)
// Just the path to the migration file
type fileMigration string
// NewFileMigrator creates a new set of migrations from a path
// Each one is run in a transaction.
func NewFileMigrator(db *sql.DB, path string, opts ...Option) (*Migrator, error) {
migrations, err := readFiles(path)
if err != nil {
return nil, err
}
m := Migrator{db: db, migrations: migrations}
for _, opt := range opts {
if err = opt(&m); err != nil {
return nil, err
}
}
return &m, nil
}
// The Version is extracted from the filename
// It implements the Migration interface
func (fm fileMigration) Version() int {
m := validFilename.FindStringSubmatch(path.Base(string(fm)))
if len(m) == 4 {
if version, err := strconv.ParseInt(m[1], 10, 32); err == nil {
return int(version)
}
}
return -1
}
// Run executes the migration
// It implements the Migration interface
func (fm fileMigration) Run(tx *sql.Tx) error {
r, err := os.Open(string(fm))
if err != nil {
return err
}
defer r.Close()
buf, err := ioutil.ReadAll(r)
if err != nil {
return err
}
_, err = tx.Exec(string(buf[:]))
return err
}
// Read all files in path
// They will be sorted by the migrator according to Version()
func readFiles(uri string) (migrations []Migration, err error) {
u, err := url.Parse(uri)
if err != nil {
return nil, err
}
// Host might be `.`
p := u.Host + u.Path
if len(p) == 0 {
// Default to current directory
wd, err := os.Getwd()
if err != nil {
return nil, err
}
p = wd
} else if p[0:1] == "." || p[0:1] != "/" {
// Ensure path is absolute
abs, err := filepath.Abs(p)
if err != nil {
return nil, err
}
p = abs
}
// Scan entire directory
files, err := ioutil.ReadDir(p)
if err != nil {
return nil, err
}
seen := make(map[int]bool)
for _, fi := range files {
if !fi.IsDir() {
fm := fileMigration(path.Join(p, fi.Name()))
// Ignore invalid filenames
v := fm.Version()
if v == -1 {
fmt.Printf("invalid version %d for %s\n", v, fm)
continue
}
if seen[v] {
fmt.Printf("duplicate version %d for %s\n", v, fm)
continue
}
migrations = append(migrations, fm)
seen[v] = true
}
}
if len(migrations) == 0 {
return nil, fmt.Errorf("no migrations found")
}
return migrations, nil
}
|