Add support for MySQL read replicas (#1786)

Adds configuration options to use a read-only MySQL replica,
and uses it instead of the primary for reads.
This commit is contained in:
Martin Angers 2021-09-01 15:50:52 -04:00 committed by GitHub
parent 2116217b6e
commit 3755a58070
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 630 additions and 319 deletions

View File

@ -0,0 +1 @@
* Add support for MySQL read replica.

View File

@ -149,7 +149,11 @@ the way that the Fleet server works.
var carveStore fleet.CarveStore var carveStore fleet.CarveStore
mailService := mail.NewService() mailService := mail.NewService()
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger)) var replicaOpt mysql.DBOption
if config.MysqlReadReplica.Address != "" {
replicaOpt = mysql.Replica(&config.MysqlReadReplica)
}
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger), replicaOpt)
if err != nil { if err != nil {
initFatal(err, "initializing datastore") initFatal(err, "initializing datastore")
} }

View File

@ -25,7 +25,7 @@ Your Fleet server's two main purposes are:
- To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/) - To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/)
- To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc. - To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc.
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations. The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. It is also possible to configure connection to a MySQL replica in addition to the primary, to be used for reading only. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
> Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance. > Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance.
@ -118,12 +118,21 @@ mysql:
address: 127.0.0.1:3306 address: 127.0.0.1:3306
``` ```
And `mysql_read_replica_address` would be:
```
mysql_read_replica:
address: 127.0.0.1:3307
```
Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way. Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way.
All duration-based settings accept valid time units of `s`, `m`, `h`. All duration-based settings accept valid time units of `s`, `m`, `h`.
##### MySQL ##### MySQL
This section describes the configuration options for the primary - if you also want to setup a read replica, the options are the same, except that the yaml section is `mysql_read_replica`, and the flags have the `mysql_read_replica_` prefix instead of `mysql_` (the corresponding environment variables follow the same transformation). Note that there is no default value for `mysql_read_replica_address`, it must be set explicitly for fleet to use a read replica.
###### `mysql_address` ###### `mysql_address`
The address of the MySQL server which Fleet should connect to. Include the hostname and port. The address of the MySQL server which Fleet should connect to. Include the hostname and port.

View File

