aboutsummaryrefslogtreecommitdiff
path: root/file.go
blob: 28673e0cc503d0bcec5fc9cb0d64b6e3cd43aafa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package migrate

import (
	"database/sql"
	"fmt"
	"io/ioutil"
	"net/url"
	"os"
	"path"
	"path/filepath"
	"regexp"
	"strconv"
)

// Regex matches the following pattern:
// 123_name.ext
var validFilename = regexp.MustCompile(`^([0-9]+)_(.*)\.(.*)$`)

// Just the path to the migration file
type fileMigration string

// NewFileMigrator creates a new set of migrations from a path
// Each one is run in a transaction.
func NewFileMigrator(db *sql.DB, path string, opts ...Option) (*Migrator, error) {
	migrations, err := readFiles(path)
	if err != nil {
		return nil, err
	}

	m := Migrator{db: db, migrations: migrations}

	for _, opt := range opts {
		if err = opt(&m); err != nil {
			return nil, err
		}
	}

	return &m, nil
}

// The Version is extracted from the filename
// It implements the Migration interface
func (fm fileMigration) Version() int {
	m := validFilename.FindStringSubmatch(path.Base(string(fm)))
	if len(m) == 4 {
		if version, err := strconv.ParseInt(m[1], 10, 32); err == nil {
			return int(version)
		}
	}
	return -1
}

// Run executes the migration
// It implements the Migration interface
func (fm fileMigration) Run(tx *sql.Tx) error {
	r, err := os.Open(string(fm))
	if err != nil {
		return err
	}
	defer r.Close()

	buf, err := ioutil.ReadAll(r)
	if err != nil {
		return err
	}
	_, err = tx.Exec(string(buf[:]))
	return err
}

// Read all files in path
// They will be sorted by the migrator according to Version()
func readFiles(uri string) (migrations []Migration, err error) {
	u, err := url.Parse(uri)
	if err != nil {
		return nil, err
	}

	// Host might be `.`
	p := u.Host + u.Path

	if len(p) == 0 {
		// Default to current directory
		wd, err := os.Getwd()
		if err != nil {
			return nil, err
		}
		p = wd
	} else if p[0:1] == "." || p[0:1] != "/" {
		// Ensure path is absolute
		abs, err := filepath.Abs(p)
		if err != nil {
			return nil, err
		}
		p = abs
	}

	// Scan entire directory
	files, err := ioutil.ReadDir(p)
	if err != nil {
		return nil, err
	}

	seen := make(map[int]bool)

	for _, fi := range files {
		if !fi.IsDir() {
			fm := fileMigration(path.Join(p, fi.Name()))

			// Ignore invalid filenames
			v := fm.Version()
			if v == -1 {
				fmt.Printf("invalid version %d for %s\n", v, fm)
				continue
			}
			if seen[v] {
				fmt.Printf("duplicate version %d for %s\n", v, fm)
				continue
			}
			migrations = append(migrations, fm)
			seen[v] = true
		}
	}
	if len(migrations) == 0 {
		return nil, fmt.Errorf("no migrations found")
	}
	return migrations, nil
}