mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
1269 lines
38 KiB
Go
1269 lines
38 KiB
Go
// Package mysql is a MySQL implementation of the Datastore interface.
|
|
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/VividCortex/mysqlerr"
|
|
"github.com/WatchBeam/clock"
|
|
"github.com/XSAM/otelsql"
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/doug-martin/goqu/v9"
|
|
"github.com/doug-martin/goqu/v9/exp"
|
|
"github.com/fleetdm/fleet/v4/server/config"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/goose"
|
|
"github.com/go-kit/kit/log"
|
|
"github.com/go-kit/kit/log/level"
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/jmoiron/sqlx"
|
|
nanodep_client "github.com/micromdm/nanodep/client"
|
|
nanodep_mysql "github.com/micromdm/nanodep/storage/mysql"
|
|
nanomdm_mysql "github.com/micromdm/nanomdm/storage/mysql"
|
|
scep_depot "github.com/micromdm/scep/v2/depot"
|
|
"github.com/ngrok/sqlmw"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
|
)
|
|
|
|
const (
|
|
defaultSelectLimit = 1000000
|
|
mySQLTimestampFormat = "2006-01-02 15:04:05" // %Y/%m/%d %H:%M:%S
|
|
)
|
|
|
|
// Matches all non-word and '-' characters for replacement
|
|
var columnCharsRegexp = regexp.MustCompile(`[^\w-.]`)
|
|
|
|
// dbReader is an interface that defines the methods required for reads.
|
|
type dbReader interface {
|
|
sqlx.QueryerContext
|
|
sqlx.PreparerContext
|
|
|
|
Close() error
|
|
Rebind(string) string
|
|
}
|
|
|
|
// Datastore is an implementation of fleet.Datastore interface backed by
|
|
// MySQL
|
|
type Datastore struct {
|
|
replica dbReader // so it cannot be used to perform writes
|
|
primary *sqlx.DB
|
|
|
|
logger log.Logger
|
|
clock clock.Clock
|
|
config config.MysqlConfig
|
|
|
|
// nil if no read replica
|
|
readReplicaConfig *config.MysqlConfig
|
|
|
|
// minimum interval between software last_opened_at timestamp to update the
|
|
// database (see file software.go).
|
|
minLastOpenedAtDiff time.Duration
|
|
|
|
writeCh chan itemToWrite
|
|
|
|
// stmtCacheMu protects access to stmtCache.
|
|
stmtCacheMu sync.Mutex
|
|
// stmtCache holds statements for queries.
|
|
stmtCache map[string]*sqlx.Stmt
|
|
|
|
// for tests, set to override the default batch size.
|
|
testDeleteMDMProfilesBatchSize int
|
|
// for tests, set to override the default batch size.
|
|
testUpsertMDMDesiredProfilesBatchSize int
|
|
}
|
|
|
|
// reader returns the DB instance to use for read-only statements, which is the
|
|
// replica unless the primary has been explicitly required via
|
|
// ctxdb.RequirePrimary.
|
|
func (ds *Datastore) reader(ctx context.Context) dbReader {
|
|
if ctxdb.IsPrimaryRequired(ctx) {
|
|
return ds.primary
|
|
}
|
|
return ds.replica
|
|
}
|
|
|
|
// writer returns the DB instance to use for write statements, which is always
|
|
// the primary.
|
|
func (ds *Datastore) writer(ctx context.Context) *sqlx.DB {
|
|
return ds.primary
|
|
}
|
|
|
|
// loadOrPrepareStmt will load a statement from the statements cache.
|
|
// If not available, it will attempt to prepare (create) it.
|
|
//
|
|
// Returns nil if it failed to prepare a statement.
|
|
func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.Stmt {
|
|
// the cache is only available on the replica
|
|
if ctxdb.IsPrimaryRequired(ctx) {
|
|
return nil
|
|
}
|
|
|
|
ds.stmtCacheMu.Lock()
|
|
defer ds.stmtCacheMu.Unlock()
|
|
|
|
stmt, ok := ds.stmtCache[query]
|
|
if !ok {
|
|
var err error
|
|
stmt, err = sqlx.PreparexContext(ctx, ds.replica, query)
|
|
if err != nil {
|
|
level.Error(ds.logger).Log(
|
|
"msg", "failed to prepare statement",
|
|
"query", query,
|
|
"err", err,
|
|
)
|
|
return nil
|
|
}
|
|
ds.stmtCache[query] = stmt
|
|
}
|
|
return stmt
|
|
}
|
|
|
|
// NewMDMAppleSCEPDepot returns a scep_depot.Depot that uses the Datastore
|
|
// underlying MySQL writer *sql.DB.
|
|
func (ds *Datastore) NewSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (scep_depot.Depot, error) {
|
|
return newSCEPDepot(ds.primary.DB, caCertPEM, caKeyPEM)
|
|
}
|
|
|
|
// NewMDMAppleMDMStorage returns a MySQL nanomdm storage that uses the Datastore
|
|
// underlying MySQL writer *sql.DB.
|
|
func (ds *Datastore) NewMDMAppleMDMStorage(pushCertPEM []byte, pushKeyPEM []byte) (*NanoMDMStorage, error) {
|
|
s, err := nanomdm_mysql.New(nanomdm_mysql.WithDB(ds.primary.DB))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &NanoMDMStorage{
|
|
MySQLStorage: s,
|
|
pushCertPEM: pushCertPEM,
|
|
pushKeyPEM: pushKeyPEM,
|
|
}, nil
|
|
}
|
|
|
|
// NanoMDMStorage wraps a *nanomdm_mysql.MySQLStorage and overrides further functionality.
|
|
type NanoMDMStorage struct {
|
|
*nanomdm_mysql.MySQLStorage
|
|
|
|
pushCertPEM []byte
|
|
pushKeyPEM []byte
|
|
}
|
|
|
|
// RetrievePushCert partially implements nanomdm_storage.PushCertStore.
|
|
//
|
|
// Always returns "0" as stale token because we are not storing the APNS in MySQL storage,
|
|
// and instead loading them at startup, thus the APNS will never be considered stale.
|
|
func (s *NanoMDMStorage) RetrievePushCert(
|
|
ctx context.Context, topic string,
|
|
) (cert *tls.Certificate, staleToken string, err error) {
|
|
tlsCert, err := tls.X509KeyPair(s.pushCertPEM, s.pushKeyPEM)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
return &tlsCert, "0", nil
|
|
}
|
|
|
|
// IsPushCertStale partially implements nanomdm_storage.PushCertStore.
|
|
//
|
|
// Given that we are not storing the APNS certificate in MySQL storage, and instead loading
|
|
// them at startup (as env variables), the APNS will never be considered stale.
|
|
//
|
|
// TODO(lucas): Revisit solution to support changing the APNS.
|
|
func (s *NanoMDMStorage) IsPushCertStale(ctx context.Context, topic, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
// StorePushCert partially implements nanomdm_storage.PushCertStore.
|
|
//
|
|
// Leaving this unimplemented as APNS certificate and key are not stored in MySQL storage,
|
|
// instead they are loaded to memory at startup.
|
|
func (s *NanoMDMStorage) StorePushCert(ctx context.Context, pemCert, pemKey []byte) error {
|
|
return errors.New("unimplemented")
|
|
}
|
|
|
|
// NewMDMAppleDEPStorage returns a MySQL nanodep storage that uses the Datastore
|
|
// underlying MySQL writer *sql.DB.
|
|
func (ds *Datastore) NewMDMAppleDEPStorage(tok nanodep_client.OAuth1Tokens) (*NanoDEPStorage, error) {
|
|
s, err := nanodep_mysql.New(nanodep_mysql.WithDB(ds.primary.DB))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &NanoDEPStorage{
|
|
MySQLStorage: s,
|
|
tokens: tok,
|
|
}, nil
|
|
}
|
|
|
|
// NanoDEPStorage wraps a *nanodep_mysql.MySQLStorage and overrides functionality to load
|
|
// DEP auth tokens from memory.
|
|
type NanoDEPStorage struct {
|
|
*nanodep_mysql.MySQLStorage
|
|
|
|
tokens nanodep_client.OAuth1Tokens
|
|
}
|
|
|
|
// RetrieveAuthTokens partially implements nanodep.AuthTokensRetriever.
|
|
//
|
|
// RetrieveAuthTokens returns the DEP auth tokens stored in memory.
|
|
func (s *NanoDEPStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
|
|
return &s.tokens, nil
|
|
}
|
|
|
|
// StoreAuthTokens partially implements nanodep.AuthTokensStorer.
|
|
//
|
|
// Leaving this unimplemented as DEP auth tokens are not stored in MySQL storage,
|
|
// instead they are loaded to memory at startup.
|
|
func (s *NanoDEPStorage) StoreAuthTokens(ctx context.Context, name string, tokens *nanodep_client.OAuth1Tokens) error {
|
|
return errors.New("unimplemented")
|
|
}
|
|
|
|
type txFn func(tx sqlx.ExtContext) error
|
|
|
|
type entity struct {
|
|
name string
|
|
}
|
|
|
|
var (
|
|
hostsTable = entity{"hosts"}
|
|
invitesTable = entity{"invites"}
|
|
packsTable = entity{"packs"}
|
|
queriesTable = entity{"queries"}
|
|
sessionsTable = entity{"sessions"}
|
|
usersTable = entity{"users"}
|
|
)
|
|
|
|
var doRetryErr = errors.New("fleet datastore retry")
|
|
|
|
// retryableError determines whether a MySQL error can be retried. By default
|
|
// errors are considered non-retryable. Only errors that we know have a
|
|
// possibility of succeeding on a retry should return true in this function.
|
|
func retryableError(err error) bool {
|
|
base := ctxerr.Cause(err)
|
|
if b, ok := base.(*mysql.MySQLError); ok {
|
|
switch b.Number {
|
|
// Consider lock related errors to be retryable
|
|
case mysqlerr.ER_LOCK_DEADLOCK, mysqlerr.ER_LOCK_WAIT_TIMEOUT:
|
|
return true
|
|
}
|
|
}
|
|
if errors.Is(err, doRetryErr) {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
|
func (ds *Datastore) withRetryTxx(ctx context.Context, fn txFn) (err error) {
|
|
operation := func() error {
|
|
tx, err := ds.writer(ctx).BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "create transaction")
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
ds.logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
|
|
}
|
|
panic(p)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
rbErr := tx.Rollback()
|
|
if rbErr != nil && rbErr != sql.ErrTxDone {
|
|
// Consider rollback errors to be non-retryable
|
|
return backoff.Permanent(ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error()))
|
|
}
|
|
|
|
if retryableError(err) {
|
|
return err
|
|
}
|
|
|
|
// Consider any other errors to be non-retryable
|
|
return backoff.Permanent(err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
err = ctxerr.Wrap(ctx, err, "commit transaction")
|
|
|
|
if retryableError(err) {
|
|
return err
|
|
}
|
|
|
|
return backoff.Permanent(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
bo := backoff.NewExponentialBackOff()
|
|
bo.MaxElapsedTime = 5 * time.Second
|
|
return backoff.Retry(operation, bo)
|
|
}
|
|
|
|
// withTx provides a common way to commit/rollback a txFn
|
|
func (ds *Datastore) withTx(ctx context.Context, fn txFn) (err error) {
|
|
tx, err := ds.writer(ctx).BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "create transaction")
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
ds.logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
|
|
}
|
|
panic(p)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
rbErr := tx.Rollback()
|
|
if rbErr != nil && rbErr != sql.ErrTxDone {
|
|
return ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error())
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "commit transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// New creates an MySQL datastore.
|
|
func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore, error) {
|
|
options := &dbOptions{
|
|
minLastOpenedAtDiff: defaultMinLastOpenedAtDiff,
|
|
maxAttempts: defaultMaxAttempts,
|
|
logger: log.NewNopLogger(),
|
|
}
|
|
|
|
for _, setOpt := range opts {
|
|
if setOpt != nil {
|
|
if err := setOpt(options); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := checkConfig(&config); err != nil {
|
|
return nil, err
|
|
}
|
|
if options.replicaConfig != nil {
|
|
if err := checkConfig(options.replicaConfig); err != nil {
|
|
return nil, fmt.Errorf("replica: %w", err)
|
|
}
|
|
}
|
|
|
|
dbWriter, err := newDB(&config, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dbReader := dbWriter
|
|
if options.replicaConfig != nil {
|
|
dbReader, err = newDB(options.replicaConfig, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
ds := &Datastore{
|
|
primary: dbWriter,
|
|
replica: dbReader,
|
|
logger: options.logger,
|
|
clock: c,
|
|
config: config,
|
|
readReplicaConfig: options.replicaConfig,
|
|
writeCh: make(chan itemToWrite),
|
|
stmtCache: make(map[string]*sqlx.Stmt),
|
|
minLastOpenedAtDiff: options.minLastOpenedAtDiff,
|
|
}
|
|
|
|
go ds.writeChanLoop()
|
|
|
|
return ds, nil
|
|
}
|
|
|
|
type itemToWrite struct {
|
|
ctx context.Context
|
|
errCh chan error
|
|
item interface{}
|
|
}
|
|
|
|
type hostXUpdatedAt struct {
|
|
hostID uint
|
|
updatedAt time.Time
|
|
what string
|
|
}
|
|
|
|
func (ds *Datastore) writeChanLoop() {
|
|
for item := range ds.writeCh {
|
|
switch actualItem := item.item.(type) {
|
|
case *fleet.Host:
|
|
item.errCh <- ds.UpdateHost(item.ctx, actualItem)
|
|
case hostXUpdatedAt:
|
|
query := fmt.Sprintf(`UPDATE hosts SET %s = ? WHERE id=?`, actualItem.what)
|
|
_, err := ds.writer(item.ctx).ExecContext(item.ctx, query, actualItem.updatedAt, actualItem.hostID)
|
|
item.errCh <- ctxerr.Wrap(item.ctx, err, "updating hosts label updated at")
|
|
}
|
|
}
|
|
}
|
|
|
|
var otelTracedDriverName string
|
|
|
|
func init() {
|
|
var err error
|
|
otelTracedDriverName, err = otelsql.Register("mysql", semconv.DBSystemMySQL.Value.AsString())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func newDB(conf *config.MysqlConfig, opts *dbOptions) (*sqlx.DB, error) {
|
|
driverName := "mysql"
|
|
if opts.tracingConfig != nil && opts.tracingConfig.TracingEnabled {
|
|
if opts.tracingConfig.TracingType == "opentelemetry" {
|
|
driverName = otelTracedDriverName
|
|
} else {
|
|
driverName = "apm/mysql"
|
|
}
|
|
}
|
|
if opts.interceptor != nil {
|
|
driverName = "mysql-mw"
|
|
sql.Register(driverName, sqlmw.Driver(mysql.MySQLDriver{}, opts.interceptor))
|
|
}
|
|
if opts.sqlMode != "" {
|
|
conf.SQLMode = opts.sqlMode
|
|
}
|
|
|
|
dsn := generateMysqlConnectionString(*conf)
|
|
db, err := sqlx.Open(driverName, dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db.SetMaxIdleConns(conf.MaxIdleConns)
|
|
db.SetMaxOpenConns(conf.MaxOpenConns)
|
|
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
|
|
|
|
var dbError error
|
|
for attempt := 0; attempt < opts.maxAttempts; attempt++ {
|
|
dbError = db.Ping()
|
|
if dbError == nil {
|
|
// we're connected!
|
|
break
|
|
}
|
|
interval := time.Duration(attempt) * time.Second
|
|
opts.logger.Log("mysql", fmt.Sprintf(
|
|
"could not connect to db: %v, sleeping %v", dbError, interval))
|
|
time.Sleep(interval)
|
|
}
|
|
|
|
if dbError != nil {
|
|
return nil, dbError
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func checkConfig(conf *config.MysqlConfig) error {
|
|
if conf.PasswordPath != "" && conf.Password != "" {
|
|
return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
|
}
|
|
|
|
// Check to see if the flag is populated
|
|
// Check if file exists on disk
|
|
// If file exists read contents
|
|
if conf.PasswordPath != "" {
|
|
fileContents, err := os.ReadFile(conf.PasswordPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conf.Password = strings.TrimSpace(string(fileContents))
|
|
}
|
|
|
|
if conf.TLSCA != "" {
|
|
conf.TLSConfig = "custom"
|
|
err := registerTLS(*conf)
|
|
if err != nil {
|
|
return fmt.Errorf("register TLS config for mysql: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) MigrateTables(ctx context.Context) error {
|
|
return tables.MigrationClient.Up(ds.writer(ctx).DB, "")
|
|
}
|
|
|
|
func (ds *Datastore) MigrateData(ctx context.Context) error {
|
|
return data.MigrationClient.Up(ds.writer(ctx).DB, "")
|
|
}
|
|
|
|
// loadMigrations manually loads the applied migrations in ascending
|
|
// order (goose doesn't provide such functionality).
|
|
//
|
|
// Returns two lists of version IDs (one for "table" and one for "data").
|
|
func (ds *Datastore) loadMigrations(
|
|
ctx context.Context,
|
|
writer *sql.DB,
|
|
reader dbReader,
|
|
) (tableRecs []int64, dataRecs []int64, err error) {
|
|
// We need to run the following to trigger the creation of the migration status tables.
|
|
_, err = tables.MigrationClient.GetDBVersion(writer)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
_, err = data.MigrationClient.GetDBVersion(writer)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
// version_id > 0 to skip the bootstrap migration that creates the migration tables.
|
|
if err := sqlx.SelectContext(ctx, reader, &tableRecs,
|
|
"SELECT version_id FROM "+tables.MigrationClient.TableName+" WHERE version_id > 0 AND is_applied ORDER BY id ASC",
|
|
); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if err := sqlx.SelectContext(ctx, reader, &dataRecs,
|
|
"SELECT version_id FROM "+data.MigrationClient.TableName+" WHERE version_id > 0 AND is_applied ORDER BY id ASC",
|
|
); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return tableRecs, dataRecs, nil
|
|
}
|
|
|
|
// MigrationStatus will return the current status of the migrations
|
|
// comparing the known migrations in code and the applied migrations in the database.
|
|
//
|
|
// It assumes some deployments may have performed migrations out of order.
|
|
func (ds *Datastore) MigrationStatus(ctx context.Context) (*fleet.MigrationStatus, error) {
|
|
if tables.MigrationClient.Migrations == nil || data.MigrationClient.Migrations == nil {
|
|
return nil, errors.New("unexpected nil migrations list")
|
|
}
|
|
appliedTable, appliedData, err := ds.loadMigrations(ctx, ds.primary.DB, ds.replica)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot load migrations: %w", err)
|
|
}
|
|
return compareMigrations(
|
|
tables.MigrationClient.Migrations,
|
|
data.MigrationClient.Migrations,
|
|
appliedTable,
|
|
appliedData,
|
|
), nil
|
|
}
|
|
|
|
// It assumes some deployments may have performed migrations out of order.
|
|
func compareMigrations(knownTable goose.Migrations, knownData goose.Migrations, appliedTable, appliedData []int64) *fleet.MigrationStatus {
|
|
if len(appliedTable) == 0 && len(appliedData) == 0 {
|
|
return &fleet.MigrationStatus{
|
|
StatusCode: fleet.NoMigrationsCompleted,
|
|
}
|
|
}
|
|
|
|
missingTable, unknownTable, equalTable := compareVersions(
|
|
getVersionsFromMigrations(knownTable),
|
|
appliedTable,
|
|
knownUnknownTableMigrations,
|
|
)
|
|
|
|
missingData, unknownData, equalData := compareVersions(
|
|
getVersionsFromMigrations(knownData),
|
|
appliedData,
|
|
knownUnknownDataMigrations,
|
|
)
|
|
|
|
if equalData && equalTable {
|
|
return &fleet.MigrationStatus{
|
|
StatusCode: fleet.AllMigrationsCompleted,
|
|
}
|
|
}
|
|
|
|
//
|
|
// The following code assumes there cannot be migrations missing on
|
|
// "table" and database being ahead on "data" (and vice-versa).
|
|
//
|
|
|
|
// Check for missing migrations first, as these are more important
|
|
// to detect than the unknown migrations.
|
|
if len(missingTable) > 0 || len(missingData) > 0 {
|
|
return &fleet.MigrationStatus{
|
|
StatusCode: fleet.SomeMigrationsCompleted,
|
|
MissingTable: missingTable,
|
|
MissingData: missingData,
|
|
}
|
|
}
|
|
|
|
// len(unknownTable) > 0 || len(unknownData) > 0
|
|
return &fleet.MigrationStatus{
|
|
StatusCode: fleet.UnknownMigrations,
|
|
UnknownTable: unknownTable,
|
|
UnknownData: unknownData,
|
|
}
|
|
}
|
|
|
|
var (
|
|
knownUnknownTableMigrations = map[int64]struct{}{
|
|
// This migration was introduced incorrectly in fleet-v4.4.0 and its
|
|
// timestamp was changed in fleet-v4.4.1.
|
|
20210924114500: {},
|
|
}
|
|
knownUnknownDataMigrations = map[int64]struct{}{
|
|
// This migration was present in 2.0.0, and was removed on a subsequent release.
|
|
// Was basically running `DELETE FROM packs WHERE deleted = 1`, (such `deleted`
|
|
// column doesn't exist anymore).
|
|
20171212182459: {},
|
|
// Deleted in
|
|
// https://github.com/fleetdm/fleet/commit/fd61dcab67f341c9e47fb6cb968171650c19a681
|
|
20161223115449: {},
|
|
20170309091824: {},
|
|
20171027173700: {},
|
|
20171212182458: {},
|
|
}
|
|
)
|
|
|
|
func unknownUnknowns(in []int64, knownUnknowns map[int64]struct{}) []int64 {
|
|
var result []int64
|
|
for _, t := range in {
|
|
if _, ok := knownUnknowns[t]; !ok {
|
|
result = append(result, t)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// compareVersions returns any missing or extra elements in v2 with respect to v1
|
|
// (v1 or v2 need not be ordered).
|
|
func compareVersions(v1, v2 []int64, knownUnknowns map[int64]struct{}) (missing []int64, unknown []int64, equal bool) {
|
|
v1s := make(map[int64]struct{})
|
|
for _, m := range v1 {
|
|
v1s[m] = struct{}{}
|
|
}
|
|
v2s := make(map[int64]struct{})
|
|
for _, m := range v2 {
|
|
v2s[m] = struct{}{}
|
|
}
|
|
for _, m := range v1 {
|
|
if _, ok := v2s[m]; !ok {
|
|
missing = append(missing, m)
|
|
}
|
|
}
|
|
for _, m := range v2 {
|
|
if _, ok := v1s[m]; !ok {
|
|
unknown = append(unknown, m)
|
|
}
|
|
}
|
|
unknown = unknownUnknowns(unknown, knownUnknowns)
|
|
if len(missing) == 0 && len(unknown) == 0 {
|
|
return nil, nil, true
|
|
}
|
|
return missing, unknown, false
|
|
}
|
|
|
|
func getVersionsFromMigrations(migrations goose.Migrations) []int64 {
|
|
versions := make([]int64, len(migrations))
|
|
for i := range migrations {
|
|
versions[i] = migrations[i].Version
|
|
}
|
|
return versions
|
|
}
|
|
|
|
// HealthCheck returns an error if the MySQL backend is not healthy.
|
|
func (ds *Datastore) HealthCheck() error {
|
|
// NOTE: does not receive a context as argument here, because the HealthCheck
|
|
// interface potentially affects more than the datastore layer, and I'm not
|
|
// sure we can safely identify and change them all at this moment.
|
|
if _, err := ds.primary.ExecContext(context.Background(), "select 1"); err != nil {
|
|
return err
|
|
}
|
|
if ds.readReplicaConfig != nil {
|
|
var dst int
|
|
if err := sqlx.GetContext(context.Background(), ds.replica, &dst, "select 1"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) closeStmts() error {
|
|
ds.stmtCacheMu.Lock()
|
|
defer ds.stmtCacheMu.Unlock()
|
|
|
|
var err error
|
|
for query, stmt := range ds.stmtCache {
|
|
if errClose := stmt.Close(); errClose != nil {
|
|
err = multierror.Append(err, errClose)
|
|
}
|
|
delete(ds.stmtCache, query)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Close frees resources associated with underlying mysql connection
|
|
func (ds *Datastore) Close() error {
|
|
var err error
|
|
if errStmt := ds.closeStmts(); errStmt != nil {
|
|
err = multierror.Append(err, errStmt)
|
|
}
|
|
if errWriter := ds.primary.Close(); errWriter != nil {
|
|
err = multierror.Append(err, errWriter)
|
|
}
|
|
if ds.readReplicaConfig != nil {
|
|
if errRead := ds.replica.Close(); errRead != nil {
|
|
err = multierror.Append(err, errRead)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
// sanitizeColumn is used to sanitize column names which can't be passed as placeholders when executing sql queries
|
|
func sanitizeColumn(col string) string {
|
|
col = columnCharsRegexp.ReplaceAllString(col, "")
|
|
oldParts := strings.Split(col, ".")
|
|
parts := oldParts[:0]
|
|
for _, p := range oldParts {
|
|
if len(p) != 0 {
|
|
parts = append(parts, p)
|
|
}
|
|
}
|
|
if len(parts) == 0 {
|
|
return ""
|
|
}
|
|
col = "`" + strings.Join(parts, "`.`") + "`"
|
|
return col
|
|
}
|
|
|
|
// appendListOptionsToSelect will apply the given list options to ds and
|
|
// return the new select dataset.
|
|
//
|
|
// NOTE: This is a copy of appendListOptionsToSQL that uses the goqu package.
|
|
func appendListOptionsToSelect(ds *goqu.SelectDataset, opts fleet.ListOptions) *goqu.SelectDataset {
|
|
ds = appendOrderByToSelect(ds, opts)
|
|
ds = appendLimitOffsetToSelect(ds, opts)
|
|
return ds
|
|
}
|
|
|
|
func appendOrderByToSelect(ds *goqu.SelectDataset, opts fleet.ListOptions) *goqu.SelectDataset {
|
|
if opts.OrderKey != "" {
|
|
ordersKeys := strings.Split(opts.OrderKey, ",")
|
|
for _, key := range ordersKeys {
|
|
ident := goqu.I(key)
|
|
|
|
var orderedExpr exp.OrderedExpression
|
|
if opts.OrderDirection == fleet.OrderDescending {
|
|
orderedExpr = ident.Desc()
|
|
} else {
|
|
orderedExpr = ident.Asc()
|
|
}
|
|
|
|
ds = ds.OrderAppend(orderedExpr)
|
|
}
|
|
}
|
|
|
|
return ds
|
|
}
|
|
|
|
func appendLimitOffsetToSelect(ds *goqu.SelectDataset, opts fleet.ListOptions) *goqu.SelectDataset {
|
|
perPage := opts.PerPage
|
|
// If caller doesn't supply a limit apply a reasonably large default limit
|
|
// to insure that an unbounded query with many results doesn't consume too
|
|
// much memory or hang
|
|
if perPage == 0 {
|
|
perPage = defaultSelectLimit
|
|
}
|
|
ds = ds.Limit(perPage)
|
|
|
|
offset := perPage * opts.Page
|
|
if offset > 0 {
|
|
ds = ds.Offset(offset)
|
|
}
|
|
return ds
|
|
}
|
|
|
|
// Appends the list options SQL to the passed in SQL string. This appended
|
|
// SQL is determined by the passed in options.
|
|
//
|
|
// NOTE: this method will mutate the options argument if no explicit PerPage
|
|
// option is set (a default value will be provided) or if the cursor approach is used.
|
|
func appendListOptionsToSQL(sql string, opts *fleet.ListOptions) (string, []interface{}) {
|
|
return appendListOptionsWithCursorToSQL(sql, nil, opts)
|
|
}
|
|
|
|
// Appends the list options SQL to the passed in SQL string. This appended
|
|
// SQL is determined by the passed in options. This supports cursor options
|
|
//
|
|
// NOTE: this method will mutate the options argument if no explicit PerPage option
|
|
// is set (a default value will be provided) or if the cursor approach is used.
|
|
func appendListOptionsWithCursorToSQL(sql string, params []interface{}, opts *fleet.ListOptions) (string, []interface{}) {
|
|
orderKey := sanitizeColumn(opts.OrderKey)
|
|
|
|
if opts.After != "" && orderKey != "" {
|
|
afterSql := " WHERE "
|
|
if strings.Contains(strings.ToLower(sql), "where") {
|
|
afterSql = " AND "
|
|
}
|
|
if strings.HasSuffix(orderKey, "id") {
|
|
i, _ := strconv.Atoi(opts.After)
|
|
params = append(params, i)
|
|
} else {
|
|
params = append(params, opts.After)
|
|
}
|
|
direction := ">" // ASC
|
|
if opts.OrderDirection == fleet.OrderDescending {
|
|
direction = "<" // DESC
|
|
}
|
|
sql = fmt.Sprintf("%s %s %s %s ?", sql, afterSql, orderKey, direction)
|
|
|
|
// After existing supersedes Page, so we disable it
|
|
opts.Page = 0
|
|
}
|
|
|
|
if orderKey != "" {
|
|
direction := "ASC"
|
|
if opts.OrderDirection == fleet.OrderDescending {
|
|
direction = "DESC"
|
|
}
|
|
|
|
sql = fmt.Sprintf("%s ORDER BY %s %s", sql, orderKey, direction)
|
|
}
|
|
// REVIEW: If caller doesn't supply a limit apply a default limit to insure
|
|
// that an unbounded query with many results doesn't consume too much memory
|
|
// or hang
|
|
if opts.PerPage == 0 {
|
|
opts.PerPage = defaultSelectLimit
|
|
}
|
|
|
|
perPage := opts.PerPage
|
|
if opts.IncludeMetadata {
|
|
perPage++
|
|
}
|
|
sql = fmt.Sprintf("%s LIMIT %d", sql, perPage)
|
|
|
|
offset := opts.PerPage * opts.Page
|
|
|
|
if offset > 0 {
|
|
sql = fmt.Sprintf("%s OFFSET %d", sql, offset)
|
|
}
|
|
|
|
return sql, params
|
|
}
|
|
|
|
// whereFilterHostsByTeams returns the appropriate condition to use in the WHERE
|
|
// clause to render only the appropriate teams.
|
|
//
|
|
// filter provides the filtering parameters that should be used. hostKey is the
|
|
// name/alias of the hosts table to use in generating the SQL.
|
|
func (ds *Datastore) whereFilterHostsByTeams(filter fleet.TeamFilter, hostKey string) string {
|
|
if filter.User == nil {
|
|
// This is likely unintentional, however we would like to return no
|
|
// results rather than panicking or returning some other error. At least
|
|
// log.
|
|
level.Info(ds.logger).Log("err", "team filter missing user")
|
|
return "FALSE"
|
|
}
|
|
|
|
defaultAllowClause := "TRUE"
|
|
if filter.TeamID != nil {
|
|
defaultAllowClause = fmt.Sprintf("%s.team_id = %d", hostKey, *filter.TeamID)
|
|
}
|
|
|
|
if filter.User.GlobalRole != nil {
|
|
switch *filter.User.GlobalRole {
|
|
case fleet.RoleAdmin, fleet.RoleMaintainer, fleet.RoleObserverPlus:
|
|
return defaultAllowClause
|
|
case fleet.RoleObserver:
|
|
if filter.IncludeObserver {
|
|
return defaultAllowClause
|
|
}
|
|
return "FALSE"
|
|
default:
|
|
// Fall through to specific teams
|
|
}
|
|
}
|
|
|
|
// Collect matching teams
|
|
var idStrs []string
|
|
var teamIDSeen bool
|
|
for _, team := range filter.User.Teams {
|
|
if team.Role == fleet.RoleAdmin ||
|
|
team.Role == fleet.RoleMaintainer ||
|
|
team.Role == fleet.RoleObserverPlus ||
|
|
(team.Role == fleet.RoleObserver && filter.IncludeObserver) {
|
|
idStrs = append(idStrs, strconv.Itoa(int(team.ID)))
|
|
if filter.TeamID != nil && *filter.TeamID == team.ID {
|
|
teamIDSeen = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(idStrs) == 0 {
|
|
// User has no global role and no teams allowed by includeObserver.
|
|
return "FALSE"
|
|
}
|
|
|
|
if filter.TeamID != nil {
|
|
if teamIDSeen {
|
|
// all good, this user has the right to see the requested team
|
|
return defaultAllowClause
|
|
}
|
|
return "FALSE"
|
|
}
|
|
|
|
return fmt.Sprintf("%s.team_id IN (%s)", hostKey, strings.Join(idStrs, ","))
|
|
}
|
|
|
|
// whereFilterTeams returns the appropriate condition to use in the WHERE
|
|
// clause to render only the appropriate teams.
|
|
//
|
|
// filter provides the filtering parameters that should be used. hostKey is the
|
|
// name/alias of the teams table to use in generating the SQL.
|
|
func (ds *Datastore) whereFilterTeams(filter fleet.TeamFilter, teamKey string) string {
|
|
if filter.User == nil {
|
|
// This is likely unintentional, however we would like to return no
|
|
// results rather than panicking or returning some other error. At least
|
|
// log.
|
|
level.Info(ds.logger).Log("err", "team filter missing user")
|
|
return "FALSE"
|
|
}
|
|
|
|
if filter.User.GlobalRole != nil {
|
|
switch *filter.User.GlobalRole {
|
|
case fleet.RoleAdmin, fleet.RoleMaintainer, fleet.RoleObserverPlus:
|
|
return "TRUE"
|
|
case fleet.RoleObserver:
|
|
if filter.IncludeObserver {
|
|
return "TRUE"
|
|
}
|
|
return "FALSE"
|
|
default:
|
|
// Fall through to specific teams
|
|
}
|
|
}
|
|
|
|
// Collect matching teams
|
|
var idStrs []string
|
|
for _, team := range filter.User.Teams {
|
|
if team.Role == fleet.RoleAdmin ||
|
|
team.Role == fleet.RoleMaintainer ||
|
|
team.Role == fleet.RoleObserverPlus ||
|
|
(team.Role == fleet.RoleObserver && filter.IncludeObserver) {
|
|
idStrs = append(idStrs, strconv.Itoa(int(team.ID)))
|
|
}
|
|
}
|
|
|
|
if len(idStrs) == 0 {
|
|
// User has no global role and no teams allowed by includeObserver.
|
|
return "FALSE"
|
|
}
|
|
|
|
return fmt.Sprintf("%s.id IN (%s)", teamKey, strings.Join(idStrs, ","))
|
|
}
|
|
|
|
// whereOmitIDs returns the appropriate condition to use in the WHERE
|
|
// clause to omit the provided IDs from the selection.
|
|
func (ds *Datastore) whereOmitIDs(colName string, omit []uint) string {
|
|
if len(omit) == 0 {
|
|
return "TRUE"
|
|
}
|
|
|
|
var idStrs []string
|
|
for _, id := range omit {
|
|
idStrs = append(idStrs, strconv.Itoa(int(id)))
|
|
}
|
|
|
|
return fmt.Sprintf("%s NOT IN (%s)", colName, strings.Join(idStrs, ","))
|
|
}
|
|
|
|
// registerTLS adds client certificate configuration to the mysql connection.
|
|
func registerTLS(conf config.MysqlConfig) error {
|
|
tlsCfg := config.TLS{
|
|
TLSCert: conf.TLSCert,
|
|
TLSKey: conf.TLSKey,
|
|
TLSCA: conf.TLSCA,
|
|
TLSServerName: conf.TLSServerName,
|
|
}
|
|
cfg, err := tlsCfg.ToTLSConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := mysql.RegisterTLSConfig(conf.TLSConfig, cfg); err != nil {
|
|
return fmt.Errorf("register mysql tls config: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// generateMysqlConnectionString returns a MySQL connection string using the
|
|
// provided configuration.
|
|
func generateMysqlConnectionString(conf config.MysqlConfig) string {
|
|
params := url.Values{
|
|
// using collation implicitly sets the charset too
|
|
// and it's the recommended way to do it per the
|
|
// driver documentation:
|
|
// https://github.com/go-sql-driver/mysql#charset
|
|
"collation": []string{"utf8mb4_unicode_ci"},
|
|
"parseTime": []string{"true"},
|
|
"loc": []string{"UTC"},
|
|
"time_zone": []string{"'-00:00'"},
|
|
"clientFoundRows": []string{"true"},
|
|
"allowNativePasswords": []string{"true"},
|
|
"group_concat_max_len": []string{"4194304"},
|
|
"multiStatements": []string{"true"},
|
|
}
|
|
if conf.TLSConfig != "" {
|
|
params.Set("tls", conf.TLSConfig)
|
|
}
|
|
if conf.SQLMode != "" {
|
|
params.Set("sql_mode", conf.SQLMode)
|
|
}
|
|
|
|
dsn := fmt.Sprintf(
|
|
"%s:%s@%s(%s)/%s?%s",
|
|
conf.Username,
|
|
conf.Password,
|
|
conf.Protocol,
|
|
conf.Address,
|
|
conf.Database,
|
|
params.Encode(),
|
|
)
|
|
|
|
return dsn
|
|
}
|
|
|
|
// isForeignKeyError checks if the provided error is a MySQL child foreign key
|
|
// error (Error #1452)
|
|
func isChildForeignKeyError(err error) bool {
|
|
err = ctxerr.Cause(err)
|
|
mysqlErr, ok := err.(*mysql.MySQLError)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// https://dev.mysql.com/doc/refman/5.7/en/error-messages-server.html#error_er_no_referenced_row_2
|
|
const ER_NO_REFERENCED_ROW_2 = 1452
|
|
return mysqlErr.Number == ER_NO_REFERENCED_ROW_2
|
|
}
|
|
|
|
type patternReplacer func(string) string
|
|
|
|
// likePattern returns a pattern to match m with LIKE.
|
|
func likePattern(m string) string {
|
|
m = strings.Replace(m, "_", "\\_", -1)
|
|
m = strings.Replace(m, "%", "\\%", -1)
|
|
return "%" + m + "%"
|
|
}
|
|
|
|
// noneReplacer doesn't manipulate
|
|
func noneReplacer(m string) string {
|
|
return m
|
|
}
|
|
|
|
// searchLike adds SQL and parameters for a "search" using LIKE syntax.
|
|
//
|
|
// The input columns must be sanitized if they are provided by the user.
|
|
func searchLike(sql string, params []interface{}, match string, columns ...string) (string, []interface{}) {
|
|
return searchLikePattern(sql, params, match, likePattern, columns...)
|
|
}
|
|
|
|
func searchLikePattern(sql string, params []interface{}, match string, replacer patternReplacer, columns ...string) (string, []interface{}) {
|
|
if len(columns) == 0 || len(match) == 0 {
|
|
return sql, params
|
|
}
|
|
|
|
pattern := replacer(match)
|
|
ors := make([]string, 0, len(columns))
|
|
for _, column := range columns {
|
|
ors = append(ors, column+" LIKE ?")
|
|
params = append(params, pattern)
|
|
}
|
|
|
|
sql += " AND (" + strings.Join(ors, " OR ") + ")"
|
|
return sql, params
|
|
}
|
|
|
|
// very loosely checks that a string looks like an email:
|
|
// has no spaces, a single @ character, a part before the @,
|
|
// a part after the @, the part after has at least one dot
|
|
// with something after the dot. I don't think this is perfectly
|
|
// correct as the email format allows any chars including spaces
|
|
// when inside double quotes, but this is an edge case that is
|
|
// unlikely to matter much in practice. Another option that would
|
|
// definitely not cut out any valid address is to just check for
|
|
// the presence of @, which is arguably the most important check
|
|
// in this.
|
|
var rxLooseEmail = regexp.MustCompile(`^[^\s@]+@[^\s@\.]+\..+$`)
|
|
|
|
/*
|
|
This regex matches any occurrence of a character from the ASCII character set followed by one or more characters that are not from the ASCII character set.
|
|
The first part `[[:ascii:]]` matches any character that is within the ASCII range (0 to 127 in the ASCII table),
|
|
while the second part `[^[:ascii:]]` matches any character that is not within the ASCII range.
|
|
So, when these two parts are combined with no space in between, the resulting regex matches any
|
|
sequence of characters where the first character is within the ASCII range and the following characters are not within the ASCII range.
|
|
*/
|
|
var (
|
|
nonascii = regexp.MustCompile(`(?P<ascii>[[:ascii:]])(?P<nonascii>[^[:ascii:]]+)`)
|
|
nonacsiiReplace = regexp.MustCompile(`[^[:ascii:]]`)
|
|
)
|
|
|
|
func hostSearchLike(sql string, params []interface{}, match string, columns ...string) (string, []interface{}, bool) {
|
|
var matchesEmail bool
|
|
base, args := searchLike(sql, params, match, columns...)
|
|
|
|
// special-case for hosts: if match looks like an email address, add searching
|
|
// in host_emails table as an option, in addition to the provided columns.
|
|
if rxLooseEmail.MatchString(match) {
|
|
matchesEmail = true
|
|
// remove the closing paren and add the email condition to the list
|
|
base = strings.TrimSuffix(base, ")") + " OR (" + ` EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))`
|
|
args = append(args, likePattern(match))
|
|
}
|
|
return base, args, matchesEmail
|
|
}
|
|
|
|
func hostSearchLikeAny(sql string, params []interface{}, match string, columns ...string) (string, []interface{}) {
|
|
return searchLikePattern(sql, params, buildWildcardMatchPhrase(match), noneReplacer, columns...)
|
|
}
|
|
|
|
func buildWildcardMatchPhrase(matchQuery string) string {
|
|
return replaceMatchAny(likePattern(matchQuery))
|
|
}
|
|
|
|
func hasNonASCIIRegex(s string) bool {
|
|
return nonascii.MatchString(s)
|
|
}
|
|
|
|
func replaceMatchAny(s string) string {
|
|
return nonacsiiReplace.ReplaceAllString(s, "_")
|
|
}
|
|
|
|
func (ds *Datastore) InnoDBStatus(ctx context.Context) (string, error) {
|
|
status := struct {
|
|
Type string `db:"Type"`
|
|
Name string `db:"Name"`
|
|
Status string `db:"Status"`
|
|
}{}
|
|
// using the writer even when doing a read to get the data from the main db node
|
|
err := ds.writer(ctx).GetContext(ctx, &status, "show engine innodb status")
|
|
if err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "Getting innodb status")
|
|
}
|
|
return status.Status, nil
|
|
}
|
|
|
|
func (ds *Datastore) ProcessList(ctx context.Context) ([]fleet.MySQLProcess, error) {
|
|
var processList []fleet.MySQLProcess
|
|
// using the writer even when doing a read to get the data from the main db node
|
|
err := ds.writer(ctx).SelectContext(ctx, &processList, "show processlist")
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "Getting process list")
|
|
}
|
|
return processList, nil
|
|
}
|
|
|
|
func insertOnDuplicateDidInsert(res sql.Result) bool {
|
|
// Note that connection string sets CLIENT_FOUND_ROWS (see
|
|
// generateMysqlConnectionString in this package), so LastInsertId is 0
|
|
// and RowsAffected 1 when a row is set to its current values.
|
|
//
|
|
// See [the docs][1] or @mna's comment in `insertOnDuplicateDidUpdate`
|
|
// below for more details
|
|
//
|
|
// [1]: https://dev.mysql.com/doc/refman/5.7/en/insert-on-duplicate.html
|
|
lastID, _ := res.LastInsertId()
|
|
affected, _ := res.RowsAffected()
|
|
return lastID != 0 && affected == 1
|
|
}
|
|
|
|
func insertOnDuplicateDidUpdate(res sql.Result) bool {
|
|
// From mysql's documentation:
|
|
//
|
|
// With ON DUPLICATE KEY UPDATE, the affected-rows value per row is 1 if
|
|
// the row is inserted as a new row, 2 if an existing row is updated, and
|
|
// 0 if an existing row is set to its current values. If you specify the
|
|
// CLIENT_FOUND_ROWS flag to the mysql_real_connect() C API function when
|
|
// connecting to mysqld, the affected-rows value is 1 (not 0) if an
|
|
// existing row is set to its current values.
|
|
//
|
|
// https://dev.mysql.com/doc/refman/5.7/en/insert-on-duplicate.html
|
|
//
|
|
// Note that connection string sets CLIENT_FOUND_ROWS (see
|
|
// generateMysqlConnectionString in this package), so it does return 1 when
|
|
// an existing row is set to its current values, but with a last inserted id
|
|
// of 0.
|
|
//
|
|
// Also note that with our mysql driver, Result.LastInsertId and
|
|
// Result.RowsAffected can never return an error, they are retrieved at the
|
|
// time of the Exec call, and the result simply returns the integers it
|
|
// already holds:
|
|
// https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/result.go
|
|
//
|
|
// TODO(mna): would that work on mariadb too?
|
|
|
|
lastID, _ := res.LastInsertId()
|
|
aff, _ := res.RowsAffected()
|
|
return lastID == 0 || aff != 1
|
|
}
|
|
|
|
type parameterizedStmt struct {
|
|
Statement string
|
|
Args []interface{}
|
|
}
|
|
|
|
// optimisticGetOrInsert encodes an efficient pattern of looking up a row's ID
|
|
// for a unique key that is more likely to already exist (i.e. the insert
|
|
// should be infrequent, the read should succeed most of the time).
|
|
// It proceeds as follows:
|
|
// 1. Try to read the ID from the read replica.
|
|
// 2. If it does not exist, try to insert the row in the primary.
|
|
// 3. If it fails due to a duplicate key, try to read the ID again, this
|
|
// time from the primary.
|
|
//
|
|
// The read statement must only SELECT the id column.
|
|
func (ds *Datastore) optimisticGetOrInsert(ctx context.Context, readStmt, insertStmt *parameterizedStmt) (id uint, err error) {
|
|
readID := func(q sqlx.QueryerContext) (uint, error) {
|
|
var id uint
|
|
err := sqlx.GetContext(ctx, q, &id, readStmt.Statement, readStmt.Args...)
|
|
return id, err
|
|
}
|
|
|
|
// 1. read from the read replica, as it is likely to already exist
|
|
id, err = readID(ds.reader(ctx))
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
// this does not exist yet, try to insert it
|
|
res, err := ds.writer(ctx).ExecContext(ctx, insertStmt.Statement, insertStmt.Args...)
|
|
if err != nil {
|
|
if isDuplicate(err) {
|
|
// it might've been created between the select and the insert, read
|
|
// again this time from the primary database connection.
|
|
id, err := readID(ds.writer(ctx))
|
|
if err != nil {
|
|
return 0, ctxerr.Wrap(ctx, err, "get id from writer")
|
|
}
|
|
return id, nil
|
|
}
|
|
return 0, ctxerr.Wrap(ctx, err, "insert")
|
|
}
|
|
id, _ := res.LastInsertId()
|
|
return uint(id), nil
|
|
}
|
|
return 0, ctxerr.Wrap(ctx, err, "get id from reader")
|
|
}
|
|
return id, nil
|
|
}
|