@ -177,6 +177,7 @@ type VulnerabilitiesConfig struct {
// updated to set and retrieve the configurations as appropriate. // updated to set and retrieve the configurations as appropriate.
type FleetConfig struct { type FleetConfig struct {
Mysql MysqlConfig Mysql MysqlConfig
MysqlReadReplica MysqlConfig `yaml:"mysql_read_replica"`
Redis RedisConfig Redis RedisConfig
Server ServerConfig Server ServerConfig
Auth AuthConfig Auth AuthConfig
@ -197,32 +198,36 @@ type FleetConfig struct {
// addConfigs adds the configuration keys and default values that will be // addConfigs adds the configuration keys and default values that will be
// filled into the FleetConfig struct // filled into the FleetConfig struct
func (man Manager) addConfigs() { func (man Manager) addConfigs() {
addMysqlConfig := func(prefix, defaultAddr, usageSuffix string) {
man.addConfigString(prefix+".protocol", "tcp",
"MySQL server communication protocol (tcp,unix,...)"+usageSuffix)
man.addConfigString(prefix+".address", defaultAddr,
"MySQL server address (host:port)"+usageSuffix)
man.addConfigString(prefix+".username", "fleet",
"MySQL server username"+usageSuffix)
man.addConfigString(prefix+".password", "",
"MySQL server password (prefer env variable for security)"+usageSuffix)
man.addConfigString(prefix+".password_path", "",
"Path to file containg MySQL server password"+usageSuffix)
man.addConfigString(prefix+".database", "fleet",
"MySQL database name"+usageSuffix)
man.addConfigString(prefix+".tls_cert", "",
"MySQL TLS client certificate path"+usageSuffix)
man.addConfigString(prefix+".tls_key", "",
"MySQL TLS client key path"+usageSuffix)
man.addConfigString(prefix+".tls_ca", "",
"MySQL TLS server CA"+usageSuffix)
man.addConfigString(prefix+".tls_server_name", "",
"MySQL TLS server name"+usageSuffix)
man.addConfigString(prefix+".tls_config", "",
"MySQL TLS config value"+usageSuffix+" Use skip-verify, true, false or custom key.")
man.addConfigInt(prefix+".max_open_conns", 50, "MySQL maximum open connection handles"+usageSuffix)
man.addConfigInt(prefix+".max_idle_conns", 50, "MySQL maximum idle connection handles"+usageSuffix)
man.addConfigInt(prefix+".conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused"+usageSuffix)
}
// MySQL // MySQL
man.addConfigString("mysql.protocol", "tcp", addMysqlConfig("mysql", "localhost:3306", ".")
"MySQL server communication protocol (tcp,unix,...)") addMysqlConfig("mysql_read_replica", "", " for the read replica.")
man.addConfigString("mysql.address", "localhost:3306",
"MySQL server address (host:port)")
man.addConfigString("mysql.username", "fleet",
"MySQL server username")
man.addConfigString("mysql.password", "",
"MySQL server password (prefer env variable for security)")
man.addConfigString("mysql.password_path", "",
"Path to file containg MySQL server password")
man.addConfigString("mysql.database", "fleet",
"MySQL database name")
man.addConfigString("mysql.tls_cert", "",
"MySQL TLS client certificate path")
man.addConfigString("mysql.tls_key", "",
"MySQL TLS client key path")
man.addConfigString("mysql.tls_ca", "",
"MySQL TLS server CA")
man.addConfigString("mysql.tls_server_name", "",
"MySQL TLS server name")
man.addConfigString("mysql.tls_config", "",
"MySQL TLS config value. Use skip-verify, true, false or custom key.")
man.addConfigInt("mysql.max_open_conns", 50, "MySQL maximum open connection handles.")
man.addConfigInt("mysql.max_idle_conns", 50, "MySQL maximum idle connection handles.")
man.addConfigInt("mysql.conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused.")
// Redis // Redis
man.addConfigString("redis.address", "localhost:6379", man.addConfigString("redis.address", "localhost:6379",
@ -382,23 +387,28 @@ func (man Manager) addConfigs() {
func (man Manager) LoadConfig() FleetConfig { func (man Manager) LoadConfig() FleetConfig {
man.loadConfigFile() man.loadConfigFile()
loadMysqlConfig := func(prefix string) MysqlConfig {
return MysqlConfig{
Protocol: man.getConfigString(prefix + ".protocol"),
Address: man.getConfigString(prefix + ".address"),
Username: man.getConfigString(prefix + ".username"),
Password: man.getConfigString(prefix + ".password"),
PasswordPath: man.getConfigString(prefix + ".password_path"),
Database: man.getConfigString(prefix + ".database"),
TLSCert: man.getConfigString(prefix + ".tls_cert"),
TLSKey: man.getConfigString(prefix + ".tls_key"),
TLSCA: man.getConfigString(prefix + ".tls_ca"),
TLSServerName: man.getConfigString(prefix + ".tls_server_name"),
TLSConfig: man.getConfigString(prefix + ".tls_config"),
MaxOpenConns: man.getConfigInt(prefix + ".max_open_conns"),
MaxIdleConns: man.getConfigInt(prefix + ".max_idle_conns"),
ConnMaxLifetime: man.getConfigInt(prefix + ".conn_max_lifetime"),
}
}
return FleetConfig{ return FleetConfig{
Mysql: MysqlConfig{ Mysql: loadMysqlConfig("mysql"),
Protocol: man.getConfigString("mysql.protocol"), MysqlReadReplica: loadMysqlConfig("mysql_read_replica"),
Address: man.getConfigString("mysql.address"),
Username: man.getConfigString("mysql.username"),
Password: man.getConfigString("mysql.password"),
PasswordPath: man.getConfigString("mysql.password_path"),
Database: man.getConfigString("mysql.database"),
TLSCert: man.getConfigString("mysql.tls_cert"),
TLSKey: man.getConfigString("mysql.tls_key"),
TLSCA: man.getConfigString("mysql.tls_ca"),
TLSServerName: man.getConfigString("mysql.tls_server_name"),
TLSConfig: man.getConfigString("mysql.tls_config"),
MaxOpenConns: man.getConfigInt("mysql.max_open_conns"),
MaxIdleConns: man.getConfigInt("mysql.max_idle_conns"),
ConnMaxLifetime: man.getConfigInt("mysql.conn_max_lifetime"),
},
Redis: RedisConfig{ Redis: RedisConfig{
Address: man.getConfigString("redis.address"), Address: man.getConfigString("redis.address"),
Password: man.getConfigString("redis.password"), Password: man.getConfigString("redis.password"),

View File

@ -58,6 +58,7 @@ func TestConfigRoundtrip(t *testing.T) {
// Marshal the generated config // Marshal the generated config
buf, err := yaml.Marshal(original) buf, err := yaml.Marshal(original)
require.Nil(t, err) require.Nil(t, err)
t.Log(string(buf))
// Manually load the serialized config // Manually load the serialized config
man.viper.SetConfigType("yaml") man.viper.SetConfigType("yaml")

View File

@ -3,6 +3,7 @@ package mysql
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/fleet"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -13,7 +14,7 @@ func (d *Datastore) NewActivity(user *fleet.User, activityType string, details *
if err != nil { if err != nil {
return errors.Wrap(err, "marshaling activity details") return errors.Wrap(err, "marshaling activity details")
} }
_, err = d.db.Exec( _, err = d.writer.Exec(
`INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`, `INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`,
user.ID, user.ID,
user.Name, user.Name,
@ -34,7 +35,7 @@ func (d *Datastore) ListActivities(opt fleet.ListOptions) ([]*fleet.Activity, er
WHERE true` WHERE true`
query = appendListOptionsToSQL(query, opt) query = appendListOptionsToSQL(query, opt)
err := d.db.Select(&activities, query) err := d.reader.Select(&activities, query)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Activity") return nil, notFound("Activity")
} else if err != nil { } else if err != nil {

View File

@ -25,7 +25,7 @@ func (d *Datastore) NewAppConfig(info *fleet.AppConfig) (*fleet.AppConfig, error
func (d *Datastore) AppConfig() (*fleet.AppConfig, error) { func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
info := &fleet.AppConfig{} info := &fleet.AppConfig{}
var bytes []byte var bytes []byte
err := d.db.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`) err := d.reader.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "selecting app config") return nil, errors.Wrap(err, "selecting app config")
} }
@ -43,7 +43,7 @@ func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
} }
func (d *Datastore) isEventSchedulerEnabled() (bool, error) { func (d *Datastore) isEventSchedulerEnabled() (bool, error) {
rows, err := d.db.Query("SELECT @@event_scheduler") rows, err := d.writer.Query("SELECT @@event_scheduler")
if err != nil { if err != nil {
return false, err return false, err
} }
@ -136,7 +136,7 @@ func (d *Datastore) SaveAppConfig(info *fleet.AppConfig) error {
func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) { func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) {
var s fleet.EnrollSecret var s fleet.EnrollSecret
err := d.db.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret) err := d.reader.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
if err != nil { if err != nil {
return nil, errors.New("no matching secret found") return nil, errors.New("no matching secret found")
} }
@ -184,7 +184,7 @@ func (d *Datastore) GetEnrollSecrets(teamID *uint) ([]*fleet.EnrollSecret, error
args = append(args, teamID) args = append(args, teamID)
} }
var secrets []*fleet.EnrollSecret var secrets []*fleet.EnrollSecret
if err := d.db.Select(&secrets, sql, args...); err != nil { if err := d.reader.Select(&secrets, sql, args...); err != nil {
return nil, errors.Wrap(err, "get secrets") return nil, errors.Wrap(err, "get secrets")
} }
return secrets, nil return secrets, nil

View File

@ -259,7 +259,7 @@ func TestAppConfigDefaults(t *testing.T) {
defer ds.Close() defer ds.Close()
insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)` insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)`
_, err := ds.db.Exec(insertAppConfigQuery, `{}`) _, err := ds.writer.Exec(insertAppConfigQuery, `{}`)
require.NoError(t, err) require.NoError(t, err)
ac, err := ds.AppConfig() ac, err := ds.AppConfig()
@ -269,7 +269,7 @@ func TestAppConfigDefaults(t *testing.T) {
require.True(t, ac.HostSettings.EnableHostUsers) require.True(t, ac.HostSettings.EnableHostUsers)
require.False(t, ac.HostSettings.EnableSoftwareInventory) require.False(t, ac.HostSettings.EnableSoftwareInventory)
_, err = ds.db.Exec( _, err = ds.writer.Exec(
insertAppConfigQuery, insertAppConfigQuery,
`{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`, `{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`,
) )

View File

@ -17,7 +17,7 @@ func (d *Datastore) NewDistributedQueryCampaign(camp *fleet.DistributedQueryCamp
) )
VALUES(?,?,?) VALUES(?,?,?)
` `
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID) result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "inserting distributed query campaign") return nil, errors.Wrap(err, "inserting distributed query campaign")
} }
@ -32,7 +32,7 @@ func (d *Datastore) DistributedQueryCampaign(id uint) (*fleet.DistributedQueryCa
SELECT * FROM distributed_query_campaigns WHERE id = ? SELECT * FROM distributed_query_campaigns WHERE id = ?
` `
campaign := &fleet.DistributedQueryCampaign{} campaign := &fleet.DistributedQueryCampaign{}
if err := d.db.Get(campaign, sql, id); err != nil { if err := d.reader.Get(campaign, sql, id); err != nil {
return nil, errors.Wrap(err, "selecting distributed query campaign") return nil, errors.Wrap(err, "selecting distributed query campaign")
} }
@ -47,7 +47,7 @@ func (d *Datastore) SaveDistributedQueryCampaign(camp *fleet.DistributedQueryCam
user_id = ? user_id = ?
WHERE id = ? WHERE id = ?
` `
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID) result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "updating distributed query campaign") return errors.Wrap(err, "updating distributed query campaign")
} }
@ -68,7 +68,7 @@ func (d *Datastore) DistributedQueryCampaignTargetIDs(id uint) (*fleet.HostTarge
` `
targets := []fleet.DistributedQueryCampaignTarget{} targets := []fleet.DistributedQueryCampaignTarget{}
if err := d.db.Select(&targets, sqlStatement, id); err != nil { if err := d.reader.Select(&targets, sqlStatement, id); err != nil {
return nil, errors.Wrap(err, "select distributed campaign target") return nil, errors.Wrap(err, "select distributed campaign target")
} }
@ -100,7 +100,7 @@ func (d *Datastore) NewDistributedQueryCampaignTarget(target *fleet.DistributedQ
) )
VALUES (?,?,?) VALUES (?,?,?)
` `
result, err := d.db.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID) result, err := d.writer.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "insert distributed campaign target") return nil, errors.Wrap(err, "insert distributed campaign target")
} }
@ -118,7 +118,7 @@ func (d *Datastore) CleanupDistributedQueryCampaigns(now time.Time) (expired uin
WHERE (status = ? AND created_at < ?) WHERE (status = ? AND created_at < ?)
OR (status = ? AND created_at < ?) OR (status = ? AND created_at < ?)
` `
result, err := d.db.Exec(sqlStatement, fleet.QueryComplete, result, err := d.writer.Exec(sqlStatement, fleet.QueryComplete,
fleet.QueryWaiting, now.Add(-1*time.Minute), fleet.QueryWaiting, now.Add(-1*time.Minute),
fleet.QueryRunning, now.Add(-24*time.Hour)) fleet.QueryRunning, now.Add(-24*time.Hour))
if err != nil { if err != nil {

View File

@ -33,7 +33,7 @@ func (d *Datastore) NewCarve(metadata *fleet.CarveMetadata) (*fleet.CarveMetadat
? ?
)` )`
result, err := d.db.Exec( result, err := d.writer.Exec(
stmt, stmt,
metadata.HostId, metadata.HostId,
metadata.CreatedAt.Format(mySQLTimestampFormat), metadata.CreatedAt.Format(mySQLTimestampFormat),
@ -64,7 +64,7 @@ func (d *Datastore) UpdateCarve(metadata *fleet.CarveMetadata) error {
expired = ? expired = ?
WHERE id = ? WHERE id = ?
` `
_, err := d.db.Exec( _, err := d.writer.Exec(
stmt, stmt,
metadata.MaxBlock, metadata.MaxBlock,
metadata.Expired, metadata.Expired,
@ -163,7 +163,7 @@ func (d *Datastore) Carve(carveId int64) (*fleet.CarveMetadata, error) {
) )
var metadata fleet.CarveMetadata var metadata fleet.CarveMetadata
if err := d.db.Get(&metadata, stmt, carveId); err != nil { if err := d.reader.Get(&metadata, stmt, carveId); err != nil {
return nil, errors.Wrap(err, "get carve by ID") return nil, errors.Wrap(err, "get carve by ID")
} }
@ -179,7 +179,7 @@ func (d *Datastore) CarveBySessionId(sessionId string) (*fleet.CarveMetadata, er
) )
var metadata fleet.CarveMetadata var metadata fleet.CarveMetadata
if err := d.db.Get(&metadata, stmt, sessionId); err != nil { if err := d.reader.Get(&metadata, stmt, sessionId); err != nil {
return nil, errors.Wrap(err, "get carve by session ID") return nil, errors.Wrap(err, "get carve by session ID")
} }
@ -195,7 +195,7 @@ func (d *Datastore) CarveByName(name string) (*fleet.CarveMetadata, error) {
) )
var metadata fleet.CarveMetadata var metadata fleet.CarveMetadata
if err := d.db.Get(&metadata, stmt, name); err != nil { if err := d.reader.Get(&metadata, stmt, name); err != nil {
return nil, errors.Wrap(err, "get carve by name") return nil, errors.Wrap(err, "get carve by name")
} }
@ -213,7 +213,7 @@ func (d *Datastore) ListCarves(opt fleet.CarveListOptions) ([]*fleet.CarveMetada
} }
stmt = appendListOptionsToSQL(stmt, opt.ListOptions) stmt = appendListOptionsToSQL(stmt, opt.ListOptions)
carves := []*fleet.CarveMetadata{} carves := []*fleet.CarveMetadata{}
if err := d.db.Select(&carves, stmt); err != nil && err != sql.ErrNoRows { if err := d.reader.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "list carves") return nil, errors.Wrap(err, "list carves")
} }
@ -231,13 +231,14 @@ func (d *Datastore) NewBlock(metadata *fleet.CarveMetadata, blockId int64, data
?, ?,
? ?
)` )`
if _, err := d.db.Exec(stmt, metadata.ID, blockId, data); err != nil { if _, err := d.writer.Exec(stmt, metadata.ID, blockId, data); err != nil {
return errors.Wrap(err, "insert carve block") return errors.Wrap(err, "insert carve block")
} }
if metadata.MaxBlock < blockId { if metadata.MaxBlock < blockId {
// Update max_block // Update max_block
metadata.MaxBlock = blockId metadata.MaxBlock = blockId
// TODO: looks like this should all be done in an atomic transaction?
if err := d.UpdateCarve(metadata); err != nil { if err := d.UpdateCarve(metadata); err != nil {
return errors.Wrap(err, "insert carve block") return errors.Wrap(err, "insert carve block")
} }
@ -253,7 +254,7 @@ func (d *Datastore) GetBlock(metadata *fleet.CarveMetadata, blockId int64) ([]by
WHERE metadata_id = ? AND block_id = ? WHERE metadata_id = ? AND block_id = ?
` `
var data []byte var data []byte
if err := d.db.Get(&data, stmt, metadata.ID, blockId); err != nil { if err := d.reader.Get(&data, stmt, metadata.ID, blockId); err != nil {
return nil, errors.Wrap(err, "select data") return nil, errors.Wrap(err, "select data")
} }

View File

@ -1,6 +1,9 @@
package mysql package mysql
import "github.com/go-kit/kit/log" import (
"github.com/fleetdm/fleet/v4/server/config"
"github.com/go-kit/kit/log"
)
const defaultMaxAttempts int = 15 const defaultMaxAttempts int = 15
@ -11,6 +14,7 @@ type dbOptions struct {
// maxAttempts configures the number of retries to connect to the DB // maxAttempts configures the number of retries to connect to the DB
maxAttempts int maxAttempts int
logger log.Logger logger log.Logger
replicaConfig *config.MysqlConfig
} }
// Logger adds a logger to the datastore // Logger adds a logger to the datastore
@ -21,6 +25,14 @@ func Logger(l log.Logger) DBOption {
} }
} }
// Replica sets the configuration of the read replica for the datastore.
func Replica(conf *config.MysqlConfig) DBOption {
return func(o *dbOptions) error {
o.replicaConfig = conf
return nil
}
}
// LimitAttempts sets a the number of attempts // LimitAttempts sets a the number of attempts
// to try establishing a connection to the database backend // to try establishing a connection to the database backend
// the default value is 15 attempts // the default value is 15 attempts

View File

@ -11,7 +11,7 @@ import (
// returning a notFound error if appropriate. // returning a notFound error if appropriate.
func (d *Datastore) deleteEntity(dbTable string, id uint) error { func (d *Datastore) deleteEntity(dbTable string, id uint) error {
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable) deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable)
result, err := d.db.Exec(deleteStmt, id) result, err := d.writer.Exec(deleteStmt, id)
if err != nil { if err != nil {
return errors.Wrapf(err, "delete %s", dbTable) return errors.Wrapf(err, "delete %s", dbTable)
} }
@ -26,7 +26,7 @@ func (d *Datastore) deleteEntity(dbTable string, id uint) error {
// table, returning a notFound error if appropriate. // table, returning a notFound error if appropriate.
func (d *Datastore) deleteEntityByName(dbTable string, name string) error { func (d *Datastore) deleteEntityByName(dbTable string, name string) error {
deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable) deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable)
result, err := d.db.Exec(deleteStmt, name) result, err := d.writer.Exec(deleteStmt, name)
if err != nil { if err != nil {
if isMySQLForeignKey(err) { if isMySQLForeignKey(err) {
return foreignKey(dbTable, name) return foreignKey(dbTable, name)
@ -50,7 +50,7 @@ func (d *Datastore) deleteEntities(dbTable string, ids []uint) (uint, error) {
return 0, errors.Wrapf(err, "building delete entities query %s", dbTable) return 0, errors.Wrapf(err, "building delete entities query %s", dbTable)
} }
result, err := d.db.Exec(query, args...) result, err := d.writer.Exec(query, args...)
if err != nil { if err != nil {
return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable) return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable)
} }

View File

@ -15,7 +15,7 @@ func (ds *Datastore) PendingEmailChange(uid uint, newEmail, token string) error
new_email new_email
) VALUES( ?, ?, ? ) ) VALUES( ?, ?, ? )
` `
_, err := ds.db.Exec(sqlStatement, uid, token, newEmail) _, err := ds.writer.Exec(sqlStatement, uid, token, newEmail)
if err != nil { if err != nil {
return errors.Wrap(err, "inserting email change record") return errors.Wrap(err, "inserting email change record")
} }
@ -32,7 +32,7 @@ func (ds *Datastore) ConfirmPendingEmailChange(id uint, token string) (newEmail
Token string Token string
NewEmail string `db:"new_email"` NewEmail string `db:"new_email"`
}{} }{}
err = ds.db.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id) err = ds.writer.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", notFound("email change with token") return "", notFound("email change with token")

View File

@ -33,7 +33,7 @@ func (d *Datastore) NewHost(host *fleet.Host) (*fleet.Host, error) {
) )
VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? ) VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? )
` `
result, err := d.db.Exec( result, err := d.writer.Exec(
sqlStatement, sqlStatement,
host.OsqueryHostID, host.OsqueryHostID,
host.DetailUpdatedAt, host.DetailUpdatedAt,
@ -95,7 +95,7 @@ func (d *Datastore) SaveHost(host *fleet.Host) error {
percent_disk_space_available = ? percent_disk_space_available = ?
WHERE id = ? WHERE id = ?
` `
_, err := d.db.Exec(sqlStatement, _, err := d.writer.Exec(sqlStatement,
host.DetailUpdatedAt, host.DetailUpdatedAt,
host.LabelUpdatedAt, host.LabelUpdatedAt,
host.NodeKey, host.NodeKey,
@ -248,7 +248,7 @@ FROM scheduled_query_stats sqs
WHERE host_id = ? AND p.pack_type IS NULL WHERE host_id = ? AND p.pack_type IS NULL
` `
var stats []fleet.ScheduledQueryStats var stats []fleet.ScheduledQueryStats
if err := d.db.Select(&stats, sql, host.ID); err != nil { if err := d.reader.Select(&stats, sql, host.ID); err != nil {
return errors.Wrap(err, "load pack stats") return errors.Wrap(err, "load pack stats")
} }
@ -268,9 +268,9 @@ WHERE host_id = ? AND p.pack_type IS NULL
return nil return nil
} }
func (d *Datastore) loadHostUsers(host *fleet.Host) error { func (d *Datastore) loadHostUsers(db dbReader, host *fleet.Host) error {
sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL` sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL`
if err := d.db.Select(&host.Users, sql, host.ID); err != nil { if err := db.Select(&host.Users, sql, host.ID); err != nil {
return errors.Wrap(err, "load pack stats") return errors.Wrap(err, "load pack stats")
} }
return nil return nil
@ -292,14 +292,14 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
LIMIT 1 LIMIT 1
` `
host := &fleet.Host{} host := &fleet.Host{}
err := d.db.Get(host, sqlStatement, id) err := d.reader.Get(host, sqlStatement, id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "get host by id") return nil, errors.Wrap(err, "get host by id")
} }
if err := d.loadHostPackStats(host); err != nil { if err := d.loadHostPackStats(host); err != nil {
return nil, err return nil, err
} }
if err := d.loadHostUsers(host); err != nil { if err := d.loadHostUsers(d.reader, host); err != nil {
return nil, err return nil, err
} }
@ -308,7 +308,9 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
func (d *Datastore) amountEnrolledHosts() (int, error) { func (d *Datastore) amountEnrolledHosts() (int, error) {
var amount int var amount int
err := d.db.Get(&amount, `SELECT count(*) FROM hosts`) // uses the writer as it is called from ShouldSendStatistics, which is a
// write method.
err := d.writer.Get(&amount, `SELECT count(*) FROM hosts`)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -363,7 +365,7 @@ func (d *Datastore) ListHosts(filter fleet.TeamFilter, opt fleet.HostListOptions
sql = appendListOptionsToSQL(sql, opt.ListOptions) sql = appendListOptionsToSQL(sql, opt.ListOptions)
hosts := []*fleet.Host{} hosts := []*fleet.Host{}
if err := d.db.Select(&hosts, sql, params...); err != nil { if err := d.reader.Select(&hosts, sql, params...); err != nil {
return nil, errors.Wrap(err, "list hosts") return nil, errors.Wrap(err, "list hosts")
} }
@ -413,7 +415,7 @@ func (d *Datastore) CleanupIncomingHosts(now time.Time) error {
WHERE hostname = '' AND osquery_version = '' WHERE hostname = '' AND osquery_version = ''
AND created_at < (? - INTERVAL 5 MINUTE) AND created_at < (? - INTERVAL 5 MINUTE)
` `
if _, err := d.db.Exec(sqlStatement, now); err != nil { if _, err := d.writer.Exec(sqlStatement, now); err != nil {
return errors.Wrap(err, "cleanup incoming hosts") return errors.Wrap(err, "cleanup incoming hosts")
} }
@ -442,7 +444,7 @@ func (d *Datastore) GenerateHostStatusStatistics(filter fleet.TeamFilter, now ti
Online uint `db:"online"` Online uint `db:"online"`
New uint `db:"new"` New uint `db:"new"`
}{} }{}
err := d.db.Get(&counts, sqlStatement, now, now, now, now, now) err := d.reader.Get(&counts, sqlStatement, now, now, now, now, now)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
e = errors.Wrap(err, "generating host statistics") e = errors.Wrap(err, "generating host statistics")
return return
@ -582,7 +584,7 @@ func (d *Datastore) AuthenticateHost(nodeKey string) (*fleet.Host, error) {
` `
host := &fleet.Host{} host := &fleet.Host{}
if err := d.db.Get(host, sqlStatement, nodeKey); err != nil { if err := d.reader.Get(host, sqlStatement, nodeKey); err != nil {
switch err { switch err {
case sql.ErrNoRows: case sql.ErrNoRows:
return nil, notFound("Host") return nil, notFound("Host")
@ -601,7 +603,7 @@ func (d *Datastore) MarkHostSeen(host *fleet.Host, t time.Time) error {
WHERE node_key=? WHERE node_key=?
` `
_, err := d.db.Exec(sqlStatement, t, host.NodeKey) _, err := d.writer.Exec(sqlStatement, t, host.NodeKey)
if err != nil { if err != nil {
return errors.Wrap(err, "marking host seen") return errors.Wrap(err, "marking host seen")
} }
@ -659,11 +661,11 @@ func (d *Datastore) searchHostsWithOmits(filter fleet.TeamFilter, query string,
if err != nil { if err != nil {
return nil, errors.Wrap(err, "searching hosts") return nil, errors.Wrap(err, "searching hosts")
} }
sql = d.db.Rebind(sql) sql = d.reader.Rebind(sql)
hosts := []*fleet.Host{} hosts := []*fleet.Host{}
err = d.db.Select(&hosts, sql, args...) err = d.reader.Select(&hosts, sql, args...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "searching hosts rebound") return nil, errors.Wrap(err, "searching hosts rebound")
} }
@ -695,8 +697,8 @@ func (d *Datastore) searchHostsDefault(filter fleet.TeamFilter, omit ...uint) ([
if err != nil { if err != nil {
return nil, errors.Wrap(err, "searching default hosts") return nil, errors.Wrap(err, "searching default hosts")
} }
sql = d.db.Rebind(sql) sql = d.reader.Rebind(sql)
err = d.db.Select(&hosts, sql, args...) err = d.reader.Select(&hosts, sql, args...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "searching default hosts rebound") return nil, errors.Wrap(err, "searching default hosts rebound")
} }
@ -730,7 +732,7 @@ func (d *Datastore) SearchHosts(filter fleet.TeamFilter, query string, omit ...u
) )
hosts := []*fleet.Host{} hosts := []*fleet.Host{}
if err := d.db.Select(&hosts, sql, hostQuery, ipQuery); err != nil { if err := d.reader.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
return nil, errors.Wrap(err, "searching hosts") return nil, errors.Wrap(err, "searching hosts")
} }
@ -755,7 +757,7 @@ func (d *Datastore) HostIDsByName(filter fleet.TeamFilter, hostnames []string) (
} }
var hostIDs []uint var hostIDs []uint
if err := d.db.Select(&hostIDs, sql, args...); err != nil { if err := d.reader.Select(&hostIDs, sql, args...); err != nil {
return nil, errors.Wrap(err, "get host IDs") return nil, errors.Wrap(err, "get host IDs")
} }
@ -770,7 +772,7 @@ func (d *Datastore) HostByIdentifier(identifier string) (*fleet.Host, error) {
LIMIT 1 LIMIT 1
` `
host := &fleet.Host{} host := &fleet.Host{}
err := d.db.Get(host, sql, identifier) err := d.reader.Get(host, sql, identifier)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "get host by identifier") return nil, errors.Wrap(err, "get host by identifier")
} }
@ -796,7 +798,7 @@ func (d *Datastore) AddHostsToTeam(teamID *uint, hostIDs []uint) error {
return errors.Wrap(err, "sqlx.In AddHostsToTeam") return errors.Wrap(err, "sqlx.In AddHostsToTeam")
} }
if _, err := d.db.Exec(sql, args...); err != nil { if _, err := d.writer.Exec(sql, args...); err != nil {
return errors.Wrap(err, "exec AddHostsToTeam") return errors.Wrap(err, "exec AddHostsToTeam")
} }
@ -809,7 +811,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
VALUES (?, ?) VALUES (?, ?)
ON DUPLICATE KEY UPDATE additional = VALUES(additional) ON DUPLICATE KEY UPDATE additional = VALUES(additional)
` `
if _, err := d.db.Exec(sql, host.ID, host.Additional); err != nil { if _, err := d.writer.Exec(sql, host.ID, host.Additional); err != nil {
return errors.Wrap(err, "insert additional") return errors.Wrap(err, "insert additional")
} }
@ -818,7 +820,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
func (d *Datastore) SaveHostUsers(host *fleet.Host) error { func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
if len(host.Users) == 0 { if len(host.Users) == 0 {
if _, err := d.db.Exec( if _, err := d.writer.Exec(
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`, `UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`,
host.ID, host.ID,
); err != nil { ); err != nil {
@ -829,7 +831,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
} }
currentHost := &fleet.Host{ID: host.ID} currentHost := &fleet.Host{ID: host.ID}
if err := d.loadHostUsers(currentHost); err != nil { if err := d.loadHostUsers(d.writer, currentHost); err != nil {
return err return err
} }
@ -837,7 +839,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
for _, u := range host.Users { for _, u := range host.Users {
incomingUsers[u.Uid] = true incomingUsers[u.Uid] = true
if _, err := d.db.Exec( if _, err := d.writer.Exec(
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`, `INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
host.ID, u.Uid, u.Username, u.Type, u.GroupName, host.ID, u.Uid, u.Username, u.Type, u.GroupName,
); err != nil { ); err != nil {
@ -860,7 +862,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`, `UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`,
removedValues, removedValues,
) )
if _, err := d.db.Exec(removedSql, removedArgs...); err != nil { if _, err := d.writer.Exec(removedSql, removedArgs...); err != nil {
return errors.Wrap(err, "mark users as removed") return errors.Wrap(err, "mark users as removed")
} }
@ -869,15 +871,15 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) { func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) {
var totalCount, unseenCount int var totalCount, unseenCount int
err := d.db.QueryRow("SELECT count(*) FROM hosts").Scan(&totalCount) err := d.reader.Get(&totalCount, "SELECT count(*) FROM hosts")
if err != nil { if err != nil {
return 0, 0, errors.Wrap(err, "getting total host count") return 0, 0, errors.Wrap(err, "getting total host count")
} }
err = d.db.QueryRow( err = d.reader.Get(&unseenCount,
"SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?", "SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?",
daysCount, daysCount,
).Scan(&unseenCount) )
if err != nil { if err != nil {
return 0, 0, errors.Wrap(err, "getting unseen host count") return 0, 0, errors.Wrap(err, "getting unseen host count")
} }

View File

@ -23,7 +23,7 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
VALUES ( ?, ?, ?, ?, ?, ?, ?) VALUES ( ?, ?, ?, ?, ?, ?, ?)
` `
result, err := d.db.Exec(sqlStmt, i.InvitedBy, i.Email, result, err := d.writer.Exec(sqlStmt, i.InvitedBy, i.Email,
i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole) i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole)
if err != nil && isDuplicate(err) { if err != nil && isDuplicate(err) {
return nil, alreadyExists("Invite", i.Email) return nil, alreadyExists("Invite", i.Email)
@ -45,10 +45,11 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
for _, userTeam := range i.Teams { for _, userTeam := range i.Teams {
args = append(args, i.ID, userTeam.Team.ID, userTeam.Role) args = append(args, i.ID, userTeam.Team.ID, userTeam.Role)
} }
// TODO: seems like this should be in a transaction?
sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " + sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " +
strings.Repeat(valueStr, len(i.Teams)) strings.Repeat(valueStr, len(i.Teams))
sql = strings.TrimSuffix(sql, ",") sql = strings.TrimSuffix(sql, ",")
if _, err := d.db.Exec(sql, args...); err != nil { if _, err := d.writer.Exec(sql, args...); err != nil {
return nil, errors.Wrap(err, "insert teams") return nil, errors.Wrap(err, "insert teams")
} }
@ -63,7 +64,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...) query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...)
query = appendListOptionsToSQL(query, opt) query = appendListOptionsToSQL(query, opt)
err := d.db.Select(&invites, query, params...) err := d.reader.Select(&invites, query, params...)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Invite") return nil, notFound("Invite")
} else if err != nil { } else if err != nil {
@ -80,7 +81,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
// Invite returns Invite identified by id. // Invite returns Invite identified by id.
func (d *Datastore) Invite(id uint) (*fleet.Invite, error) { func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
var invite fleet.Invite var invite fleet.Invite
err := d.db.Get(&invite, "SELECT * FROM invites WHERE id = ?", id) err := d.reader.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Invite").WithID(id) return nil, notFound("Invite").WithID(id)
} else if err != nil { } else if err != nil {
@ -97,7 +98,7 @@ func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
// InviteByEmail finds an Invite with a particular email, if one exists. // InviteByEmail finds an Invite with a particular email, if one exists.
func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) { func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
var invite fleet.Invite var invite fleet.Invite
err := d.db.Get(&invite, "SELECT * FROM invites WHERE email = ?", email) err := d.reader.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Invite"). return nil, notFound("Invite").
WithMessage(fmt.Sprintf("with email %s", email)) WithMessage(fmt.Sprintf("with email %s", email))
@ -115,7 +116,7 @@ func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
// InviteByToken finds an Invite with a particular token, if one exists. // InviteByToken finds an Invite with a particular token, if one exists.
func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) { func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) {
var invite fleet.Invite var invite fleet.Invite
err := d.db.Get(&invite, "SELECT * FROM invites WHERE token = ?", token) err := d.reader.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Invite"). return nil, notFound("Invite").
WithMessage(fmt.Sprintf("with token %s", token)) WithMessage(fmt.Sprintf("with token %s", token))
@ -163,7 +164,7 @@ func (d *Datastore) loadTeamsForInvites(invites []*fleet.Invite) error {
fleet.UserTeam fleet.UserTeam
InviteID uint `db:"invite_id"` InviteID uint `db:"invite_id"`
} }
if err := d.db.Select(&rows, sql, args...); err != nil { if err := d.reader.Select(&rows, sql, args...); err != nil {
return errors.Wrap(err, "get loadTeamsForInvites") return errors.Wrap(err, "get loadTeamsForInvites")
} }

View File

@ -115,7 +115,7 @@ func (d *Datastore) GetLabelSpecs() ([]*fleet.LabelSpec, error) {
var specs []*fleet.LabelSpec var specs []*fleet.LabelSpec
// Get basic specs // Get basic specs
query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels" query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels"
if err := d.db.Select(&specs, query); err != nil { if err := d.reader.Select(&specs, query); err != nil {
return nil, errors.Wrap(err, "get labels") return nil, errors.Wrap(err, "get labels")
} }
@ -138,7 +138,7 @@ SELECT name, description, query, platform, label_type, label_membership_type
FROM labels FROM labels
WHERE name = ? WHERE name = ?
` `
if err := d.db.Select(&specs, query, name); err != nil { if err := d.reader.Select(&specs, query, name); err != nil {
return nil, errors.Wrap(err, "get label") return nil, errors.Wrap(err, "get label")
} }
if len(specs) == 0 { if len(specs) == 0 {
@ -171,7 +171,7 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
WHERE label_id = (SELECT id FROM labels WHERE name = ?) WHERE label_id = (SELECT id FROM labels WHERE name = ?)
) )
` `
err := d.db.Select(&label.Hosts, sql, label.Name) err := d.reader.Select(&label.Hosts, sql, label.Name)
if err != nil { if err != nil {
return errors.Wrap(err, "get hostnames for label") return errors.Wrap(err, "get hostnames for label")
} }
@ -190,7 +190,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
label_membership_type label_membership_type
) VALUES ( ?, ?, ?, ?, ?, ?) ) VALUES ( ?, ?, ?, ?, ?, ?)
` `
result, err := d.db.Exec( result, err := d.writer.Exec(
query, query,
label.Name, label.Name,
label.Description, label.Description,
@ -216,7 +216,7 @@ func (d *Datastore) SaveLabel(label *fleet.Label) (*fleet.Label, error) {
description = ? description = ?
WHERE id = ? WHERE id = ?
` `
_, err := d.db.Exec(query, label.Name, label.Description, label.ID) _, err := d.writer.Exec(query, label.Name, label.Description, label.ID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "saving label") return nil, errors.Wrap(err, "saving label")
} }
@ -236,7 +236,7 @@ func (d *Datastore) Label(lid uint) (*fleet.Label, error) {
` `
label := &fleet.Label{} label := &fleet.Label{}
if err := d.db.Get(label, sql, lid); err != nil { if err := d.reader.Get(label, sql, lid); err != nil {
return nil, errors.Wrap(err, "selecting label") return nil, errors.Wrap(err, "selecting label")
} }
@ -255,7 +255,7 @@ func (d *Datastore) ListLabels(filter fleet.TeamFilter, opt fleet.ListOptions) (
query = appendListOptionsToSQL(query, opt) query = appendListOptionsToSQL(query, opt)
labels := []*fleet.Label{} labels := []*fleet.Label{}
if err := d.db.Select(&labels, query); err != nil { if err := d.reader.Select(&labels, query); err != nil {
// it's ok if no labels exist // it's ok if no labels exist
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return labels, nil return labels, nil
@ -288,7 +288,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
WHERE platform = ? OR platform = '' WHERE platform = ? OR platform = ''
AND label_membership_type = ? AND label_membership_type = ?
` `
rows, err = d.db.Query(sql, platform, fleet.LabelMembershipTypeDynamic) rows, err = d.reader.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
} else { } else {
// Retrieve all labels (with matching platform) iff there is a label // Retrieve all labels (with matching platform) iff there is a label
// that has been created since this host last reported label query // that has been created since this host last reported label query
@ -300,7 +300,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
AND (platform = ? OR platform = '') AND (platform = ? OR platform = '')
AND label_membership_type = ? AND label_membership_type = ?
` `
rows, err = d.db.Query( rows, err = d.reader.Query(
sql, sql,
platform, platform,
host.LabelUpdatedAt, host.LabelUpdatedAt,
@ -369,7 +369,7 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
updated_at = VALUES(updated_at) updated_at = VALUES(updated_at)
` `
_, err := d.db.Exec(sql, vals...) _, err := d.writer.Exec(sql, vals...)
if err != nil { if err != nil {
return errors.Wrapf(err, "insert label query executions (%v)", vals) return errors.Wrapf(err, "insert label query executions (%v)", vals)
} }
@ -384,8 +384,8 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
if err != nil { if err != nil {
return errors.Wrap(err, "IN for DELETE FROM label_membership") return errors.Wrap(err, "IN for DELETE FROM label_membership")
} }
query = d.db.Rebind(query) query = d.writer.Rebind(query)
_, err = d.db.Exec(query, args...) _, err = d.writer.Exec(query, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "delete label query executions") return errors.Wrap(err, "delete label query executions")
} }
@ -403,7 +403,7 @@ func (d *Datastore) ListLabelsForHost(hid uint) ([]*fleet.Label, error) {
` `
labels := []*fleet.Label{} labels := []*fleet.Label{}
err := d.db.Select(&labels, sqlStatement, hid) err := d.reader.Select(&labels, sqlStatement, hid)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting host labels") return nil, errors.Wrap(err, "selecting host labels")
} }
@ -432,7 +432,7 @@ func (d *Datastore) ListHostsInLabel(filter fleet.TeamFilter, lid uint, opt flee
sql = appendListOptionsToSQL(sql, opt.ListOptions) sql = appendListOptionsToSQL(sql, opt.ListOptions)
hosts := []*fleet.Host{} hosts := []*fleet.Host{}
err := d.db.Select(&hosts, sql, params...) err := d.reader.Select(&hosts, sql, params...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting label query executions") return nil, errors.Wrap(err, "selecting label query executions")
} }
@ -458,9 +458,9 @@ func (d *Datastore) ListUniqueHostsInLabels(filter fleet.TeamFilter, labels []ui
return nil, errors.Wrap(err, "building query listing unique hosts in labels") return nil, errors.Wrap(err, "building query listing unique hosts in labels")
} }
query = d.db.Rebind(query) query = d.reader.Rebind(query)
hosts := []*fleet.Host{} hosts := []*fleet.Host{}
err = d.db.Select(&hosts, query, args...) err = d.reader.Select(&hosts, query, args...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "listing unique hosts in labels") return nil, errors.Wrap(err, "listing unique hosts in labels")
} }
@ -492,10 +492,10 @@ func (d *Datastore) searchLabelsWithOmits(filter fleet.TeamFilter, query string,
return nil, errors.Wrap(err, "building query for labels with omits") return nil, errors.Wrap(err, "building query for labels with omits")
} }
sql = d.db.Rebind(sql) sql = d.reader.Rebind(sql)
matches := []*fleet.Label{} matches := []*fleet.Label{}
err = d.db.Select(&matches, sql, args...) err = d.reader.Select(&matches, sql, args...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting labels with omits") return nil, errors.Wrap(err, "selecting labels with omits")
} }
@ -527,7 +527,7 @@ func (d *Datastore) addAllHostsLabelToList(filter fleet.TeamFilter, labels []*fl
) )
var allHosts fleet.Label var allHosts fleet.Label
if err := d.db.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil { if err := d.reader.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
return nil, errors.Wrap(err, "get all hosts label") return nil, errors.Wrap(err, "get all hosts label")
} }
@ -575,8 +575,8 @@ func (d *Datastore) searchLabelsDefault(filter fleet.TeamFilter, omit ...uint) (
if err != nil { if err != nil {
return nil, errors.Wrap(err, "searching default labels") return nil, errors.Wrap(err, "searching default labels")
} }
sql = d.db.Rebind(sql) sql = d.reader.Rebind(sql)
if err := d.db.Select(&labels, sql, args...); err != nil { if err := d.reader.Select(&labels, sql, args...); err != nil {
return nil, errors.Wrap(err, "searching default labels rebound") return nil, errors.Wrap(err, "searching default labels rebound")
} }
@ -617,7 +617,7 @@ func (d *Datastore) SearchLabels(filter fleet.TeamFilter, query string, omit ...
) )
matches := []*fleet.Label{} matches := []*fleet.Label{}
if err := d.db.Select(&matches, sql, transformedQuery); err != nil { if err := d.reader.Select(&matches, sql, transformedQuery); err != nil {
return nil, errors.Wrap(err, "selecting labels for search") return nil, errors.Wrap(err, "selecting labels for search")
} }
@ -645,7 +645,7 @@ func (d *Datastore) LabelIDsByName(labels []string) ([]uint, error) {
} }
var labelIDs []uint var labelIDs []uint
if err := d.db.Select(&labelIDs, sql, args...); err != nil { if err := d.reader.Select(&labelIDs, sql, args...); err != nil {
return nil, errors.Wrap(err, "get label IDs") return nil, errors.Wrap(err, "get label IDs")
} }

View File

@ -29,27 +29,27 @@ func (d *Datastore) Lock(name string, owner string, expiration time.Duration) (b
} }
func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) { func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) {
return d.db.Exec( return d.writer.Exec(
`INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`, `INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`,
name, owner, time.Now().Add(expiration), name, owner, time.Now().Add(expiration),
) )
} }
func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) { func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) {
return d.db.Exec( return d.writer.Exec(
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`, `UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`,
name, owner, time.Now().Add(expiration), name, owner, name, owner, time.Now().Add(expiration), name, owner,
) )
} }
func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) { func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) {
return d.db.Exec( return d.writer.Exec(
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`, `UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`,
name, owner, time.Now().Add(expiration), name, name, owner, time.Now().Add(expiration), name,
) )
} }
func (d *Datastore) Unlock(name string, owner string) error { func (d *Datastore) Unlock(name string, owner string) error {
_, err := d.db.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner) _, err := d.writer.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
return err return err
} }

View File

@ -85,14 +85,14 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
defer ds.Close() defer ds.Close()
for { for {
version, err := tables.MigrationClient.GetDBVersion(ds.db.DB) version, err := tables.MigrationClient.GetDBVersion(ds.writer.DB)
require.NoError(t, err) require.NoError(t, err)
// break right before the the constraint migration // break right before the the constraint migration
if version == 20210818182258 { if version == 20210818182258 {
break break
} }
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, "")) require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
} }
host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now()) host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now())
@ -122,7 +122,7 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
require.NoError(t, ds.DeleteHost(host1.ID)) require.NoError(t, ds.DeleteHost(host1.ID))
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, "")) require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
// Make sure we don't delete more than we need // Make sure we don't delete more than we need
hostCheck, err := ds.Host(host2.ID) hostCheck, err := ds.Host(host2.ID)

View File

@ -37,13 +37,28 @@ var (
columnCharsRegexp = regexp.MustCompile(`[^\w-]`) columnCharsRegexp = regexp.MustCompile(`[^\w-]`)
) )
// dbReader is an interface that defines the methods required for reads.
type dbReader interface {
sqlx.Queryer
Close() error
Rebind(string) string
Select(interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
}
// Datastore is an implementation of fleet.Datastore interface backed by // Datastore is an implementation of fleet.Datastore interface backed by
// MySQL // MySQL
type Datastore struct { type Datastore struct {
db *sqlx.DB reader dbReader // so it cannot be used to perform writes
writer *sqlx.DB
logger log.Logger logger log.Logger
clock clock.Clock clock clock.Clock
config config.MysqlConfig config config.MysqlConfig
// nil if no read replica
readReplicaConfig *config.MysqlConfig
} }
type txFn func(*sqlx.Tx) error type txFn func(*sqlx.Tx) error
@ -67,7 +82,7 @@ func retryableError(err error) bool {
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff // withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
func (d *Datastore) withRetryTxx(fn txFn) (err error) { func (d *Datastore) withRetryTxx(fn txFn) (err error) {
operation := func() error { operation := func() error {
tx, err := d.db.Beginx() tx, err := d.writer.Beginx()
if err != nil { if err != nil {
return errors.Wrap(err, "create transaction") return errors.Wrap(err, "create transaction")
} }
@ -116,7 +131,7 @@ func (d *Datastore) withRetryTxx(fn txFn) (err error) {
// withTx provides a common way to commit/rollback a txFn // withTx provides a common way to commit/rollback a txFn
func (d *Datastore) withTx(fn txFn) (err error) { func (d *Datastore) withTx(fn txFn) (err error) {
tx, err := d.db.Beginx() tx, err := d.writer.Beginx()
if err != nil { if err != nil {
return errors.Wrap(err, "create transaction") return errors.Wrap(err, "create transaction")
} }
@ -153,51 +168,64 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
} }
for _, setOpt := range opts { for _, setOpt := range opts {
if setOpt != nil {
setOpt(options) setOpt(options)
} }
if config.PasswordPath != "" && config.Password != "" {
return nil, errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
} }
// Check to see if the flag is populated if err := checkConfig(&config); err != nil {
// Check if file exists on disk return nil, err
// If file exists read contents }
if config.PasswordPath != "" { if options.replicaConfig != nil {
fileContents, err := ioutil.ReadFile(config.PasswordPath) if err := checkConfig(options.replicaConfig); err != nil {
return nil, errors.Wrap(err, "replica")
}
}
dbWriter, err := newDB(&config, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
config.Password = strings.TrimSpace(string(fileContents)) dbReader := dbWriter
} if options.replicaConfig != nil {
dbReader, err = newDB(options.replicaConfig, options)
if config.TLSCA != "" {
config.TLSConfig = "custom"
err := registerTLS(config)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "register TLS config for mysql") return nil, err
} }
} }
dsn := generateMysqlConnectionString(config) ds := &Datastore{
writer: dbWriter,
reader: dbReader,
logger: options.logger,
clock: c,
config: config,
readReplicaConfig: options.replicaConfig,
}
return ds, nil
}
func newDB(conf *config.MysqlConfig, opts *dbOptions) (*sqlx.DB, error) {
dsn := generateMysqlConnectionString(*conf)
db, err := sqlx.Open("mysql", dsn) db, err := sqlx.Open("mysql", dsn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.SetMaxIdleConns(config.MaxIdleConns) db.SetMaxIdleConns(conf.MaxIdleConns)
db.SetMaxOpenConns(config.MaxOpenConns) db.SetMaxOpenConns(conf.MaxOpenConns)
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetime)) db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
var dbError error var dbError error
for attempt := 0; attempt < options.maxAttempts; attempt++ { for attempt := 0; attempt < opts.maxAttempts; attempt++ {
dbError = db.Ping() dbError = db.Ping()
if dbError == nil { if dbError == nil {
// we're connected! // we're connected!
break break
} }
interval := time.Duration(attempt) * time.Second interval := time.Duration(attempt) * time.Second
options.logger.Log("mysql", fmt.Sprintf( opts.logger.Log("mysql", fmt.Sprintf(
"could not connect to db: %v, sleeping %v", dbError, interval)) "could not connect to db: %v, sleeping %v", dbError, interval))
time.Sleep(interval) time.Sleep(interval)
} }
@ -205,20 +233,37 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
if dbError != nil { if dbError != nil {
return nil, dbError return nil, dbError
} }
return db, nil
}
ds := &Datastore{ func checkConfig(conf *config.MysqlConfig) error {
db: db, if conf.PasswordPath != "" && conf.Password != "" {
logger: options.logger, return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
clock: c,
config: config,
} }
return ds, nil // Check to see if the flag is populated
// Check if file exists on disk
// If file exists read contents
if conf.PasswordPath != "" {
fileContents, err := ioutil.ReadFile(conf.PasswordPath)
if err != nil {
return err
}
conf.Password = strings.TrimSpace(string(fileContents))
}
if conf.TLSCA != "" {
conf.TLSConfig = "custom"
err := registerTLS(*conf)
if err != nil {
return errors.Wrap(err, "register TLS config for mysql")
}
}
return nil
} }
func (d *Datastore) Begin() (fleet.Transaction, error) { func (d *Datastore) Begin() (fleet.Transaction, error) {
return d.db.Beginx() return d.writer.Beginx()
} }
func (d *Datastore) Name() string { func (d *Datastore) Name() string {
@ -226,11 +271,11 @@ func (d *Datastore) Name() string {
} }
func (d *Datastore) MigrateTables() error { func (d *Datastore) MigrateTables() error {
return tables.MigrationClient.Up(d.db.DB, "") return tables.MigrationClient.Up(d.writer.DB, "")
} }
func (d *Datastore) MigrateData() error { func (d *Datastore) MigrateData() error {
return data.MigrationClient.Up(d.db.DB, "") return data.MigrationClient.Up(d.writer.DB, "")
} }
func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) { func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
@ -243,7 +288,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
return 0, errors.Wrap(err, "missing tables migrations") return 0, errors.Wrap(err, "missing tables migrations")
} }
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.db.DB) currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.writer.DB)
if err != nil { if err != nil {
return 0, errors.Wrap(err, "cannot get table migration status") return 0, errors.Wrap(err, "cannot get table migration status")
} }
@ -253,7 +298,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
return 0, errors.Wrap(err, "missing data migrations") return 0, errors.Wrap(err, "missing data migrations")
} }
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.db.DB) currentDataVersion, err := data.MigrationClient.GetDBVersion(d.writer.DB)
if err != nil { if err != nil {
return 0, errors.Wrap(err, "cannot get data migration status") return 0, errors.Wrap(err, "cannot get data migration status")
} }
@ -283,11 +328,11 @@ func (d *Datastore) Drop() error {
WHERE TABLE_SCHEMA = ?; WHERE TABLE_SCHEMA = ?;
` `
if err := d.db.Select(&tables, sql, d.config.Database); err != nil { if err := d.writer.Select(&tables, sql, d.config.Database); err != nil {
return err return err
} }
tx, err := d.db.Begin() tx, err := d.writer.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -312,13 +357,28 @@ func (d *Datastore) Drop() error {
// HealthCheck returns an error if the MySQL backend is not healthy. // HealthCheck returns an error if the MySQL backend is not healthy.
func (d *Datastore) HealthCheck() error { func (d *Datastore) HealthCheck() error {
_, err := d.db.Exec("select 1") if _, err := d.writer.Exec("select 1"); err != nil {
return err return err
}
if d.readReplicaConfig != nil {
var dst int
if err := d.reader.Get(&dst, "select 1"); err != nil {
return err
}
}
return nil
} }
// Close frees resources associated with underlying mysql connection // Close frees resources associated with underlying mysql connection
func (d *Datastore) Close() error { func (d *Datastore) Close() error {
return d.db.Close() err := d.writer.Close()
if d.readReplicaConfig != nil {
errRead := d.reader.Close()
if err == nil {
err = errRead
}
}
return err
} }
func sanitizeColumn(col string) string { func sanitizeColumn(col string) string {

View File

@ -1,7 +1,9 @@
package mysql package mysql
import ( import (
"database/sql"
"testing" "testing"
"time"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/VividCortex/mysqlerr" "github.com/VividCortex/mysqlerr"
@ -15,6 +17,51 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestDatastoreReplica(t *testing.T) {
// a bit unfortunate to create temp databases just for this - could be mixed
// with other tests when/if we move to subtests to minimize the number of
// databases created for tests (see #1805).
t.Run("noreplica", func(t *testing.T) {
ds := CreateMySQLDSWithOptions(t, nil)
defer ds.Close()
require.Equal(t, ds.reader, ds.writer)
})
t.Run("replica", func(t *testing.T) {
opts := &DatastoreTestOptions{Replica: true}
ds := CreateMySQLDSWithOptions(t, opts)
defer ds.Close()
require.NotEqual(t, ds.reader, ds.writer)
// create a new host
host, err := ds.NewHost(&fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: "1",
UUID: "1",
Hostname: "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
})
require.NoError(t, err)
require.NotNil(t, host)
// trying to read it fails, not replicated yet
_, err = ds.Host(host.ID)
require.Error(t, err)
require.True(t, errors.Is(err, sql.ErrNoRows))
opts.RunReplication()
// now it can read it
host2, err := ds.Host(host.ID)
require.NoError(t, err)
require.Equal(t, host.ID, host2.ID)
})
}
func TestSanitizeColumn(t *testing.T) { func TestSanitizeColumn(t *testing.T) {
t.Parallel() t.Parallel()
@ -121,8 +168,10 @@ func TestSearchLike(t *testing.T) {
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) { func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
require.NoError(t, err) require.NoError(t, err)
dbmock := sqlx.NewDb(db, "sqlmock")
ds := &Datastore{ ds := &Datastore{
db: sqlx.NewDb(db, "sqlmock"), writer: dbmock,
reader: dbmock,
logger: log.NewNopLogger(), logger: log.NewNopLogger(),
} }

View File

@ -205,7 +205,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
WHERE name = ? WHERE name = ?
` `
var pack fleet.Pack var pack fleet.Pack
err := d.db.Get(&pack, sqlStatement, name) err := d.reader.Get(&pack, sqlStatement, name)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, false, nil return nil, false, nil
@ -213,7 +213,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
return nil, false, errors.Wrap(err, "fetch pack by name") return nil, false, errors.Wrap(err, "fetch pack by name")
} }
if err := d.loadPackTargets(&pack); err != nil { if err := d.loadPackTargets(d.reader, &pack); err != nil {
return nil, false, err return nil, false, err
} }
@ -313,10 +313,10 @@ func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
return nil return nil
} }
func (d *Datastore) loadPackTargets(pack *fleet.Pack) error { func (d *Datastore) loadPackTargets(db dbReader, pack *fleet.Pack) error {
var targets []fleet.PackTarget var targets []fleet.PackTarget
sql := `SELECT * FROM pack_targets WHERE pack_id = ?` sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
if err := d.db.Select(&targets, sql, pack.ID); err != nil { if err := db.Select(&targets, sql, pack.ID); err != nil {
return errors.Wrap(err, "select pack targets") return errors.Wrap(err, "select pack targets")
} }
@ -371,14 +371,14 @@ func (d *Datastore) DeletePack(name string) error {
func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) { func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
query := `SELECT * FROM packs WHERE id = ?` query := `SELECT * FROM packs WHERE id = ?`
pack := &fleet.Pack{} pack := &fleet.Pack{}
err := d.db.Get(pack, query, pid) err := d.reader.Get(pack, query, pid)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Pack").WithID(pid) return nil, notFound("Pack").WithID(pid)
} else if err != nil { } else if err != nil {
return nil, errors.Wrap(err, "get pack") return nil, errors.Wrap(err, "get pack")
} }
if err := d.loadPackTargets(pack); err != nil { if err := d.loadPackTargets(d.reader, pack); err != nil {
return nil, err return nil, err
} }
@ -388,14 +388,15 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
// EnsureGlobalPack gets or inserts a pack with type global // EnsureGlobalPack gets or inserts a pack with type global
func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) { func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) {
pack := &fleet.Pack{} pack := &fleet.Pack{}
err := d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`) // read from primary as we will create the pack if it doesn't exist
err := d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return d.insertNewGlobalPack() return d.insertNewGlobalPack()
} else if err != nil { } else if err != nil {
return nil, errors.Wrap(err, "get pack") return nil, errors.Wrap(err, "get pack")
} }
if err := d.loadPackTargets(pack); err != nil { if err := d.loadPackTargets(d.writer, pack); err != nil {
return nil, err return nil, err
} }
@ -439,14 +440,15 @@ func (d *Datastore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) {
} }
teamType := fmt.Sprintf("team-%d", teamID) teamType := fmt.Sprintf("team-%d", teamID)
err = d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType) // read from primary as we will create the team pack if it doesn't exist
err = d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return d.insertNewTeamPack(t) return d.insertNewTeamPack(t)
} else if err != nil { } else if err != nil {
return nil, errors.Wrap(err, "get pack") return nil, errors.Wrap(err, "get pack")
} }
if err := d.loadPackTargets(pack); err != nil { if err := d.loadPackTargets(d.writer, pack); err != nil {
return nil, err return nil, err
} }
@ -499,13 +501,13 @@ func (d *Datastore) ListPacks(opt fleet.PackListOptions) ([]*fleet.Pack, error)
query = `SELECT * FROM packs` query = `SELECT * FROM packs`
} }
var packs []*fleet.Pack var packs []*fleet.Pack
err := d.db.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions)) err := d.reader.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing packs") return nil, errors.Wrap(err, "listing packs")
} }
for _, pack := range packs { for _, pack := range packs {
if err := d.loadPackTargets(pack); err != nil { if err := d.loadPackTargets(d.reader, pack); err != nil {
return nil, err return nil, err
} }
} }
@ -541,7 +543,7 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
` `
packs := []*fleet.Pack{} packs := []*fleet.Pack{}
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows { if err := d.reader.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing hosts in pack") return nil, errors.Wrap(err, "listing hosts in pack")
} }
return packs, nil return packs, nil

View File

@ -524,7 +524,7 @@ func TestTeamScheduleNamesMigrateToNewFormat(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// insert team pack by hand with the old naming scheme // insert team pack by hand with the old naming scheme
_, err = ds.db.Exec( _, err = ds.writer.Exec(
"INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)", "INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)",
teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1), teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1),
) )

View File

@ -11,7 +11,7 @@ func (d *Datastore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
( user_id, token, expires_at) ( user_id, token, expires_at)
VALUES (?,?, NOW()) VALUES (?,?, NOW())
` `
response, err := d.db.Exec(sqlStatement, req.UserID, req.Token) response, err := d.writer.Exec(sqlStatement, req.UserID, req.Token)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "inserting password reset requests") return nil, errors.Wrap(err, "inserting password reset requests")
} }
@ -30,7 +30,7 @@ func (d *Datastore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) er
token = ? token = ?
WHERE id = ? WHERE id = ?
` `
result, err := d.db.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID) result, err := d.writer.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "updating password reset requests") return errors.Wrap(err, "updating password reset requests")
} }
@ -58,7 +58,7 @@ func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
sqlStatement := ` sqlStatement := `
DELETE FROM password_reset_requests WHERE user_id = ? DELETE FROM password_reset_requests WHERE user_id = ?
` `
_, err := d.db.Exec(sqlStatement, userID) _, err := d.writer.Exec(sqlStatement, userID)
if err != nil { if err != nil {
return errors.Wrap(err, "deleting password reset request by user") return errors.Wrap(err, "deleting password reset request by user")
} }
@ -72,7 +72,7 @@ func (d *Datastore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest
WHERE id = ? LIMIT 1 WHERE id = ? LIMIT 1
` `
passwordResetRequest := &fleet.PasswordResetRequest{} passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.db.Get(&passwordResetRequest, sqlStatement, id) err := d.reader.Get(&passwordResetRequest, sqlStatement, id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting password reset by id") return nil, errors.Wrap(err, "selecting password reset by id")
} }
@ -87,7 +87,7 @@ func (d *Datastore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordReset
` `
passwordResetRequests := []*fleet.PasswordResetRequest{} passwordResetRequests := []*fleet.PasswordResetRequest{}
err := d.db.Select(&passwordResetRequests, sqlStatement, id) err := d.reader.Select(&passwordResetRequests, sqlStatement, id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "finding password resets by user id") return nil, errors.Wrap(err, "finding password resets by user id")
} }
@ -102,7 +102,7 @@ func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordRese
WHERE token = ? LIMIT 1 WHERE token = ? LIMIT 1
` `
passwordResetRequest := &fleet.PasswordResetRequest{} passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.db.Get(passwordResetRequest, sqlStatement, token) err := d.reader.Get(passwordResetRequest, sqlStatement, token)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting password reset requests") return nil, errors.Wrap(err, "selecting password reset requests")
} }
@ -118,7 +118,7 @@ func (d *Datastore) FindPassswordResetByTokenAndUserID(token string, id uint) (*
LIMIT 1 LIMIT 1
` `
passwordResetRequest := &fleet.PasswordResetRequest{} passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.db.Get(passwordResetRequest, sqlStatement, id, token) err := d.reader.Get(passwordResetRequest, sqlStatement, id, token)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting password reset by token and user id") return nil, errors.Wrap(err, "selecting password reset by token and user id")
} }

