diff --git a/Makefile b/Makefile index 366dcf062..02a5d921b 100644 --- a/Makefile +++ b/Makefile @@ -199,7 +199,7 @@ update-go-cloner: go run ./tools/cloner-check/main.go --update migration: - go run github.com/fleetdm/goose/cmd/goose -dir server/datastore/mysql/migrations/tables create $(name) + go run ./server/goose/cmd/goose -dir server/datastore/mysql/migrations/tables create $(name) gofmt -w server/datastore/mysql/migrations/tables/*_$(name)*.go clean: clean-assets diff --git a/go.mod b/go.mod index 2ea831472..dd036cddd 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,6 @@ require ( github.com/elazarl/go-bindata-assetfs v1.0.1 github.com/facebookincubator/nvdtools v0.1.6-0.20231010102659-d14ce526f176 github.com/fatih/color v1.15.0 - github.com/fleetdm/goose v0.0.0-20221011170007-06aacf8ac547 github.com/getsentry/sentry-go v0.18.0 github.com/ghodss/yaml v1.0.0 github.com/go-ini/ini v1.67.0 @@ -59,6 +58,7 @@ require ( github.com/kevinburke/go-bindata v3.24.0+incompatible github.com/kolide/kit v0.0.0-20221107170827-fb85e3d59eab github.com/kolide/launcher v1.0.12 + github.com/lib/pq v1.10.9 github.com/macadmins/osquery-extension v0.0.15 github.com/mattermost/xml-roundtrip-validator v0.0.0-20201213122252-bcd7e1b9601e github.com/mattn/go-sqlite3 v1.14.13 @@ -99,6 +99,7 @@ require ( github.com/ulikunitz/xz v0.5.10 github.com/urfave/cli/v2 v2.23.5 github.com/valyala/fasthttp v1.40.0 + github.com/ziutek/mymysql v1.5.4 go.elastic.co/apm/module/apmgorilla/v2 v2.3.0 go.elastic.co/apm/module/apmsql/v2 v2.4.3 go.elastic.co/apm/v2 v2.4.3 diff --git a/go.sum b/go.sum index 8da3240df..0d8e270c1 100644 --- a/go.sum +++ b/go.sum @@ -445,8 +445,6 @@ github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/fleetdm/goose v0.0.0-20221011170007-06aacf8ac547 h1:3Vlgx6mJYFlj3GPB3CgoQrR7URgE0GQGnKYNfoXxuUo= -github.com/fleetdm/goose v0.0.0-20221011170007-06aacf8ac547/go.mod h1:d7Q+0eCENnKQUhkfAUVLfGnD4QcgJMF/uB9WRTN9TDI= github.com/fleetdm/kolide-kit v0.0.0-20230519160117-86cc9441f9c1 h1:9JGbRO6QKpHr5HO5t6g3/5EspV5eDWaLKbRH7xKqL/c= github.com/fleetdm/kolide-kit v0.0.0-20230519160117-86cc9441f9c1/go.mod h1:HHtqF91JHl66L+Ms8aswzqVb2eEU5O3DRNiFmUzOf60= github.com/fleetdm/nanodep v0.1.1-0.20221221202251-71b67ab1da24 h1:XhczaxKV3J4NjztroidSnYKyq5xtxF+amBYdBWeik58= @@ -525,7 +523,6 @@ github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+ github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= @@ -856,7 +853,6 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -900,7 +896,6 @@ github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.13 h1:1tj15ngiFfcZzii7yd82foL+ks+ouQcj8j/TPq3fk1I= github.com/mattn/go-sqlite3 v1.14.13/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= @@ -1233,6 +1228,7 @@ github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPR github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zclconf/go-cty v1.1.0/go.mod h1:xnAOWiHeOqg2nWS62VtQ7pbOu17FtxJNW8RLEih+O3s= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.elastic.co/apm/module/apmgorilla/v2 v2.3.0 h1:jHw8N252UTwKTk945+Am8AaawhHC6DWpFVeTXQO8Gko= go.elastic.co/apm/module/apmgorilla/v2 v2.3.0/go.mod h1:2LXDBbVhFf9rF65jZecvl78IZMuvSRldQ+9A/fjfIo0= diff --git a/server/datastore/mysql/migrations/data/migration.go b/server/datastore/mysql/migrations/data/migration.go index 743b06230..6185cc832 100644 --- a/server/datastore/mysql/migrations/data/migration.go +++ b/server/datastore/mysql/migrations/data/migration.go @@ -1,7 +1,5 @@ package data -import "github.com/fleetdm/goose" +import "github.com/fleetdm/fleet/v4/server/goose" -var ( - MigrationClient = goose.New("migration_status_data", goose.MySqlDialect{}) -) +var MigrationClient = goose.New("migration_status_data", goose.MySqlDialect{}) diff --git a/server/datastore/mysql/migrations/tables/migration.go b/server/datastore/mysql/migrations/tables/migration.go index 4359791a9..a7872b5f7 100644 --- a/server/datastore/mysql/migrations/tables/migration.go +++ b/server/datastore/mysql/migrations/tables/migration.go @@ -5,7 +5,7 @@ import ( "encoding/json" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/goose" + "github.com/fleetdm/fleet/v4/server/goose" "github.com/jmoiron/sqlx" "github.com/pkg/errors" ) @@ -59,7 +59,6 @@ WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ? `, table, index).Scan(&count) - if err != nil { return false } diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index 78975dbf4..2fbd34f03 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -27,7 +27,7 @@ import ( "github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data" "github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/goose" + "github.com/fleetdm/fleet/v4/server/goose" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/go-sql-driver/mysql" diff --git a/server/goose/LICENSE b/server/goose/LICENSE new file mode 100644 index 000000000..00c71c7ce --- /dev/null +++ b/server/goose/LICENSE @@ -0,0 +1,23 @@ +Modified work Copyright (c) 2021-present Fleet Device Management Inc +Modified work Copyright (c) 2016 Vojtech Vitek +Original work Copyright (c) 2012 Liam Staskawicz + + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/server/goose/README.md b/server/goose/README.md new file mode 100644 index 000000000..8f977aeba --- /dev/null +++ b/server/goose/README.md @@ -0,0 +1,7 @@ +# goose + +Goose is a database migration tool. Manage your database's evolution by creating incremental SQL files or Go functions. + +## Contents + +The contents in this directory are a copy (made in December 2023) of https://github.com/fleetdm/goose which is a fork of https://github.com/pressly/goose with some customizations for working with [Fleet](https://github.com/fleetdm/fleet). \ No newline at end of file diff --git a/server/goose/cmd/goose/main.go b/server/goose/cmd/goose/main.go new file mode 100644 index 000000000..a9c0993ed --- /dev/null +++ b/server/goose/cmd/goose/main.go @@ -0,0 +1,107 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "os" + + "github.com/fleetdm/fleet/v4/server/goose" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + _ "github.com/ziutek/mymysql/godrv" +) + +var ( + flags = flag.NewFlagSet("goose", flag.ExitOnError) + dir = flags.String("dir", ".", "directory with migration files") +) + +func main() { + flags.Usage = usage + if err := flags.Parse(os.Args[1:]); err != nil { + log.Fatalf("flags parse: %s", err) + } + + args := flags.Args() + + if len(args) > 1 && args[0] == "create" { + if err := goose.Run("create", nil, *dir, args[1:]...); err != nil { + log.Fatalf("goose run: %v", err) + } + return + } + + if len(args) < 3 { + flags.Usage() + return + } + + if args[0] == "-h" || args[0] == "--help" { + flags.Usage() + return + } + + driver, dbstring, command := args[0], args[1], args[2] + + switch driver { + case "postgres", "mysql", "sqlite3": + if err := goose.SetDialect(driver); err != nil { + log.Fatal(err) + } + default: + log.Fatalf("%q driver not supported\n", driver) + } + + switch dbstring { + case "": + log.Fatalf("-dbstring=%q not supported\n", dbstring) + default: + } + + db, err := sql.Open(driver, dbstring) + if err != nil { + log.Fatalf("-dbstring=%q: %v\n", dbstring, err) + } + + arguments := []string{} + if len(args) > 3 { + arguments = append(arguments, args[3:]...) + } + + if err := goose.Run(command, db, *dir, arguments...); err != nil { + log.Fatalf("goose run: %v", err) + } +} + +func usage() { + fmt.Print(usagePrefix) + flags.PrintDefaults() + fmt.Print(usageCommands) +} + +var ( + usagePrefix = `Usage: goose [OPTIONS] DRIVER DBSTRING COMMAND + +Examples: + goose postgres "user=postgres dbname=postgres sslmode=disable" up + goose mysql "user:password@/dbname" down + goose sqlite3 ./foo.db status + goose postgres "user=postgres dbname=postgres sslmode=disable" create init sql + +Options: +` + + usageCommands = ` +Commands: + up Migrate the DB to the most recent version available + down Roll back the version by 1 + redo Re-run the latest migration + status Dump the migration status for the current DB + version Print the current version of the database + create Creates a blank migration template +` +) diff --git a/server/goose/create.go b/server/goose/create.go new file mode 100644 index 000000000..5dd943044 --- /dev/null +++ b/server/goose/create.go @@ -0,0 +1,18 @@ +package goose + +import ( + "database/sql" + "fmt" + "time" +) + +// Create writes a new blank migration file. +func Create(db *sql.DB, dir, name, migrationType string) error { + paths, err := CreateMigration(name, migrationType, dir, time.Now()) + if err != nil { + return err + } + fmt.Printf("Created %s migration files at %v\n", migrationType, paths) + + return nil +} diff --git a/server/goose/dialect.go b/server/goose/dialect.go new file mode 100644 index 000000000..bfa5f879c --- /dev/null +++ b/server/goose/dialect.go @@ -0,0 +1,126 @@ +package goose + +import ( + "database/sql" + "fmt" +) + +// SqlDialect abstracts the details of specific SQL dialects +// for goose's few SQL specific statements +type SqlDialect interface { + createVersionTableSql(name string) string // sql string to create the goose_db_version table + insertVersionSql(name string) string // sql string to insert the initial version table row + dbVersionQuery(db *sql.DB, name string) (*sql.Rows, error) +} + +func GetDialect() SqlDialect { + return globalGoose.Dialect +} + +func (c *Client) SetDialect(d string) error { + switch d { + case "postgres": + c.Dialect = &PostgresDialect{} + case "mysql": + c.Dialect = &MySqlDialect{} + case "sqlite3": + c.Dialect = &Sqlite3Dialect{} + default: + return fmt.Errorf("%q: unknown dialect", d) + } + + return nil +} + +func SetDialect(d string) error { + return globalGoose.SetDialect(d) +} + +//////////////////////////// +// Postgres +//////////////////////////// + +type PostgresDialect struct{} + +func (pg PostgresDialect) createVersionTableSql(name string) string { + return `CREATE TABLE ` + name + ` ( + id serial NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + );` +} + +func (pg PostgresDialect) insertVersionSql(name string) string { + return "INSERT INTO " + name + " (version_id, is_applied) VALUES ($1, $2);" +} + +func (pg PostgresDialect) dbVersionQuery(db *sql.DB, name string) (*sql.Rows, error) { + /* #nosec G202 -- name is actually well defined */ + rows, err := db.Query("SELECT version_id, is_applied from " + name + " ORDER BY id DESC") + if err != nil { + return nil, err + } + + return rows, err +} + +//////////////////////////// +// MySQL +//////////////////////////// + +type MySqlDialect struct{} + +func (m MySqlDialect) createVersionTableSql(name string) string { + return `CREATE TABLE ` + name + ` ( + id serial NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + );` +} + +func (m MySqlDialect) insertVersionSql(name string) string { + return "INSERT INTO " + name + " (version_id, is_applied) VALUES (?, ?);" +} + +func (m MySqlDialect) dbVersionQuery(db *sql.DB, name string) (*sql.Rows, error) { + /* #nosec G202 -- name is actually well defined */ + rows, err := db.Query("SELECT version_id, is_applied from " + name + " ORDER BY id DESC") + if err != nil { + return nil, err + } + + return rows, err +} + +//////////////////////////// +// sqlite3 +//////////////////////////// + +type Sqlite3Dialect struct{} + +func (m Sqlite3Dialect) createVersionTableSql(name string) string { + return `CREATE TABLE ` + name + ` ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version_id INTEGER NOT NULL, + is_applied INTEGER NOT NULL, + tstamp TIMESTAMP DEFAULT (datetime('now')) + );` +} + +func (m Sqlite3Dialect) insertVersionSql(name string) string { + return "INSERT INTO " + name + " (version_id, is_applied) VALUES (?, ?);" +} + +func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB, name string) (*sql.Rows, error) { + /* #nosec G202 -- name is actually well defined */ + rows, err := db.Query("SELECT version_id, is_applied from " + name + " ORDER BY id DESC") + if err != nil { + return nil, err + } + + return rows, err +} diff --git a/server/goose/down.go b/server/goose/down.go new file mode 100644 index 000000000..b23a1c1ec --- /dev/null +++ b/server/goose/down.go @@ -0,0 +1,25 @@ +package goose + +import ( + "database/sql" + "fmt" +) + +func (c *Client) Down(db *sql.DB, dir string) error { + currentVersion, err := c.GetDBVersion(db) + if err != nil { + return err + } + + migrations, err := c.collectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + + current, err := migrations.Current(currentVersion) + if err != nil { + return fmt.Errorf("no migration %v", currentVersion) + } + + return c.runMigration(db, current, migrateDown) +} diff --git a/server/goose/goose.go b/server/goose/goose.go new file mode 100644 index 000000000..e039bed8a --- /dev/null +++ b/server/goose/goose.go @@ -0,0 +1,76 @@ +package goose + +import ( + "database/sql" + "errors" + "fmt" +) + +var ( + minVersion = int64(0) + maxVersion = int64((1 << 63) - 1) +) + +// Client stores the migration state and preferences. Prefer interacting with +// the Goose API through a Client struct created with New rather than using the +// global Client and functions. +type Client struct { + // TableName is the name of the table used to store migration status + // for this client. + TableName string + // Dialect is the SqlDialect to use. + Dialect SqlDialect + // Migrations is the list of migrations. + Migrations Migrations +} + +func New(tableName string, dialect SqlDialect) *Client { + return &Client{ + TableName: tableName, + Dialect: dialect, + } +} + +func Run(command string, db *sql.DB, dir string, args ...string) error { + switch command { + case "up": + if err := globalGoose.Up(db, dir); err != nil { + return err + } + case "up-by-one": + if err := globalGoose.UpByOne(db, dir); err != nil { + return err + } + case "create": + if len(args) == 0 { + return errors.New("create must be of form: goose [OPTIONS] DRIVER DBSTRING create NAME [go|sql]") + } + + migrationType := "go" + if len(args) == 2 { + migrationType = args[1] + } + if err := Create(db, dir, args[0], migrationType); err != nil { + return err + } + case "down": + if err := globalGoose.Down(db, dir); err != nil { + return err + } + case "redo": + if err := globalGoose.Redo(db, dir); err != nil { + return err + } + case "status": + if err := globalGoose.Status(db, dir); err != nil { + return err + } + case "version": + if err := globalGoose.Version(db, dir); err != nil { + return err + } + default: + return fmt.Errorf("%q: no such command", command) + } + return nil +} diff --git a/server/goose/migrate.go b/server/goose/migrate.go new file mode 100644 index 000000000..bf5ca921b --- /dev/null +++ b/server/goose/migrate.go @@ -0,0 +1,236 @@ +package goose + +import ( + "database/sql" + "errors" + "fmt" + "log" + "path/filepath" + "runtime" + "sort" +) + +var ( + ErrNoCurrentVersion = errors.New("no current version found") + ErrNoNextVersion = errors.New("no next version found") + + globalGoose = &Client{ + TableName: "goose_db_version", + Dialect: &PostgresDialect{}, + } +) + +type Migrations []*Migration + +// helpers so we can use pkg sort +func (ms Migrations) Len() int { return len(ms) } +func (ms Migrations) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] } +func (ms Migrations) Less(i, j int) bool { + if ms[i].Version == ms[j].Version { + log.Fatalf("goose: duplicate version %v detected:\n%v\n%v", ms[i].Version, ms[i].Source, ms[j].Source) + } + return ms[i].Version < ms[j].Version +} + +func (ms Migrations) Current(current int64) (*Migration, error) { + for i, migration := range ms { + if migration.Version == current { + return ms[i], nil + } + } + + return nil, ErrNoCurrentVersion +} + +func (ms Migrations) Next(current int64) (*Migration, error) { + for i, migration := range ms { + if migration.Version > current { + return ms[i], nil + } + } + + return nil, ErrNoNextVersion +} + +func (ms Migrations) Last() (*Migration, error) { + if len(ms) == 0 { + return nil, ErrNoNextVersion + } + + return ms[len(ms)-1], nil +} + +func (ms Migrations) String() string { + str := "" + for _, m := range ms { + str += fmt.Sprintln(m) + } + return str +} + +// AddMigration adds a new go migration to the goose struct. +func (c *Client) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { + _, filename, _, _ := runtime.Caller(1) + v, _ := NumericComponent(filename) + migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename} + + c.Migrations = append(c.Migrations, migration) +} + +// AddMigration exists for legacy support of the package global use of Goose. +// Calling the AddMigration method on a Goose struct is the preferred method. +func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { + // We can't just use globalGoose.AddMigration here because we need to + // correctly record the caller. + _, filename, _, _ := runtime.Caller(1) + v, _ := NumericComponent(filename) + migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename} + + globalGoose.Migrations = append(globalGoose.Migrations, migration) +} + +// collect all the valid looking migration scripts in the +// migrations folder and go func registry, and key them by version +func (c *Client) collectMigrations(dirpath string, current, target int64) (Migrations, error) { + var migrations Migrations + + // only load migrations from file if a path is explicitly provided + if dirpath != "" { + // extract the numeric component of each migration, + // filter out any uninteresting files, + // and ensure we only have one file per migration version. + sqlMigrations, err := filepath.Glob(dirpath + "/*.sql") + if err != nil { + return nil, err + } + + for _, file := range sqlMigrations { + v, err := NumericComponent(file) + if err != nil { + return nil, err + } + if versionFilter(v, current, target) { + migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file} + migrations = append(migrations, migration) + } + } + } + + for _, migration := range c.Migrations { + v, err := NumericComponent(migration.Source) + if err != nil { + return nil, err + } + if versionFilter(v, current, target) { + migrations = append(migrations, migration) + } + } + + migrations = sortAndConnectMigrations(migrations) + + return migrations, nil +} + +func sortAndConnectMigrations(migrations Migrations) Migrations { + sort.Sort(migrations) + + // now that we're sorted in the appropriate direction, + // populate next and previous for each migration + for i, m := range migrations { + prev := int64(-1) + if i > 0 { + prev = migrations[i-1].Version + migrations[i-1].Next = m.Version + } + migrations[i].Previous = prev + } + + return migrations +} + +func versionFilter(v, current, target int64) bool { + if target > current { + return v > current && v <= target + } + + if target < current { + return v <= current && v > target + } + + return false +} + +// retrieve the current version for this DB. +// Create and initialize the DB version table if it doesn't exist. +func (c *Client) GetDBVersion(db *sql.DB) (int64, error) { + rows, err := c.Dialect.dbVersionQuery(db, c.TableName) + if err != nil { + return 0, c.createVersionTable(db) + } + defer rows.Close() + + // The most recent record for each migration specifies + // whether it has been applied or rolled back. + // The first version we find that has been applied is the current version. + + toSkip := make([]int64, 0) + + for rows.Next() { + var row MigrationRecord + if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil { + log.Fatal("error scanning rows:", err) + } + + // have we already marked this version to be skipped? + skip := false + for _, v := range toSkip { + if v == row.VersionId { + skip = true + break + } + } + + if skip { + continue + } + + // if version has been applied we're done + if row.IsApplied { + return row.VersionId, nil + } + + // latest version of migration has not been applied. + toSkip = append(toSkip, row.VersionId) + } + + if err := rows.Err(); err != nil { + return 0, err + } + + panic("unreachable") +} + +// Create the goose_db_version table +// and insert the initial 0 value into it +func (c *Client) createVersionTable(db *sql.DB) error { + txn, err := db.Begin() + if err != nil { + return err + } + + d := c.Dialect + + if _, err := txn.Exec(d.createVersionTableSql(c.TableName)); err != nil { + txn.Rollback() //nolint:errcheck + return err + } + + version := 0 + applied := true + if _, err := txn.Exec(d.insertVersionSql(c.TableName), version, applied); err != nil { + txn.Rollback() //nolint:errcheck + return err + } + + return txn.Commit() +} diff --git a/server/goose/migrate_test.go b/server/goose/migrate_test.go new file mode 100644 index 000000000..97a0a1fee --- /dev/null +++ b/server/goose/migrate_test.go @@ -0,0 +1,56 @@ +package goose + +import "testing" + +func newMigration(v int64, src string) *Migration { + return &Migration{Version: v, Previous: -1, Next: -1, Source: src} +} + +func TestMigrationSort(t *testing.T) { + + ms := Migrations{} + + // insert in any order + ms = append(ms, newMigration(20120000, "test")) + ms = append(ms, newMigration(20128000, "test")) + ms = append(ms, newMigration(20129000, "test")) + ms = append(ms, newMigration(20127000, "test")) + + ms = sortAndConnectMigrations(ms) + + sorted := []int64{20120000, 20127000, 20128000, 20129000} + + validateMigrationSort(t, ms, sorted) +} + +func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) { + + for i, m := range ms { + if sorted[i] != m.Version { + t.Error("incorrect sorted version") + } + + var next, prev int64 + + if i == 0 { + prev = -1 + next = ms[i+1].Version + } else if i == len(ms)-1 { + prev = ms[i-1].Version + next = -1 + } else { + prev = ms[i-1].Version + next = ms[i+1].Version + } + + if m.Next != next { + t.Errorf("mismatched Next. v: %v, got %v, wanted %v\n", m, m.Next, next) + } + + if m.Previous != prev { + t.Errorf("mismatched Previous v: %v, got %v, wanted %v\n", m, m.Previous, prev) + } + } + + t.Log(ms) +} diff --git a/server/goose/migration.go b/server/goose/migration.go new file mode 100644 index 000000000..28767a5b1 --- /dev/null +++ b/server/goose/migration.go @@ -0,0 +1,219 @@ +package goose + +import ( + "database/sql" + "errors" + "fmt" + "log" + "path/filepath" + "regexp" + "strconv" + "strings" + "text/template" + "time" +) + +type MigrationRecord struct { + VersionId int64 + TStamp time.Time + IsApplied bool // was this a result of up() or down() +} + +type Migration struct { + Version int64 + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + Source string // path to .sql script + UpFn func(*sql.Tx) error // Up go migration function + DownFn func(*sql.Tx) error // Down go migration function +} + +const ( + migrateUp = true + migrateDown = !migrateUp +) + +func (m *Migration) String() string { + return fmt.Sprintf(m.Source) +} + +func (c *Client) runMigration(db *sql.DB, m *Migration, direction bool) error { + switch filepath.Ext(m.Source) { + case ".sql": + if err := c.runSQLMigration(db, m.Source, m.Version, direction); err != nil { + return fmt.Errorf("failed to run migration: %w", err) + } + + case ".go": + name, date := parseNameAndDate(m.Source) + log.Printf("[%s] %s\n", date, name) + + tx, err := db.Begin() + if err != nil { + log.Fatal("db.Begin: ", err) + } + + fn := m.UpFn + if !direction { + fn = m.DownFn + } + if fn != nil { + if err := fn(tx); err != nil { + tx.Rollback() //nolint:errcheck + log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(m.Source), err) + return err + } + } + + if err = c.FinalizeMigration(tx, direction, m.Version); err != nil { + log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(m.Source), err) + } + } + + return nil +} + +var ( + upperReplace = regexp.MustCompile("([a-z])([A-Z])") // e.g. UpdateBuiltin -> Update Builtin + allUpperWordsReplace = regexp.MustCompile("([A-Z]+)([A-Z][a-z])") // e.g. IDIn -> ID In +) + +func parseNameAndDate(source string) (name string, date string) { + parts := strings.SplitN(strings.TrimSuffix(filepath.Base(source), ".go"), "_", 2) + // Stripping seconds [:8] because Fleet developers add seconds when re-arranging new migrations + // e.g.: 2022/10/10 15:43:46 fail to parse time: parsing time "20201021104586": second out of range + datePart := parts[0][:8] + mt, err := time.Parse("20060102", datePart) + if err != nil { + log.Fatalf("fail to parse time: %s", err) + } + name = upperReplace.ReplaceAllString(parts[1], "$1 $2") // add spaces in the filename + name = allUpperWordsReplace.ReplaceAllString(name, "$1 $2") // add spaces in the filename + date = mt.Format("2006-01-02") + return +} + +// look for migration scripts with names in the form: +// +// XXX_descriptivename.ext +// +// where XXX specifies the version number +// and ext specifies the type of migration +func NumericComponent(name string) (int64, error) { + base := filepath.Base(name) + + if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { + return 0, errors.New("not a recognized migration file type") + } + + idx := strings.Index(base, "_") + if idx < 0 { + return 0, errors.New("no separator found") + } + + n, e := strconv.ParseInt(base[:idx], 10, 64) + if e == nil && n <= 0 { + return 0, errors.New("migration IDs must be greater than zero") + } + + return n, e +} + +func CreateMigration(name, migrationType, dir string, t time.Time) ([]string, error) { + if migrationType != "go" && migrationType != "sql" { + return nil, errors.New("migration type must be 'go' or 'sql'") + } + + timestamp := t.Format("20060102150405") + filename := fmt.Sprintf("%s_%s.%s", timestamp, name, migrationType) + + fpath := filepath.Join(dir, filename) + tmpl := sqlMigrationTemplate + if migrationType == "go" { + tmpl = goSqlMigrationTemplate + } + + var paths []string + + migrationPath, err := writeTemplateToFile(fpath, tmpl, timestamp) + if err != nil { + return nil, err + } + paths = append(paths, migrationPath) + + if migrationType == "go" { + fpath := strings.Replace(filepath.Join(dir, filename), ".go", "_test.go", 1) + migrationTestPath, err := writeTemplateToFile(fpath, goSqlMigrationTestTemplate, timestamp) + if err != nil { + return nil, err + } + paths = append(paths, migrationTestPath) + } + + return paths, nil +} + +// Update the version table for the given migration, +// and finalize the transaction. +func (c *Client) FinalizeMigration(tx *sql.Tx, direction bool, v int64) error { + // XXX: drop goose_db_version table on some minimum version number? + stmt := c.Dialect.insertVersionSql(c.TableName) + if _, err := tx.Exec(stmt, v, direction); err != nil { + tx.Rollback() //nolint:errcheck + return err + } + + return tx.Commit() +} + +var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(` +-- +goose Up +-- SQL in section 'Up' is executed when this migration is applied + + +-- +goose Down +-- SQL section 'Down' is executed when this migration is rolled back + +`)) + +var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(` +package tables + +import ( + "database/sql" +) + +func init() { + MigrationClient.AddMigration(Up_{{.}}, Down_{{.}}) +} + +func Up_{{.}}(tx *sql.Tx) error { + return nil +} + +func Down_{{.}}(tx *sql.Tx) error { + return nil +} +`)) + +var goSqlMigrationTestTemplate = template.Must(template.New("goose.go-migration").Parse(` +package tables + +import "testing" + +func TestUp_{{.}}(t *testing.T) { + db := applyUpToPrev(t) + + // + // Insert data to test the migration + // + // ... + + // Apply current migration. + applyNext(t, db) + + // + // Check data, insert new entries, e.g. to verify migration is safe. + // + // ... +}`)) diff --git a/server/goose/migration_sql.go b/server/goose/migration_sql.go new file mode 100644 index 000000000..a2d5bafce --- /dev/null +++ b/server/goose/migration_sql.go @@ -0,0 +1,165 @@ +package goose + +import ( + "bufio" + "bytes" + "database/sql" + "io" + "log" + "os" + "path/filepath" + "strings" +) + +const sqlCmdPrefix = "-- +goose " + +// Checks the line to see if the line has a statement-ending semicolon +// or if the line contains a double-dash comment. +func endsWithSemicolon(line string) bool { + prev := "" + scanner := bufio.NewScanner(strings.NewReader(line)) + scanner.Split(bufio.ScanWords) + + for scanner.Scan() { + word := scanner.Text() + if strings.HasPrefix(word, "--") { + break + } + prev = word + } + + return strings.HasSuffix(prev, ";") +} + +// Split the given sql script into individual statements. +// +// The base case is to simply split on semicolons, as these +// naturally terminate a statement. +// +// However, more complex cases like pl/pgsql can have semicolons +// within a statement. For these cases, we provide the explicit annotations +// 'StatementBegin' and 'StatementEnd' to allow the script to +// tell us to ignore semicolons. +func splitSQLStatements(r io.Reader, direction bool) (stmts []string) { + var buf bytes.Buffer + scanner := bufio.NewScanner(r) + + // track the count of each section + // so we can diagnose scripts with no annotations + upSections := 0 + downSections := 0 + + statementEnded := false + ignoreSemicolons := false + directionIsActive := false + + for scanner.Scan() { + + line := scanner.Text() + + // handle any goose-specific commands + if strings.HasPrefix(line, sqlCmdPrefix) { + cmd := strings.TrimSpace(line[len(sqlCmdPrefix):]) + switch cmd { + case "Up": + directionIsActive = (direction == true) + upSections++ + break + + case "Down": + directionIsActive = (direction == false) + downSections++ + break + + case "StatementBegin": + if directionIsActive { + ignoreSemicolons = true + } + break + + case "StatementEnd": + if directionIsActive { + statementEnded = (ignoreSemicolons == true) + ignoreSemicolons = false + } + break + } + } + + if !directionIsActive { + continue + } + + if _, err := buf.WriteString(line + "\n"); err != nil { + log.Fatalf("io err: %v", err) + } + + // Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement + // Lines that end with semicolon that are in a statement block + // do not conclude statement. + if (!ignoreSemicolons && endsWithSemicolon(line)) || statementEnded { + statementEnded = false + stmts = append(stmts, buf.String()) + buf.Reset() + } + } + + if err := scanner.Err(); err != nil { + log.Fatalf("scanning migration: %v", err) + } + + // diagnose likely migration script errors + if ignoreSemicolons { + log.Println("WARNING: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'") + } + + if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 { + log.Printf("WARNING: Unexpected unfinished SQL query: %s. Missing a semicolon?\n", bufferRemaining) + } + + if upSections == 0 && downSections == 0 { + log.Fatalf(`ERROR: no Up/Down annotations found, so no statements were executed. + See https://bitbucket.org/liamstask/goose/overview for details.`) + } + + return +} + +// Run a migration specified in raw SQL. +// +// Sections of the script can be annotated with a special comment, +// starting with "-- +goose" to specify whether the section should +// be applied during an Up or Down migration +// +// All statements following an Up or Down directive are grouped together +// until another direction directive is found. +func (c *Client) runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) error { + tx, err := db.Begin() + if err != nil { + log.Fatal("db.Begin:", err) + } + + f, err := os.Open(scriptFile) + if err != nil { + log.Fatal(err) + } + + // find each statement, checking annotations for up/down direction + // and execute each of them in the current transaction. + // Commits the transaction if successfully applied each statement and + // records the version into the version table or returns an error and + // rolls back the transaction. + for _, query := range splitSQLStatements(f, direction) { + if _, err = tx.Exec(query); err != nil { + tx.Rollback() //nolint:errcheck + log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(scriptFile), err) + return err + } + } + + if err = c.FinalizeMigration(tx, direction, v); err != nil { + log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(scriptFile), err) + } + + return nil +} diff --git a/server/goose/migration_sql_test.go b/server/goose/migration_sql_test.go new file mode 100644 index 000000000..5852960b1 --- /dev/null +++ b/server/goose/migration_sql_test.go @@ -0,0 +1,147 @@ +package goose + +import ( + "strings" + "testing" +) + +func TestSemicolons(t *testing.T) { + + type testData struct { + line string + result bool + } + + tests := []testData{ + { + line: "END;", + result: true, + }, + { + line: "END; -- comment", + result: true, + }, + { + line: "END ; -- comment", + result: true, + }, + { + line: "END -- comment", + result: false, + }, + { + line: "END -- comment ;", + result: false, + }, + { + line: "END \" ; \" -- comment", + result: false, + }, + } + + for _, test := range tests { + r := endsWithSemicolon(test.line) + if r != test.result { + t.Errorf("incorrect semicolon. got %v, want %v", r, test.result) + } + } +} + +func TestSplitStatements(t *testing.T) { + + type testData struct { + sql string + direction bool + count int + } + + tests := []testData{ + { + sql: functxt, + direction: true, + count: 2, + }, + { + sql: functxt, + direction: false, + count: 2, + }, + { + sql: multitxt, + direction: true, + count: 2, + }, + { + sql: multitxt, + direction: false, + count: 2, + }, + } + + for _, test := range tests { + stmts := splitSQLStatements(strings.NewReader(test.sql), test.direction) + if len(stmts) != test.count { + t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count) + } + } +} + +var functxt = `-- +goose Up +CREATE TABLE IF NOT EXISTS histories ( + id BIGSERIAL PRIMARY KEY, + current_value varchar(2000) NOT NULL, + created_at timestamp with time zone NOT NULL +); + +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) +returns void AS $$ +DECLARE + create_query text; +BEGIN + FOR create_query IN SELECT + 'CREATE TABLE IF NOT EXISTS histories_' + || TO_CHAR( d, 'YYYY_MM' ) + || ' ( CHECK( created_at >= timestamp ''' + || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) + || ''' AND created_at < timestamp ''' + || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) + || ''' ) ) inherits ( histories );' + FROM generate_series( $1, $2, '1 month' ) AS d + LOOP + EXECUTE create_query; + END LOOP; -- LOOP END +END; -- FUNCTION END +$$ +language plpgsql; +-- +goose StatementEnd + +-- +goose Down +drop function histories_partition_creation(DATE, DATE); +drop TABLE histories; +` + +// test multiple up/down transitions in a single script +var multitxt = `-- +goose Up +CREATE TABLE post ( + id int NOT NULL, + title text, + body text, + PRIMARY KEY(id) +); + +-- +goose Down +DROP TABLE post; + +-- +goose Up +CREATE TABLE fancier_post ( + id int NOT NULL, + title text, + body text, + created_on timestamp without time zone, + PRIMARY KEY(id) +); + +-- +goose Down +DROP TABLE fancier_post; +` diff --git a/server/goose/redo.go b/server/goose/redo.go new file mode 100644 index 000000000..bb1f8dd9f --- /dev/null +++ b/server/goose/redo.go @@ -0,0 +1,37 @@ +package goose + +import ( + "database/sql" +) + +func (c *Client) Redo(db *sql.DB, dir string) error { + currentVersion, err := c.GetDBVersion(db) + if err != nil { + return err + } + + migrations, err := c.collectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + + current, err := migrations.Current(currentVersion) + if err != nil { + return err + } + + previous, err := migrations.Next(currentVersion) + if err != nil { + return err + } + + if err := c.runMigration(db, previous, migrateUp); err != nil { + return err + } + + if err := c.runMigration(db, current, migrateUp); err != nil { + return err + } + + return nil +} diff --git a/server/goose/status.go b/server/goose/status.go new file mode 100644 index 000000000..bd7567105 --- /dev/null +++ b/server/goose/status.go @@ -0,0 +1,50 @@ +package goose + +import ( + "database/sql" + "fmt" + "log" + "path/filepath" + "time" +) + +func (c *Client) Status(db *sql.DB, dir string) error { + // collect all migrations + migrations, err := c.collectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + + // must ensure that the version table exists if we're running on a pristine DB + if _, err := c.GetDBVersion(db); err != nil { + return err + } + + fmt.Println(" Applied At Migration") + fmt.Println(" =======================================") + for _, migration := range migrations { + printMigrationStatus(db, migration.Version, filepath.Base(migration.Source)) + } + + return nil +} + +func printMigrationStatus(db *sql.DB, version int64, script string) { + var row MigrationRecord + q := fmt.Sprintf("SELECT tstamp, is_applied FROM goose_db_version WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version) + e := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied) + + if e != nil && e != sql.ErrNoRows { + log.Fatal(e) + } + + var appliedAt string + + if row.IsApplied { + appliedAt = row.TStamp.Format(time.ANSIC) + } else { + appliedAt = "Pending" + } + + fmt.Printf(" %-24s -- %v\n", appliedAt, script) +} diff --git a/server/goose/up.go b/server/goose/up.go new file mode 100644 index 000000000..f9da1aa59 --- /dev/null +++ b/server/goose/up.go @@ -0,0 +1,54 @@ +package goose + +import ( + "database/sql" +) + +func (c *Client) Up(db *sql.DB, dir string) error { + migrations, err := c.collectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + + for { + current, err := c.GetDBVersion(db) + if err != nil { + return err + } + + next, err := migrations.Next(current) + if err != nil { + if err == ErrNoNextVersion { + return nil + } + return err + } + + if err = c.runMigration(db, next, migrateUp); err != nil { + return err + } + } +} + +func (c *Client) UpByOne(db *sql.DB, dir string) error { + migrations, err := c.collectMigrations(dir, minVersion, maxVersion) + if err != nil { + return err + } + + currentVersion, err := c.GetDBVersion(db) + if err != nil { + return err + } + + next, err := migrations.Next(currentVersion) + if err != nil { + return err + } + + if err = c.runMigration(db, next, migrateUp); err != nil { + return err + } + + return nil +} diff --git a/server/goose/util.go b/server/goose/util.go new file mode 100644 index 000000000..151b365e3 --- /dev/null +++ b/server/goose/util.go @@ -0,0 +1,23 @@ +package goose + +import ( + "os" + "text/template" +) + +// common routines + +func writeTemplateToFile(path string, t *template.Template, data interface{}) (string, error) { + f, e := os.Create(path) + if e != nil { + return "", e + } + defer f.Close() + + e = t.Execute(f, data) + if e != nil { + return "", e + } + + return f.Name(), nil +} diff --git a/server/goose/version.go b/server/goose/version.go new file mode 100644 index 000000000..b9b0c18de --- /dev/null +++ b/server/goose/version.go @@ -0,0 +1,16 @@ +package goose + +import ( + "database/sql" + "fmt" +) + +func (c *Client) Version(db *sql.DB, dir string) error { + current, err := c.GetDBVersion(db) + if err != nil { + return err + } + + fmt.Printf("goose: dbversion %v\n", current) + return nil +} diff --git a/tools.go b/tools.go index 65e7fa706..43f546466 100644 --- a/tools.go +++ b/tools.go @@ -4,7 +4,7 @@ package tools import ( - _ "github.com/fleetdm/goose" + _ "github.com/fleetdm/fleet/v4/server/goose" _ "github.com/goreleaser/goreleaser" _ "github.com/kevinburke/go-bindata" _ "github.com/quasilyte/go-ruleguard/dsl"