mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
38b8c9cc58
#15555 Probably the best way to review this is commit by commit: - First commit does the actual moving. - Second commit fixes golangci-lint issues (in the least effort way to avoid refactoring or rearrangement of some of the code). - Third commit moves a printf to before the migration step is executed. In the past some customers hitting migration issues (like migration steps hanging or taking long to execute) and wanted to know which one was it. The only way to know was to look at the repository and looking for the next migration after the last one logged. Checks: - [X] Manual QA for all new/changed functionality Manual tests: - `make fleet && make db-reset`. - Adding a new migration via `make migration name=Foobar` and then running `./build/fleet prepare db`. - Enrolling a new device to Fleet (smoke test).
237 lines
5.7 KiB
Go
237 lines
5.7 KiB
Go
package goose
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sort"
|
|
)
|
|
|
|
var (
|
|
ErrNoCurrentVersion = errors.New("no current version found")
|
|
ErrNoNextVersion = errors.New("no next version found")
|
|
|
|
globalGoose = &Client{
|
|
TableName: "goose_db_version",
|
|
Dialect: &PostgresDialect{},
|
|
}
|
|
)
|
|
|
|
type Migrations []*Migration
|
|
|
|
// helpers so we can use pkg sort
|
|
func (ms Migrations) Len() int { return len(ms) }
|
|
func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
|
|
func (ms Migrations) Less(i, j int) bool {
|
|
if ms[i].Version == ms[j].Version {
|
|
log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source)
|
|
}
|
|
return ms[i].Version < ms[j].Version
|
|
}
|
|
|
|
func (ms Migrations) Current(current int64) (*Migration, error) {
|
|
for i, migration := range ms {
|
|
if migration.Version == current {
|
|
return ms[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, ErrNoCurrentVersion
|
|
}
|
|
|
|
func (ms Migrations) Next(current int64) (*Migration, error) {
|
|
for i, migration := range ms {
|
|
if migration.Version > current {
|
|
return ms[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, ErrNoNextVersion
|
|
}
|
|
|
|
func (ms Migrations) Last() (*Migration, error) {
|
|
if len(ms) == 0 {
|
|
return nil, ErrNoNextVersion
|
|
}
|
|
|
|
return ms[len(ms)-1], nil
|
|
}
|
|
|
|
func (ms Migrations) String() string {
|
|
str := ""
|
|
for _, m := range ms {
|
|
str += fmt.Sprintln(m)
|
|
}
|
|
return str
|
|
}
|
|
|
|
// AddMigration adds a new go migration to the goose struct.
|
|
func (c *Client) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
|
_, filename, _, _ := runtime.Caller(1)
|
|
v, _ := NumericComponent(filename)
|
|
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
|
|
|
|
c.Migrations = append(c.Migrations, migration)
|
|
}
|
|
|
|
// AddMigration exists for legacy support of the package global use of Goose.
|
|
// Calling the AddMigration method on a Goose struct is the preferred method.
|
|
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
|
|
// We can't just use globalGoose.AddMigration here because we need to
|
|
// correctly record the caller.
|
|
_, filename, _, _ := runtime.Caller(1)
|
|
v, _ := NumericComponent(filename)
|
|
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
|
|
|
|
globalGoose.Migrations = append(globalGoose.Migrations, migration)
|
|
}
|
|
|
|
// collect all the valid looking migration scripts in the
|
|
// migrations folder and go func registry, and key them by version
|
|
func (c *Client) collectMigrations(dirpath string, current, target int64) (Migrations, error) {
|
|
var migrations Migrations
|
|
|
|
// only load migrations from file if a path is explicitly provided
|
|
if dirpath != "" {
|
|
// extract the numeric component of each migration,
|
|
// filter out any uninteresting files,
|
|
// and ensure we only have one file per migration version.
|
|
sqlMigrations, err := filepath.Glob(dirpath + "/*.sql")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, file := range sqlMigrations {
|
|
v, err := NumericComponent(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if versionFilter(v, current, target) {
|
|
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
|
|
migrations = append(migrations, migration)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, migration := range c.Migrations {
|
|
v, err := NumericComponent(migration.Source)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if versionFilter(v, current, target) {
|
|
migrations = append(migrations, migration)
|
|
}
|
|
}
|
|
|
|
migrations = sortAndConnectMigrations(migrations)
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
func sortAndConnectMigrations(migrations Migrations) Migrations {
|
|
sort.Sort(migrations)
|
|
|
|
// now that we're sorted in the appropriate direction,
|
|
// populate next and previous for each migration
|
|
for i, m := range migrations {
|
|
prev := int64(-1)
|
|
if i > 0 {
|
|
prev = migrations[i-1].Version
|
|
migrations[i-1].Next = m.Version
|
|
}
|
|
migrations[i].Previous = prev
|
|
}
|
|
|
|
return migrations
|
|
}
|
|
|
|
func versionFilter(v, current, target int64) bool {
|
|
if target > current {
|
|
return v > current && v <= target
|
|
}
|
|
|
|
if target < current {
|
|
return v <= current && v > target
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// retrieve the current version for this DB.
|
|
// Create and initialize the DB version table if it doesn't exist.
|
|
func (c *Client) GetDBVersion(db *sql.DB) (int64, error) {
|
|
rows, err := c.Dialect.dbVersionQuery(db, c.TableName)
|
|
if err != nil {
|
|
return 0, c.createVersionTable(db)
|
|
}
|
|
defer rows.Close()
|
|
|
|
// The most recent record for each migration specifies
|
|
// whether it has been applied or rolled back.
|
|
// The first version we find that has been applied is the current version.
|
|
|
|
toSkip := make([]int64, 0)
|
|
|
|
for rows.Next() {
|
|
var row MigrationRecord
|
|
if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil {
|
|
log.Fatal("error scanning rows:", err)
|
|
}
|
|
|
|
// have we already marked this version to be skipped?
|
|
skip := false
|
|
for _, v := range toSkip {
|
|
if v == row.VersionId {
|
|
skip = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if skip {
|
|
continue
|
|
}
|
|
|
|
// if version has been applied we're done
|
|
if row.IsApplied {
|
|
return row.VersionId, nil
|
|
}
|
|
|
|
// latest version of migration has not been applied.
|
|
toSkip = append(toSkip, row.VersionId)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
panic("unreachable")
|
|
}
|
|
|
|
// Create the goose_db_version table
|
|
// and insert the initial 0 value into it
|
|
func (c *Client) createVersionTable(db *sql.DB) error {
|
|
txn, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
d := c.Dialect
|
|
|
|
if _, err := txn.Exec(d.createVersionTableSql(c.TableName)); err != nil {
|
|
txn.Rollback() //nolint:errcheck
|
|
return err
|
|
}
|
|
|
|
version := 0
|
|
applied := true
|
|
if _, err := txn.Exec(d.insertVersionSql(c.TableName), version, applied); err != nil {
|
|
txn.Rollback() //nolint:errcheck
|
|
return err
|
|
}
|
|
|
|
return txn.Commit()
|
|
}
|