View File

@ -12,7 +12,7 @@ import (
) )
func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) { func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
res, err := ds.db.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID) res, err := ds.writer.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "inserting new policy") return nil, errors.Wrap(err, "inserting new policy")
} }
@ -26,7 +26,7 @@ func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) { func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) {
var policy fleet.Policy var policy fleet.Policy
err := ds.db.Get( err := ds.reader.Get(
&policy, &policy,
`SELECT `SELECT
p.*, p.*,
@ -66,7 +66,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
strings.Join(bindvars, ","), strings.Join(bindvars, ","),
) )
_, err := ds.db.Exec(query, vals...) _, err := ds.writer.Exec(query, vals...)
if err != nil { if err != nil {
return errors.Wrapf(err, "insert policy_membership (%v)", vals) return errors.Wrapf(err, "insert policy_membership (%v)", vals)
} }
@ -76,7 +76,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) { func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) {
var policies []*fleet.Policy var policies []*fleet.Policy
err := ds.db.Select( err := ds.reader.Select(
&policies, &policies,
`SELECT `SELECT
p.*, p.*,
@ -97,8 +97,8 @@ func (ds *Datastore) DeleteGlobalPolicies(ids []uint) ([]uint, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "IN for DELETE FROM policies") return nil, errors.Wrap(err, "IN for DELETE FROM policies")
} }
stmt = ds.db.Rebind(stmt) stmt = ds.writer.Rebind(stmt)
if _, err := ds.db.Exec(stmt, args...); err != nil { if _, err := ds.writer.Exec(stmt, args...); err != nil {
return nil, errors.Wrap(err, "delete policies") return nil, errors.Wrap(err, "delete policies")
} }
return ids, nil return ids, nil
@ -109,7 +109,7 @@ func (ds *Datastore) PolicyQueriesForHost(_ *fleet.Host) (map[string]string, err
Id string `db:"id"` Id string `db:"id"`
Query string `db:"query"` Query string `db:"query"`
} }
err := ds.db.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`) err := ds.reader.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting policies for host") return nil, errors.Wrap(err, "selecting policies for host")
} }

View File

@ -10,7 +10,7 @@ import (
) )
func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) { func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) {
tx, err := d.db.Begin() tx, err := d.writer.Begin()
if err != nil { if err != nil {
return errors.Wrap(err, "begin ApplyQueries transaction") return errors.Wrap(err, "begin ApplyQueries transaction")
} }
@ -72,7 +72,7 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
WHERE name = ? WHERE name = ?
` `
var query fleet.Query var query fleet.Query
err := d.db.Get(&query, sqlStatement, name) err := d.reader.Get(&query, sqlStatement, name)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, notFound("Query").WithName(name) return nil, notFound("Query").WithName(name)
@ -99,7 +99,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
observer_can_run observer_can_run
) VALUES ( ?, ?, ?, ?, ?, ? ) ) VALUES ( ?, ?, ?, ?, ?, ? )
` `
result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun) result, err := d.writer.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
if err != nil && isDuplicate(err) { if err != nil && isDuplicate(err) {
return nil, alreadyExists("Query", 0) return nil, alreadyExists("Query", 0)
@ -120,7 +120,7 @@ func (d *Datastore) SaveQuery(q *fleet.Query) error {
SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ? SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ?
WHERE id = ? WHERE id = ?
` `
result, err := d.db.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID) result, err := d.writer.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "updating query") return errors.Wrap(err, "updating query")
} }
@ -156,7 +156,7 @@ func (d *Datastore) Query(id uint) (*fleet.Query, error) {
WHERE q.id = ? WHERE q.id = ?
` `
query := &fleet.Query{} query := &fleet.Query{}
if err := d.db.Get(query, sql, id); err != nil { if err := d.reader.Get(query, sql, id); err != nil {
return nil, errors.Wrap(err, "selecting query") return nil, errors.Wrap(err, "selecting query")
} }
@ -180,7 +180,7 @@ func (d *Datastore) ListQueries(opt fleet.ListOptions) ([]*fleet.Query, error) {
sql = appendListOptionsToSQL(sql, opt) sql = appendListOptionsToSQL(sql, opt)
results := []*fleet.Query{} results := []*fleet.Query{}
if err := d.db.Select(&results, sql); err != nil { if err := d.reader.Select(&results, sql); err != nil {
return nil, errors.Wrap(err, "listing queries") return nil, errors.Wrap(err, "listing queries")
} }
@ -225,7 +225,7 @@ func (d *Datastore) loadPacksForQueries(queries []*fleet.Query) error {
fleet.Pack fleet.Pack
}{} }{}
err = d.db.Select(&rows, query, args...) err = d.reader.Select(&rows, query, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "selecting load packs for queries") return errors.Wrap(err, "selecting load packs for queries")
} }

View File

@ -34,7 +34,7 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
query = appendListOptionsToSQL(query, opts) query = appendListOptionsToSQL(query, opts)
results := []*fleet.ScheduledQuery{} results := []*fleet.ScheduledQuery{}
if err := d.db.Select(&results, query, id); err != nil { if err := d.reader.Select(&results, query, id); err != nil {
return nil, errors.Wrap(err, "listing scheduled queries") return nil, errors.Wrap(err, "listing scheduled queries")
} }
@ -46,8 +46,8 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
} }
func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
selectFunc := d.db.Select selectFunc := d.writer.Select
execFunc := d.db.Exec execFunc := d.writer.Exec
if tx != nil { if tx != nil {
selectFunc = tx.Select selectFunc = tx.Select
execFunc = tx.Exec execFunc = tx.Exec
@ -108,7 +108,7 @@ func (d *Datastore) SaveScheduledQuery(sq *fleet.ScheduledQuery) (*fleet.Schedul
} }
func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) { func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
updateFunc := d.db.Exec updateFunc := d.writer.Exec
if tx != nil { if tx != nil {
updateFunc = tx.Exec updateFunc = tx.Exec
} }
@ -160,7 +160,7 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
WHERE sq.id = ? WHERE sq.id = ?
` `
sq := &fleet.ScheduledQuery{} sq := &fleet.ScheduledQuery{}
if err := d.db.Get(sq, query, id); err != nil { if err := d.reader.Get(sq, query, id); err != nil {
return nil, errors.Wrap(err, "select scheduled query") return nil, errors.Wrap(err, "select scheduled query")
} }
@ -168,11 +168,11 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
} }
func (d *Datastore) CleanupOrphanScheduledQueryStats() error { func (d *Datastore) CleanupOrphanScheduledQueryStats() error {
_, err := d.db.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`) _, err := d.writer.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
if err != nil { if err != nil {
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query") return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query")
} }
_, err = d.db.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`) _, err = d.writer.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
if err != nil { if err != nil {
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host") return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host")
} }

View File

@ -209,7 +209,7 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1") test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1")
sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2") sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2")
_, err := ds.db.Exec(`INSERT INTO scheduled_query_stats ( _, err := ds.writer.Exec(`INSERT INTO scheduled_query_stats (
host_id, scheduled_query_id, average_memory, denylisted, host_id, scheduled_query_id, average_memory, denylisted,
executions, schedule_interval, output_size, system_time, executions, schedule_interval, output_size, system_time,
user_time, wall_time user_time, wall_time
@ -224,12 +224,12 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
require.Len(t, h1.PackStats, 1) require.Len(t, h1.PackStats, 1)
// now we insert a bogus stat // now we insert a bogus stat
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats ( _, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
host_id, scheduled_query_id, average_memory, denylisted, executions host_id, scheduled_query_id, average_memory, denylisted, executions
) VALUES (?, 999, 32, false, 2);`, h1.ID) ) VALUES (?, 999, 32, false, 2);`, h1.ID)
require.NoError(t, err) require.NoError(t, err)
// and also for an unknown host // and also for an unknown host
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats ( _, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
host_id, scheduled_query_id, average_memory, denylisted, executions host_id, scheduled_query_id, average_memory, denylisted, executions
) VALUES (888, 999, 32, true, 4);`) ) VALUES (888, 999, 32, true, 4);`)
require.NoError(t, err) require.NoError(t, err)
@ -241,14 +241,14 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
// but there are definitely there // but there are definitely there
var count int var count int
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`) err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 3, count) assert.Equal(t, 3, count)
// now we clean it up // now we clean it up
require.NoError(t, ds.CleanupOrphanScheduledQueryStats()) require.NoError(t, ds.CleanupOrphanScheduledQueryStats())
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`) err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, count) assert.Equal(t, 1, count)

View File

@ -11,7 +11,7 @@ func (d *Datastore) SessionByKey(key string) (*fleet.Session, error) {
WHERE ` + "`key`" + ` = ? LIMIT 1 WHERE ` + "`key`" + ` = ? LIMIT 1
` `
session := &fleet.Session{} session := &fleet.Session{}
err := d.db.Get(session, sqlStatement, key) err := d.reader.Get(session, sqlStatement, key)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting sessions") return nil, errors.Wrap(err, "selecting sessions")
} }
@ -26,7 +26,7 @@ func (d *Datastore) SessionByID(id uint) (*fleet.Session, error) {
LIMIT 1 LIMIT 1
` `
session := &fleet.Session{} session := &fleet.Session{}
err := d.db.Get(session, sqlStatement, id) err := d.reader.Get(session, sqlStatement, id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting session by id") return nil, errors.Wrap(err, "selecting session by id")
} }
@ -40,7 +40,7 @@ func (d *Datastore) ListSessionsForUser(id uint) ([]*fleet.Session, error) {
WHERE user_id = ? WHERE user_id = ?
` `
sessions := []*fleet.Session{} sessions := []*fleet.Session{}
err := d.db.Select(&sessions, sqlStatement, id) err := d.reader.Select(&sessions, sqlStatement, id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "selecting sessions for user") return nil, errors.Wrap(err, "selecting sessions for user")
} }
@ -57,7 +57,7 @@ func (d *Datastore) NewSession(session *fleet.Session) (*fleet.Session, error) {
) )
VALUES(?,?) VALUES(?,?)
` `
result, err := d.db.Exec(sqlStatement, session.UserID, session.Key) result, err := d.writer.Exec(sqlStatement, session.UserID, session.Key)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "inserting session") return nil, errors.Wrap(err, "inserting session")
} }
@ -80,7 +80,7 @@ func (d *Datastore) DestroyAllSessionsForUser(id uint) error {
sqlStatement := ` sqlStatement := `
DELETE FROM sessions WHERE user_id = ? DELETE FROM sessions WHERE user_id = ?
` `
_, err := d.db.Exec(sqlStatement, id) _, err := d.writer.Exec(sqlStatement, id)
if err != nil { if err != nil {
return errors.Wrap(err, "deleting sessions for user") return errors.Wrap(err, "deleting sessions for user")
} }
@ -94,7 +94,7 @@ func (d *Datastore) MarkSessionAccessed(session *fleet.Session) error {
accessed_at = ? accessed_at = ?
WHERE id = ? WHERE id = ?
` `
results, err := d.db.Exec(sqlStatement, d.clock.Now(), session.ID) results, err := d.writer.Exec(sqlStatement, d.clock.Now(), session.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "updating mark session as accessed") return errors.Wrap(err, "updating mark session as accessed")
} }

View File

@ -206,7 +206,7 @@ func (d *Datastore) insertNewInstalledHostSoftware(
} }
func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) { func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) {
selectFunc := d.db.Select selectFunc := d.reader.Select
if tx != nil { if tx != nil {
selectFunc = tx.Select selectFunc = tx.Select
} }
@ -231,7 +231,7 @@ func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Softwa
WHERE s.id IN WHERE s.id IN
(SELECT software_id FROM host_software WHERE host_id = ?) (SELECT software_id FROM host_software WHERE host_id = ?)
` `
queryFunc := d.db.Queryx queryFunc := d.reader.Queryx
if tx != nil { if tx != nil {
queryFunc = tx.Queryx queryFunc = tx.Queryx
} }
@ -305,7 +305,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null` sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null`
// The rows.Close call is done by the caller once iteration using the // The rows.Close call is done by the caller once iteration using the
// returned fleet.SoftwareIterator is done. // returned fleet.SoftwareIterator is done.
rows, err := d.db.Queryx(sql) //nolint:sqlclosecheck rows, err := d.reader.Queryx(sql) //nolint:sqlclosecheck
if err != nil { if err != nil {
return nil, errors.Wrap(err, "load host software") return nil, errors.Wrap(err, "load host software")
} }
@ -314,7 +314,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error { func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error {
sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)` sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)`
if _, err := d.db.Exec(sql, software.ID, cpe); err != nil { if _, err := d.writer.Exec(sql, software.ID, cpe); err != nil {
return errors.Wrap(err, "insert software cpe") return errors.Wrap(err, "insert software cpe")
} }
return nil return nil
@ -323,7 +323,7 @@ func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error
func (d *Datastore) AllCPEs() ([]string, error) { func (d *Datastore) AllCPEs() ([]string, error) {
sql := `SELECT cpe FROM software_cpe` sql := `SELECT cpe FROM software_cpe`
var cpes []string var cpes []string
err := d.db.Select(&cpes, sql) err := d.reader.Select(&cpes, sql)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "loads cpes") return nil, errors.Wrap(err, "loads cpes")
} }
@ -337,7 +337,7 @@ func (d *Datastore) InsertCVEForCPE(cve string, cpes []string) error {
for _, cpe := range cpes { for _, cpe := range cpes {
args = append(args, cpe, cve) args = append(args, cpe, cve)
} }
_, err := d.db.Exec(sql, args...) _, err := d.writer.Exec(sql, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "insert software cve") return errors.Wrap(err, "insert software cve")
} }

View File

@ -207,7 +207,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
}) })
incoming[soft2Key] = true incoming[soft2Key] = true
tx, err := ds.db.Beginx() tx, err := ds.writer.Beginx()
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming)) require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
require.NoError(t, tx.Commit()) require.NoError(t, tx.Commit())
@ -220,7 +220,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
}) })
incoming[soft3Key] = true incoming[soft3Key] = true
tx, err = ds.db.Beginx() tx, err = ds.writer.Beginx()
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming)) require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
require.NoError(t, tx.Commit()) require.NoError(t, tx.Commit())

View File

@ -21,14 +21,14 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
} }
dest := statistics{} dest := statistics{}
err = d.db.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`) err = d.writer.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
anonIdentifier, err := server.GenerateRandomText(64) anonIdentifier, err := server.GenerateRandomText(64)
if err != nil { if err != nil {
return fleet.StatisticsPayload{}, false, err return fleet.StatisticsPayload{}, false, err
} }
_, err = d.db.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier) _, err = d.writer.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
if err != nil { if err != nil {
return fleet.StatisticsPayload{}, false, err return fleet.StatisticsPayload{}, false, err
} }
@ -55,6 +55,6 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
} }
func (d *Datastore) RecordStatisticsSent() error { func (d *Datastore) RecordStatisticsSent() error {
_, err := d.db.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`) _, err := d.writer.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
return err return err
} }

View File

@ -52,7 +52,7 @@ func (d *Datastore) CountHostsInTargets(filter fleet.TeamFilter, targets fleet.H
} }
res := fleet.TargetMetrics{} res := fleet.TargetMetrics{}
err = d.db.Get(&res, query, args...) err = d.reader.Get(&res, query, args...)
if err != nil { if err != nil {
return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets") return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets")
} }
@ -98,7 +98,7 @@ func (d *Datastore) HostIDsInTargets(filter fleet.TeamFilter, targets fleet.Host
} }
var res []uint var res []uint
err = d.db.Select(&res, query, args...) err = d.reader.Select(&res, query, args...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets") return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets")
} }

View File

@ -20,7 +20,7 @@ func (d *Datastore) NewTeam(team *fleet.Team) (*fleet.Team, error) {
description description
) VALUES ( ?, ?, ? ) ) VALUES ( ?, ?, ? )
` `
result, err := d.db.Exec( result, err := d.writer.Exec(
query, query,
team.Name, team.Name,
team.AgentOptions, team.AgentOptions,
@ -47,7 +47,7 @@ func (d *Datastore) Team(tid uint) (*fleet.Team, error) {
` `
team := &fleet.Team{} team := &fleet.Team{}
if err := d.db.Get(team, sql, tid); err != nil { if err := d.reader.Get(team, sql, tid); err != nil {
return nil, errors.Wrap(err, "select team") return nil, errors.Wrap(err, "select team")
} }
@ -84,7 +84,7 @@ func (d *Datastore) TeamByName(name string) (*fleet.Team, error) {
` `
team := &fleet.Team{} team := &fleet.Team{}
if err := d.db.Get(team, sql, name); err != nil { if err := d.reader.Get(team, sql, name); err != nil {
return nil, errors.Wrap(err, "select team") return nil, errors.Wrap(err, "select team")
} }
@ -106,7 +106,7 @@ func (d *Datastore) loadUsersForTeam(team *fleet.Team) error {
WHERE ut.team_id = ? WHERE ut.team_id = ?
` `
rows := []fleet.TeamUser{} rows := []fleet.TeamUser{}
if err := d.db.Select(&rows, sql, team.ID); err != nil { if err := d.reader.Select(&rows, sql, team.ID); err != nil {
return errors.Wrap(err, "load users for team") return errors.Wrap(err, "load users for team")
} }
@ -157,7 +157,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
description = ? description = ?
WHERE id = ? WHERE id = ?
` `
_, err := d.db.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID) _, err := d.writer.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "saving team") return nil, errors.Wrap(err, "saving team")
} }
@ -174,7 +174,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
} }
func (d *Datastore) updateTeamSchedule(team *fleet.Team) error { func (d *Datastore) updateTeamSchedule(team *fleet.Team) error {
_, err := d.db.Exec( _, err := d.writer.Exec(
`UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team), `UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team),
) )
return err return err
@ -195,7 +195,7 @@ func (d *Datastore) ListTeams(filter fleet.TeamFilter, opt fleet.ListOptions) ([
query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...) query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...)
query = appendListOptionsToSQL(query, opt) query = appendListOptionsToSQL(query, opt)
teams := []*fleet.Team{} teams := []*fleet.Team{}
if err := d.db.Select(&teams, query, params...); err != nil { if err := d.reader.Select(&teams, query, params...); err != nil {
return nil, errors.Wrap(err, "list teams") return nil, errors.Wrap(err, "list teams")
} }
if err := d.loadSecretsForTeams(teams); err != nil { if err := d.loadSecretsForTeams(teams); err != nil {
@ -229,7 +229,7 @@ func (d *Datastore) SearchTeams(filter fleet.TeamFilter, matchQuery string, omit
sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...) sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...)
sql += "\nLIMIT 5" sql += "\nLIMIT 5"
teams := []*fleet.Team{} teams := []*fleet.Team{}
if err := d.db.Select(&teams, sql, params...); err != nil { if err := d.reader.Select(&teams, sql, params...); err != nil {
return nil, errors.Wrap(err, "search teams") return nil, errors.Wrap(err, "search teams")
} }
if err := d.loadSecretsForTeams(teams); err != nil { if err := d.loadSecretsForTeams(teams); err != nil {
@ -244,7 +244,7 @@ func (d *Datastore) TeamEnrollSecrets(teamID uint) ([]*fleet.EnrollSecret, error
WHERE team_id = ? WHERE team_id = ?
` `
var secrets []*fleet.EnrollSecret var secrets []*fleet.EnrollSecret
if err := d.db.Select(&secrets, sql, teamID); err != nil { if err := d.reader.Select(&secrets, sql, teamID); err != nil {
return nil, errors.Wrap(err, "get secrets") return nil, errors.Wrap(err, "get secrets")
} }
return secrets, nil return secrets, nil

View File

@ -1,6 +1,7 @@
package mysql package mysql
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -9,10 +10,12 @@ import (
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time"
"github.com/WatchBeam/clock" "github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/config"
"github.com/go-kit/kit/log" "github.com/go-kit/kit/log"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -20,9 +23,10 @@ const (
testUsername = "root" testUsername = "root"
testPassword = "toor" testPassword = "toor"
testAddress = "localhost:3307" testAddress = "localhost:3307"
testReplicaDatabaseSuffix = "_replica"
) )
func connectMySQL(t *testing.T, testName string) *Datastore { func connectMySQL(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
config := config.MysqlConfig{ config := config.MysqlConfig{
Username: testUsername, Username: testUsername,
Password: testPassword, Password: testPassword,
@ -31,15 +35,134 @@ func connectMySQL(t *testing.T, testName string) *Datastore {
} }
// Create datastore client // Create datastore client
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1)) var replicaOpt DBOption
if opts.Replica {
replicaConf := config
replicaConf.Database += testReplicaDatabaseSuffix
replicaOpt = Replica(&replicaConf)
}
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1), replicaOpt)
require.Nil(t, err) require.Nil(t, err)
if opts.Replica {
setupReadReplica(t, testName, ds, opts)
}
return ds return ds
} }
func setupReadReplica(t *testing.T, testName string, ds *Datastore, opts *DatastoreTestOptions) {
// create the context that will cancel the replication goroutine on test exit
var cancel func()
ctx := context.Background()
if dl, ok := t.Deadline(); ok {
ctx, cancel = context.WithDeadline(ctx, dl)
} else {
ctx, cancel = context.WithCancel(ctx)
}
t.Cleanup(cancel)
// start the replication goroutine that runs when signalled through a
// channel, the replication runs in lock-step - the test is in control of
// when the replication happens, by calling opts.RunReplication(), and when
// that call returns, the replication is guaranteed to be done. This supports
// simulating all kinds of replica lag.
ch := make(chan chan struct{})
go func() {
// if it exits because of a panic/failed replication, cancel the context
// immediately so that RunReplication is unblocked too.
defer cancel()
primary := ds.writer
replica := ds.reader.(*sqlx.DB)
replicaDB := testName + testReplicaDatabaseSuffix
last := time.Now().Add(-time.Minute)
// drop all foreign keys in the replica, as that causes issues even with
// FOREIGN_KEY_CHECKS=0
var fks []struct {
TableName string `db:"table_name"`
ConstraintName string `db:"constraint_name"`
}
err := primary.SelectContext(ctx, &fks, `
SELECT
table_name, constraint_name
FROM
information_schema.key_column_usage
WHERE
table_schema = ? AND
referenced_table_name IS NOT NULL`, testName)
require.NoError(t, err)
for _, fk := range fks {
stmt := fmt.Sprintf(`ALTER TABLE %s.%s DROP FOREIGN KEY %s`, replicaDB, fk.TableName, fk.ConstraintName)
_, err := replica.ExecContext(ctx, stmt)
require.NoError(t, err)
}
for {
select {
case out := <-ch:
// identify tables with changes since the last call
var tables []string
err := primary.SelectContext(ctx, &tables, `
SELECT
table_name
FROM
information_schema.tables
WHERE
table_schema = ? AND
table_type = 'BASE TABLE' AND
update_time >= ?`, testName, last)
require.NoError(t, err)
err = primary.GetContext(ctx, &last, `
SELECT
MAX(update_time)
FROM
information_schema.tables
WHERE
table_schema = ? AND
table_type = 'BASE TABLE'`, testName)
require.NoError(t, err)
// replicate by dropping the existing table and re-creating it from
// the primary.
for _, tbl := range tables {
stmt := fmt.Sprintf(`DROP TABLE IF EXISTS %s.%s`, replicaDB, tbl)
t.Log(stmt)
_, err = replica.ExecContext(ctx, stmt)
require.NoError(t, err)
stmt = fmt.Sprintf(`CREATE TABLE %s.%s SELECT * FROM %s.%s`, replicaDB, tbl, testName, tbl)
t.Log(stmt)
_, err = replica.ExecContext(ctx, stmt)
require.NoError(t, err)
}
out <- struct{}{}
t.Logf("replication step executed, next will consider updates since %s", last)
case <-ctx.Done():
return
}
}
}()
// set RunReplication to a function that triggers the replication and waits
// for it to complete.
opts.RunReplication = func() {
done := make(chan struct{})
ch <- done
select {
case <-done:
case <-ctx.Done():
}
}
}
// initializeDatabase loads the dumped schema into a newly created database in // initializeDatabase loads the dumped schema into a newly created database in
// MySQL. This is much faster than running the full set of migrations on each // MySQL. This is much faster than running the full set of migrations on each
// test. // test.
func initializeDatabase(t *testing.T, testName string) *Datastore { func initializeDatabase(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
base := path.Dir(filename) base := path.Dir(filename)
schema, err := ioutil.ReadFile(path.Join(base, "schema.sql")) schema, err := ioutil.ReadFile(path.Join(base, "schema.sql"))
@ -47,6 +170,14 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
t.Error(err) t.Error(err)
t.FailNow() t.FailNow()
} }
// execute the schema for the test db, and once more for the replica db if
// that option is set.
dbs := []string{testName}
if opts.Replica {
dbs = append(dbs, testName+testReplicaDatabaseSuffix)
}
for _, dbName := range dbs {
// Load schema from dumpfile // Load schema from dumpfile
if out, err := exec.Command( if out, err := exec.Command(
"docker-compose", "exec", "-T", "mysql_test", "docker-compose", "exec", "-T", "mysql_test",
@ -56,24 +187,43 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
"-e", "-e",
fmt.Sprintf( fmt.Sprintf(
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;", "DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
testName, testName, testName, schema, dbName, dbName, dbName, schema,
), ),
).CombinedOutput(); err != nil { ).CombinedOutput(); err != nil {
t.Error(err) t.Error(err)
t.Error(string(out)) t.Error(string(out))
t.FailNow() t.FailNow()
} }
return connectMySQL(t, testName) }
return connectMySQL(t, testName, opts)
} }
func CreateMySQLDS(t *testing.T) *Datastore { // DatastoreTestOptions configures how the test datastore is created
// by CreateMySQLDSWithOptions.
type DatastoreTestOptions struct {
// Replica indicates that a read replica test database should be created.
Replica bool
// RunReplication is the function to call to execute the replication of all
// missing changes from the primary to the replica. The function is created
// and set automatically by CreateMySQLDSWithOptions. The test is in full
// control of when the replication is executed.
RunReplication func()
}
func createMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
if _, ok := os.LookupEnv("MYSQL_TEST"); !ok { if _, ok := os.LookupEnv("MYSQL_TEST"); !ok {
t.Skip("MySQL tests are disabled") t.Skip("MySQL tests are disabled")
} }
t.Parallel() t.Parallel()
pc, _, _, ok := runtime.Caller(1) if opts == nil {
// so it is never nil in internal helper functions
opts = new(DatastoreTestOptions)
}
pc, _, _, ok := runtime.Caller(2)
details := runtime.FuncForPC(pc) details := runtime.FuncForPC(pc)
if !ok || details == nil { if !ok || details == nil {
t.FailNow() t.FailNow()
@ -83,5 +233,13 @@ func CreateMySQLDS(t *testing.T) *Datastore {
strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_", strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_",
) )
cleanName = strings.ReplaceAll(cleanName, ".", "_") cleanName = strings.ReplaceAll(cleanName, ".", "_")
return initializeDatabase(t, cleanName) return initializeDatabase(t, cleanName, opts)
}
func CreateMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
return createMySQLDSWithOptions(t, opts)
}
func CreateMySQLDS(t *testing.T) *Datastore {
return createMySQLDSWithOptions(t, nil)
} }

View File

@ -72,7 +72,7 @@ func (d *Datastore) findUser(searchCol string, searchVal interface{}) (*fleet.Us
user := &fleet.User{} user := &fleet.User{}
err := d.db.Get(user, sqlStatement, searchVal) err := d.reader.Get(user, sqlStatement, searchVal)
if err != nil && err == sql.ErrNoRows { if err != nil && err == sql.ErrNoRows {
return nil, notFound("User"). return nil, notFound("User").
WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal)) WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal))
@ -104,7 +104,7 @@ func (d *Datastore) ListUsers(opt fleet.UserListOptions) ([]*fleet.User, error)
sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions) sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions)
users := []*fleet.User{} users := []*fleet.User{}
if err := d.db.Select(&users, sqlStatement, params...); err != nil { if err := d.reader.Select(&users, sqlStatement, params...); err != nil {
return nil, errors.Wrap(err, "list users") return nil, errors.Wrap(err, "list users")
} }
@ -160,7 +160,7 @@ func (d *Datastore) saveUser(tx *sqlx.Tx, user *fleet.User) error {
global_role = ? global_role = ?
WHERE id = ? WHERE id = ?
` `
result, err := d.db.Exec(sqlStatement, result, err := tx.Exec(sqlStatement,
user.Password, user.Password,
user.Salt, user.Salt,
user.Name, user.Name,
@ -221,7 +221,7 @@ func (d *Datastore) loadTeamsForUsers(users []*fleet.User) error {
fleet.UserTeam fleet.UserTeam
UserID uint `db:"user_id"` UserID uint `db:"user_id"`
} }
if err := d.db.Select(&rows, sql, args...); err != nil { if err := d.reader.Select(&rows, sql, args...); err != nil {
return errors.Wrap(err, "get loadTeamsForUsers") return errors.Wrap(err, "get loadTeamsForUsers")
} }