Add support for MySQL read replicas (#1786)

Adds configuration options to use a read-only MySQL replica,
and uses it instead of the primary for reads.
This commit is contained in:
Martin Angers 2021-09-01 15:50:52 -04:00 committed by GitHub
parent 2116217b6e
commit 3755a58070
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 630 additions and 319 deletions

View File

@ -0,0 +1 @@
* Add support for MySQL read replica.

View File

@ -149,7 +149,11 @@ the way that the Fleet server works.
var carveStore fleet.CarveStore
mailService := mail.NewService()
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger))
var replicaOpt mysql.DBOption
if config.MysqlReadReplica.Address != "" {
replicaOpt = mysql.Replica(&config.MysqlReadReplica)
}
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger), replicaOpt)
if err != nil {
initFatal(err, "initializing datastore")
}

View File

@ -25,7 +25,7 @@ Your Fleet server's two main purposes are:
- To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/)
- To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc.
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. It is also possible to configure connection to a MySQL replica in addition to the primary, to be used for reading only. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
> Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance.
@ -118,12 +118,21 @@ mysql:
address: 127.0.0.1:3306
```
And `mysql_read_replica_address` would be:
```
mysql_read_replica:
address: 127.0.0.1:3307
```
Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way.
All duration-based settings accept valid time units of `s`, `m`, `h`.
##### MySQL
This section describes the configuration options for the primary - if you also want to setup a read replica, the options are the same, except that the yaml section is `mysql_read_replica`, and the flags have the `mysql_read_replica_` prefix instead of `mysql_` (the corresponding environment variables follow the same transformation). Note that there is no default value for `mysql_read_replica_address`, it must be set explicitly for fleet to use a read replica.
###### `mysql_address`
The address of the MySQL server which Fleet should connect to. Include the hostname and port.

View File

@ -177,6 +177,7 @@ type VulnerabilitiesConfig struct {
// updated to set and retrieve the configurations as appropriate.
type FleetConfig struct {
Mysql MysqlConfig
MysqlReadReplica MysqlConfig `yaml:"mysql_read_replica"`
Redis RedisConfig
Server ServerConfig
Auth AuthConfig
@ -197,32 +198,36 @@ type FleetConfig struct {
// addConfigs adds the configuration keys and default values that will be
// filled into the FleetConfig struct
func (man Manager) addConfigs() {
addMysqlConfig := func(prefix, defaultAddr, usageSuffix string) {
man.addConfigString(prefix+".protocol", "tcp",
"MySQL server communication protocol (tcp,unix,...)"+usageSuffix)
man.addConfigString(prefix+".address", defaultAddr,
"MySQL server address (host:port)"+usageSuffix)
man.addConfigString(prefix+".username", "fleet",
"MySQL server username"+usageSuffix)
man.addConfigString(prefix+".password", "",
"MySQL server password (prefer env variable for security)"+usageSuffix)
man.addConfigString(prefix+".password_path", "",
"Path to file containg MySQL server password"+usageSuffix)
man.addConfigString(prefix+".database", "fleet",
"MySQL database name"+usageSuffix)
man.addConfigString(prefix+".tls_cert", "",
"MySQL TLS client certificate path"+usageSuffix)
man.addConfigString(prefix+".tls_key", "",
"MySQL TLS client key path"+usageSuffix)
man.addConfigString(prefix+".tls_ca", "",
"MySQL TLS server CA"+usageSuffix)
man.addConfigString(prefix+".tls_server_name", "",
"MySQL TLS server name"+usageSuffix)
man.addConfigString(prefix+".tls_config", "",
"MySQL TLS config value"+usageSuffix+" Use skip-verify, true, false or custom key.")
man.addConfigInt(prefix+".max_open_conns", 50, "MySQL maximum open connection handles"+usageSuffix)
man.addConfigInt(prefix+".max_idle_conns", 50, "MySQL maximum idle connection handles"+usageSuffix)
man.addConfigInt(prefix+".conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused"+usageSuffix)
}
// MySQL
man.addConfigString("mysql.protocol", "tcp",
"MySQL server communication protocol (tcp,unix,...)")
man.addConfigString("mysql.address", "localhost:3306",
"MySQL server address (host:port)")
man.addConfigString("mysql.username", "fleet",
"MySQL server username")
man.addConfigString("mysql.password", "",
"MySQL server password (prefer env variable for security)")
man.addConfigString("mysql.password_path", "",
"Path to file containg MySQL server password")
man.addConfigString("mysql.database", "fleet",
"MySQL database name")
man.addConfigString("mysql.tls_cert", "",
"MySQL TLS client certificate path")
man.addConfigString("mysql.tls_key", "",
"MySQL TLS client key path")
man.addConfigString("mysql.tls_ca", "",
"MySQL TLS server CA")
man.addConfigString("mysql.tls_server_name", "",
"MySQL TLS server name")
man.addConfigString("mysql.tls_config", "",
"MySQL TLS config value. Use skip-verify, true, false or custom key.")
man.addConfigInt("mysql.max_open_conns", 50, "MySQL maximum open connection handles.")
man.addConfigInt("mysql.max_idle_conns", 50, "MySQL maximum idle connection handles.")
man.addConfigInt("mysql.conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused.")
addMysqlConfig("mysql", "localhost:3306", ".")
addMysqlConfig("mysql_read_replica", "", " for the read replica.")
// Redis
man.addConfigString("redis.address", "localhost:6379",
@ -382,23 +387,28 @@ func (man Manager) addConfigs() {
func (man Manager) LoadConfig() FleetConfig {
man.loadConfigFile()
loadMysqlConfig := func(prefix string) MysqlConfig {
return MysqlConfig{
Protocol: man.getConfigString(prefix + ".protocol"),
Address: man.getConfigString(prefix + ".address"),
Username: man.getConfigString(prefix + ".username"),
Password: man.getConfigString(prefix + ".password"),
PasswordPath: man.getConfigString(prefix + ".password_path"),
Database: man.getConfigString(prefix + ".database"),
TLSCert: man.getConfigString(prefix + ".tls_cert"),
TLSKey: man.getConfigString(prefix + ".tls_key"),
TLSCA: man.getConfigString(prefix + ".tls_ca"),
TLSServerName: man.getConfigString(prefix + ".tls_server_name"),
TLSConfig: man.getConfigString(prefix + ".tls_config"),
MaxOpenConns: man.getConfigInt(prefix + ".max_open_conns"),
MaxIdleConns: man.getConfigInt(prefix + ".max_idle_conns"),
ConnMaxLifetime: man.getConfigInt(prefix + ".conn_max_lifetime"),
}
}
return FleetConfig{
Mysql: MysqlConfig{
Protocol: man.getConfigString("mysql.protocol"),
Address: man.getConfigString("mysql.address"),
Username: man.getConfigString("mysql.username"),
Password: man.getConfigString("mysql.password"),
PasswordPath: man.getConfigString("mysql.password_path"),
Database: man.getConfigString("mysql.database"),
TLSCert: man.getConfigString("mysql.tls_cert"),
TLSKey: man.getConfigString("mysql.tls_key"),
TLSCA: man.getConfigString("mysql.tls_ca"),
TLSServerName: man.getConfigString("mysql.tls_server_name"),
TLSConfig: man.getConfigString("mysql.tls_config"),
MaxOpenConns: man.getConfigInt("mysql.max_open_conns"),
MaxIdleConns: man.getConfigInt("mysql.max_idle_conns"),
ConnMaxLifetime: man.getConfigInt("mysql.conn_max_lifetime"),
},
Mysql: loadMysqlConfig("mysql"),
MysqlReadReplica: loadMysqlConfig("mysql_read_replica"),
Redis: RedisConfig{
Address: man.getConfigString("redis.address"),
Password: man.getConfigString("redis.password"),

View File

@ -58,6 +58,7 @@ func TestConfigRoundtrip(t *testing.T) {
// Marshal the generated config
buf, err := yaml.Marshal(original)
require.Nil(t, err)
t.Log(string(buf))
// Manually load the serialized config
man.viper.SetConfigType("yaml")

View File

@ -3,6 +3,7 @@ package mysql
import (
"database/sql"
"encoding/json"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/pkg/errors"
)
@ -13,7 +14,7 @@ func (d *Datastore) NewActivity(user *fleet.User, activityType string, details *
if err != nil {
return errors.Wrap(err, "marshaling activity details")
}
_, err = d.db.Exec(
_, err = d.writer.Exec(
`INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`,
user.ID,
user.Name,
@ -34,7 +35,7 @@ func (d *Datastore) ListActivities(opt fleet.ListOptions) ([]*fleet.Activity, er
WHERE true`
query = appendListOptionsToSQL(query, opt)
err := d.db.Select(&activities, query)
err := d.reader.Select(&activities, query)
if err == sql.ErrNoRows {
return nil, notFound("Activity")
} else if err != nil {

View File

@ -25,7 +25,7 @@ func (d *Datastore) NewAppConfig(info *fleet.AppConfig) (*fleet.AppConfig, error
func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
info := &fleet.AppConfig{}
var bytes []byte
err := d.db.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
err := d.reader.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "selecting app config")
}
@ -43,7 +43,7 @@ func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
}
func (d *Datastore) isEventSchedulerEnabled() (bool, error) {
rows, err := d.db.Query("SELECT @@event_scheduler")
rows, err := d.writer.Query("SELECT @@event_scheduler")
if err != nil {
return false, err
}
@ -136,7 +136,7 @@ func (d *Datastore) SaveAppConfig(info *fleet.AppConfig) error {
func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) {
var s fleet.EnrollSecret
err := d.db.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
err := d.reader.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
if err != nil {
return nil, errors.New("no matching secret found")
}
@ -184,7 +184,7 @@ func (d *Datastore) GetEnrollSecrets(teamID *uint) ([]*fleet.EnrollSecret, error
args = append(args, teamID)
}
var secrets []*fleet.EnrollSecret
if err := d.db.Select(&secrets, sql, args...); err != nil {
if err := d.reader.Select(&secrets, sql, args...); err != nil {
return nil, errors.Wrap(err, "get secrets")
}
return secrets, nil

View File

@ -259,7 +259,7 @@ func TestAppConfigDefaults(t *testing.T) {
defer ds.Close()
insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)`
_, err := ds.db.Exec(insertAppConfigQuery, `{}`)
_, err := ds.writer.Exec(insertAppConfigQuery, `{}`)
require.NoError(t, err)
ac, err := ds.AppConfig()
@ -269,7 +269,7 @@ func TestAppConfigDefaults(t *testing.T) {
require.True(t, ac.HostSettings.EnableHostUsers)
require.False(t, ac.HostSettings.EnableSoftwareInventory)
_, err = ds.db.Exec(
_, err = ds.writer.Exec(
insertAppConfigQuery,
`{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`,
)

View File

@ -17,7 +17,7 @@ func (d *Datastore) NewDistributedQueryCampaign(camp *fleet.DistributedQueryCamp
)
VALUES(?,?,?)
`
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
if err != nil {
return nil, errors.Wrap(err, "inserting distributed query campaign")
}
@ -32,7 +32,7 @@ func (d *Datastore) DistributedQueryCampaign(id uint) (*fleet.DistributedQueryCa
SELECT * FROM distributed_query_campaigns WHERE id = ?
`
campaign := &fleet.DistributedQueryCampaign{}
if err := d.db.Get(campaign, sql, id); err != nil {
if err := d.reader.Get(campaign, sql, id); err != nil {
return nil, errors.Wrap(err, "selecting distributed query campaign")
}
@ -47,7 +47,7 @@ func (d *Datastore) SaveDistributedQueryCampaign(camp *fleet.DistributedQueryCam
user_id = ?
WHERE id = ?
`
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
if err != nil {
return errors.Wrap(err, "updating distributed query campaign")
}
@ -68,7 +68,7 @@ func (d *Datastore) DistributedQueryCampaignTargetIDs(id uint) (*fleet.HostTarge
`
targets := []fleet.DistributedQueryCampaignTarget{}
if err := d.db.Select(&targets, sqlStatement, id); err != nil {
if err := d.reader.Select(&targets, sqlStatement, id); err != nil {
return nil, errors.Wrap(err, "select distributed campaign target")
}
@ -100,7 +100,7 @@ func (d *Datastore) NewDistributedQueryCampaignTarget(target *fleet.DistributedQ
)
VALUES (?,?,?)
`
result, err := d.db.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
result, err := d.writer.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
if err != nil {
return nil, errors.Wrap(err, "insert distributed campaign target")
}
@ -118,7 +118,7 @@ func (d *Datastore) CleanupDistributedQueryCampaigns(now time.Time) (expired uin
WHERE (status = ? AND created_at < ?)
OR (status = ? AND created_at < ?)
`
result, err := d.db.Exec(sqlStatement, fleet.QueryComplete,
result, err := d.writer.Exec(sqlStatement, fleet.QueryComplete,
fleet.QueryWaiting, now.Add(-1*time.Minute),
fleet.QueryRunning, now.Add(-24*time.Hour))
if err != nil {

View File

@ -33,7 +33,7 @@ func (d *Datastore) NewCarve(metadata *fleet.CarveMetadata) (*fleet.CarveMetadat
?
)`
result, err := d.db.Exec(
result, err := d.writer.Exec(
stmt,
metadata.HostId,
metadata.CreatedAt.Format(mySQLTimestampFormat),
@ -64,7 +64,7 @@ func (d *Datastore) UpdateCarve(metadata *fleet.CarveMetadata) error {
expired = ?
WHERE id = ?
`
_, err := d.db.Exec(
_, err := d.writer.Exec(
stmt,
metadata.MaxBlock,
metadata.Expired,
@ -163,7 +163,7 @@ func (d *Datastore) Carve(carveId int64) (*fleet.CarveMetadata, error) {
)
var metadata fleet.CarveMetadata
if err := d.db.Get(&metadata, stmt, carveId); err != nil {
if err := d.reader.Get(&metadata, stmt, carveId); err != nil {
return nil, errors.Wrap(err, "get carve by ID")
}
@ -179,7 +179,7 @@ func (d *Datastore) CarveBySessionId(sessionId string) (*fleet.CarveMetadata, er
)
var metadata fleet.CarveMetadata
if err := d.db.Get(&metadata, stmt, sessionId); err != nil {
if err := d.reader.Get(&metadata, stmt, sessionId); err != nil {
return nil, errors.Wrap(err, "get carve by session ID")
}
@ -195,7 +195,7 @@ func (d *Datastore) CarveByName(name string) (*fleet.CarveMetadata, error) {
)
var metadata fleet.CarveMetadata
if err := d.db.Get(&metadata, stmt, name); err != nil {
if err := d.reader.Get(&metadata, stmt, name); err != nil {
return nil, errors.Wrap(err, "get carve by name")
}
@ -213,7 +213,7 @@ func (d *Datastore) ListCarves(opt fleet.CarveListOptions) ([]*fleet.CarveMetada
}
stmt = appendListOptionsToSQL(stmt, opt.ListOptions)
carves := []*fleet.CarveMetadata{}
if err := d.db.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
if err := d.reader.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "list carves")
}
@ -231,13 +231,14 @@ func (d *Datastore) NewBlock(metadata *fleet.CarveMetadata, blockId int64, data
?,
?
)`
if _, err := d.db.Exec(stmt, metadata.ID, blockId, data); err != nil {
if _, err := d.writer.Exec(stmt, metadata.ID, blockId, data); err != nil {
return errors.Wrap(err, "insert carve block")
}
if metadata.MaxBlock < blockId {
// Update max_block
metadata.MaxBlock = blockId
// TODO: looks like this should all be done in an atomic transaction?
if err := d.UpdateCarve(metadata); err != nil {
return errors.Wrap(err, "insert carve block")
}
@ -253,7 +254,7 @@ func (d *Datastore) GetBlock(metadata *fleet.CarveMetadata, blockId int64) ([]by
WHERE metadata_id = ? AND block_id = ?
`
var data []byte
if err := d.db.Get(&data, stmt, metadata.ID, blockId); err != nil {
if err := d.reader.Get(&data, stmt, metadata.ID, blockId); err != nil {
return nil, errors.Wrap(err, "select data")
}

View File

@ -1,6 +1,9 @@
package mysql
import "github.com/go-kit/kit/log"
import (
"github.com/fleetdm/fleet/v4/server/config"
"github.com/go-kit/kit/log"
)
const defaultMaxAttempts int = 15
@ -11,6 +14,7 @@ type dbOptions struct {
// maxAttempts configures the number of retries to connect to the DB
maxAttempts int
logger log.Logger
replicaConfig *config.MysqlConfig
}
// Logger adds a logger to the datastore
@ -21,6 +25,14 @@ func Logger(l log.Logger) DBOption {
}
}
// Replica sets the configuration of the read replica for the datastore.
func Replica(conf *config.MysqlConfig) DBOption {
return func(o *dbOptions) error {
o.replicaConfig = conf
return nil
}
}
// LimitAttempts sets a the number of attempts
// to try establishing a connection to the database backend
// the default value is 15 attempts

View File

@ -11,7 +11,7 @@ import (
// returning a notFound error if appropriate.
func (d *Datastore) deleteEntity(dbTable string, id uint) error {
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable)
result, err := d.db.Exec(deleteStmt, id)
result, err := d.writer.Exec(deleteStmt, id)
if err != nil {
return errors.Wrapf(err, "delete %s", dbTable)
}
@ -26,7 +26,7 @@ func (d *Datastore) deleteEntity(dbTable string, id uint) error {
// table, returning a notFound error if appropriate.
func (d *Datastore) deleteEntityByName(dbTable string, name string) error {
deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable)
result, err := d.db.Exec(deleteStmt, name)
result, err := d.writer.Exec(deleteStmt, name)
if err != nil {
if isMySQLForeignKey(err) {
return foreignKey(dbTable, name)
@ -50,7 +50,7 @@ func (d *Datastore) deleteEntities(dbTable string, ids []uint) (uint, error) {
return 0, errors.Wrapf(err, "building delete entities query %s", dbTable)
}
result, err := d.db.Exec(query, args...)
result, err := d.writer.Exec(query, args...)
if err != nil {
return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable)
}

View File

@ -15,7 +15,7 @@ func (ds *Datastore) PendingEmailChange(uid uint, newEmail, token string) error
new_email
) VALUES( ?, ?, ? )
`
_, err := ds.db.Exec(sqlStatement, uid, token, newEmail)
_, err := ds.writer.Exec(sqlStatement, uid, token, newEmail)
if err != nil {
return errors.Wrap(err, "inserting email change record")
}
@ -32,7 +32,7 @@ func (ds *Datastore) ConfirmPendingEmailChange(id uint, token string) (newEmail
Token string
NewEmail string `db:"new_email"`
}{}
err = ds.db.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
err = ds.writer.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
if err != nil {
if err == sql.ErrNoRows {
return "", notFound("email change with token")

View File

@ -33,7 +33,7 @@ func (d *Datastore) NewHost(host *fleet.Host) (*fleet.Host, error) {
)
VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? )
`
result, err := d.db.Exec(
result, err := d.writer.Exec(
sqlStatement,
host.OsqueryHostID,
host.DetailUpdatedAt,
@ -95,7 +95,7 @@ func (d *Datastore) SaveHost(host *fleet.Host) error {
percent_disk_space_available = ?
WHERE id = ?
`
_, err := d.db.Exec(sqlStatement,
_, err := d.writer.Exec(sqlStatement,
host.DetailUpdatedAt,
host.LabelUpdatedAt,
host.NodeKey,
@ -248,7 +248,7 @@ FROM scheduled_query_stats sqs
WHERE host_id = ? AND p.pack_type IS NULL
`
var stats []fleet.ScheduledQueryStats
if err := d.db.Select(&stats, sql, host.ID); err != nil {
if err := d.reader.Select(&stats, sql, host.ID); err != nil {
return errors.Wrap(err, "load pack stats")
}
@ -268,9 +268,9 @@ WHERE host_id = ? AND p.pack_type IS NULL
return nil
}
func (d *Datastore) loadHostUsers(host *fleet.Host) error {
func (d *Datastore) loadHostUsers(db dbReader, host *fleet.Host) error {
sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL`
if err := d.db.Select(&host.Users, sql, host.ID); err != nil {
if err := db.Select(&host.Users, sql, host.ID); err != nil {
return errors.Wrap(err, "load pack stats")
}
return nil
@ -292,14 +292,14 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
LIMIT 1
`
host := &fleet.Host{}
err := d.db.Get(host, sqlStatement, id)
err := d.reader.Get(host, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "get host by id")
}
if err := d.loadHostPackStats(host); err != nil {
return nil, err
}
if err := d.loadHostUsers(host); err != nil {
if err := d.loadHostUsers(d.reader, host); err != nil {
return nil, err
}
@ -308,7 +308,9 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
func (d *Datastore) amountEnrolledHosts() (int, error) {
var amount int
err := d.db.Get(&amount, `SELECT count(*) FROM hosts`)
// uses the writer as it is called from ShouldSendStatistics, which is a
// write method.
err := d.writer.Get(&amount, `SELECT count(*) FROM hosts`)
if err != nil {
return 0, err
}
@ -363,7 +365,7 @@ func (d *Datastore) ListHosts(filter fleet.TeamFilter, opt fleet.HostListOptions
sql = appendListOptionsToSQL(sql, opt.ListOptions)
hosts := []*fleet.Host{}
if err := d.db.Select(&hosts, sql, params...); err != nil {
if err := d.reader.Select(&hosts, sql, params...); err != nil {
return nil, errors.Wrap(err, "list hosts")
}
@ -413,7 +415,7 @@ func (d *Datastore) CleanupIncomingHosts(now time.Time) error {
WHERE hostname = '' AND osquery_version = ''
AND created_at < (? - INTERVAL 5 MINUTE)
`
if _, err := d.db.Exec(sqlStatement, now); err != nil {
if _, err := d.writer.Exec(sqlStatement, now); err != nil {
return errors.Wrap(err, "cleanup incoming hosts")
}
@ -442,7 +444,7 @@ func (d *Datastore) GenerateHostStatusStatistics(filter fleet.TeamFilter, now ti
Online uint `db:"online"`
New uint `db:"new"`
}{}
err := d.db.Get(&counts, sqlStatement, now, now, now, now, now)
err := d.reader.Get(&counts, sqlStatement, now, now, now, now, now)
if err != nil && err != sql.ErrNoRows {
e = errors.Wrap(err, "generating host statistics")
return
@ -582,7 +584,7 @@ func (d *Datastore) AuthenticateHost(nodeKey string) (*fleet.Host, error) {
`
host := &fleet.Host{}
if err := d.db.Get(host, sqlStatement, nodeKey); err != nil {
if err := d.reader.Get(host, sqlStatement, nodeKey); err != nil {
switch err {
case sql.ErrNoRows:
return nil, notFound("Host")
@ -601,7 +603,7 @@ func (d *Datastore) MarkHostSeen(host *fleet.Host, t time.Time) error {
WHERE node_key=?
`
_, err := d.db.Exec(sqlStatement, t, host.NodeKey)
_, err := d.writer.Exec(sqlStatement, t, host.NodeKey)
if err != nil {
return errors.Wrap(err, "marking host seen")
}
@ -659,11 +661,11 @@ func (d *Datastore) searchHostsWithOmits(filter fleet.TeamFilter, query string,
if err != nil {
return nil, errors.Wrap(err, "searching hosts")
}
sql = d.db.Rebind(sql)
sql = d.reader.Rebind(sql)
hosts := []*fleet.Host{}
err = d.db.Select(&hosts, sql, args...)
err = d.reader.Select(&hosts, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "searching hosts rebound")
}
@ -695,8 +697,8 @@ func (d *Datastore) searchHostsDefault(filter fleet.TeamFilter, omit ...uint) ([
if err != nil {
return nil, errors.Wrap(err, "searching default hosts")
}
sql = d.db.Rebind(sql)
err = d.db.Select(&hosts, sql, args...)
sql = d.reader.Rebind(sql)
err = d.reader.Select(&hosts, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "searching default hosts rebound")
}
@ -730,7 +732,7 @@ func (d *Datastore) SearchHosts(filter fleet.TeamFilter, query string, omit ...u
)
hosts := []*fleet.Host{}
if err := d.db.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
if err := d.reader.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
return nil, errors.Wrap(err, "searching hosts")
}
@ -755,7 +757,7 @@ func (d *Datastore) HostIDsByName(filter fleet.TeamFilter, hostnames []string) (
}
var hostIDs []uint
if err := d.db.Select(&hostIDs, sql, args...); err != nil {
if err := d.reader.Select(&hostIDs, sql, args...); err != nil {
return nil, errors.Wrap(err, "get host IDs")
}
@ -770,7 +772,7 @@ func (d *Datastore) HostByIdentifier(identifier string) (*fleet.Host, error) {
LIMIT 1
`
host := &fleet.Host{}
err := d.db.Get(host, sql, identifier)
err := d.reader.Get(host, sql, identifier)
if err != nil {
return nil, errors.Wrap(err, "get host by identifier")
}
@ -796,7 +798,7 @@ func (d *Datastore) AddHostsToTeam(teamID *uint, hostIDs []uint) error {
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
}
if _, err := d.db.Exec(sql, args...); err != nil {
if _, err := d.writer.Exec(sql, args...); err != nil {
return errors.Wrap(err, "exec AddHostsToTeam")
}
@ -809,7 +811,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
VALUES (?, ?)
ON DUPLICATE KEY UPDATE additional = VALUES(additional)
`
if _, err := d.db.Exec(sql, host.ID, host.Additional); err != nil {
if _, err := d.writer.Exec(sql, host.ID, host.Additional); err != nil {
return errors.Wrap(err, "insert additional")
}
@ -818,7 +820,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
if len(host.Users) == 0 {
if _, err := d.db.Exec(
if _, err := d.writer.Exec(
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`,
host.ID,
); err != nil {
@ -829,7 +831,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
}
currentHost := &fleet.Host{ID: host.ID}
if err := d.loadHostUsers(currentHost); err != nil {
if err := d.loadHostUsers(d.writer, currentHost); err != nil {
return err
}
@ -837,7 +839,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
for _, u := range host.Users {
incomingUsers[u.Uid] = true
if _, err := d.db.Exec(
if _, err := d.writer.Exec(
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
host.ID, u.Uid, u.Username, u.Type, u.GroupName,
); err != nil {
@ -860,7 +862,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`,
removedValues,
)
if _, err := d.db.Exec(removedSql, removedArgs...); err != nil {
if _, err := d.writer.Exec(removedSql, removedArgs...); err != nil {
return errors.Wrap(err, "mark users as removed")
}
@ -869,15 +871,15 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) {
var totalCount, unseenCount int
err := d.db.QueryRow("SELECT count(*) FROM hosts").Scan(&totalCount)
err := d.reader.Get(&totalCount, "SELECT count(*) FROM hosts")
if err != nil {
return 0, 0, errors.Wrap(err, "getting total host count")
}
err = d.db.QueryRow(
err = d.reader.Get(&unseenCount,
"SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?",
daysCount,
).Scan(&unseenCount)
)
if err != nil {
return 0, 0, errors.Wrap(err, "getting unseen host count")
}

View File

@ -23,7 +23,7 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
VALUES ( ?, ?, ?, ?, ?, ?, ?)
`
result, err := d.db.Exec(sqlStmt, i.InvitedBy, i.Email,
result, err := d.writer.Exec(sqlStmt, i.InvitedBy, i.Email,
i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole)
if err != nil && isDuplicate(err) {
return nil, alreadyExists("Invite", i.Email)
@ -45,10 +45,11 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
for _, userTeam := range i.Teams {
args = append(args, i.ID, userTeam.Team.ID, userTeam.Role)
}
// TODO: seems like this should be in a transaction?
sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " +
strings.Repeat(valueStr, len(i.Teams))
sql = strings.TrimSuffix(sql, ",")
if _, err := d.db.Exec(sql, args...); err != nil {
if _, err := d.writer.Exec(sql, args...); err != nil {
return nil, errors.Wrap(err, "insert teams")
}
@ -63,7 +64,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...)
query = appendListOptionsToSQL(query, opt)
err := d.db.Select(&invites, query, params...)
err := d.reader.Select(&invites, query, params...)
if err == sql.ErrNoRows {
return nil, notFound("Invite")
} else if err != nil {
@ -80,7 +81,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
// Invite returns Invite identified by id.
func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
var invite fleet.Invite
err := d.db.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
if err == sql.ErrNoRows {
return nil, notFound("Invite").WithID(id)
} else if err != nil {
@ -97,7 +98,7 @@ func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
// InviteByEmail finds an Invite with a particular email, if one exists.
func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
var invite fleet.Invite
err := d.db.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
if err == sql.ErrNoRows {
return nil, notFound("Invite").
WithMessage(fmt.Sprintf("with email %s", email))
@ -115,7 +116,7 @@ func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
// InviteByToken finds an Invite with a particular token, if one exists.
func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) {
var invite fleet.Invite
err := d.db.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
if err == sql.ErrNoRows {
return nil, notFound("Invite").
WithMessage(fmt.Sprintf("with token %s", token))
@ -163,7 +164,7 @@ func (d *Datastore) loadTeamsForInvites(invites []*fleet.Invite) error {
fleet.UserTeam
InviteID uint `db:"invite_id"`
}
if err := d.db.Select(&rows, sql, args...); err != nil {
if err := d.reader.Select(&rows, sql, args...); err != nil {
return errors.Wrap(err, "get loadTeamsForInvites")
}

View File

@ -115,7 +115,7 @@ func (d *Datastore) GetLabelSpecs() ([]*fleet.LabelSpec, error) {
var specs []*fleet.LabelSpec
// Get basic specs
query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels"
if err := d.db.Select(&specs, query); err != nil {
if err := d.reader.Select(&specs, query); err != nil {
return nil, errors.Wrap(err, "get labels")
}
@ -138,7 +138,7 @@ SELECT name, description, query, platform, label_type, label_membership_type
FROM labels
WHERE name = ?
`
if err := d.db.Select(&specs, query, name); err != nil {
if err := d.reader.Select(&specs, query, name); err != nil {
return nil, errors.Wrap(err, "get label")
}
if len(specs) == 0 {
@ -171,7 +171,7 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
WHERE label_id = (SELECT id FROM labels WHERE name = ?)
)
`
err := d.db.Select(&label.Hosts, sql, label.Name)
err := d.reader.Select(&label.Hosts, sql, label.Name)
if err != nil {
return errors.Wrap(err, "get hostnames for label")
}
@ -190,7 +190,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
label_membership_type
) VALUES ( ?, ?, ?, ?, ?, ?)
`
result, err := d.db.Exec(
result, err := d.writer.Exec(
query,
label.Name,
label.Description,
@ -216,7 +216,7 @@ func (d *Datastore) SaveLabel(label *fleet.Label) (*fleet.Label, error) {
description = ?
WHERE id = ?
`
_, err := d.db.Exec(query, label.Name, label.Description, label.ID)
_, err := d.writer.Exec(query, label.Name, label.Description, label.ID)
if err != nil {
return nil, errors.Wrap(err, "saving label")
}
@ -236,7 +236,7 @@ func (d *Datastore) Label(lid uint) (*fleet.Label, error) {
`
label := &fleet.Label{}
if err := d.db.Get(label, sql, lid); err != nil {
if err := d.reader.Get(label, sql, lid); err != nil {
return nil, errors.Wrap(err, "selecting label")
}
@ -255,7 +255,7 @@ func (d *Datastore) ListLabels(filter fleet.TeamFilter, opt fleet.ListOptions) (
query = appendListOptionsToSQL(query, opt)
labels := []*fleet.Label{}
if err := d.db.Select(&labels, query); err != nil {
if err := d.reader.Select(&labels, query); err != nil {
// it's ok if no labels exist
if err == sql.ErrNoRows {
return labels, nil
@ -288,7 +288,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
WHERE platform = ? OR platform = ''
AND label_membership_type = ?
`
rows, err = d.db.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
rows, err = d.reader.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
} else {
// Retrieve all labels (with matching platform) iff there is a label
// that has been created since this host last reported label query
@ -300,7 +300,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
AND (platform = ? OR platform = '')
AND label_membership_type = ?
`
rows, err = d.db.Query(
rows, err = d.reader.Query(
sql,
platform,
host.LabelUpdatedAt,
@ -369,7 +369,7 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
updated_at = VALUES(updated_at)
`
_, err := d.db.Exec(sql, vals...)
_, err := d.writer.Exec(sql, vals...)
if err != nil {
return errors.Wrapf(err, "insert label query executions (%v)", vals)
}
@ -384,8 +384,8 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
if err != nil {
return errors.Wrap(err, "IN for DELETE FROM label_membership")
}
query = d.db.Rebind(query)
_, err = d.db.Exec(query, args...)
query = d.writer.Rebind(query)
_, err = d.writer.Exec(query, args...)
if err != nil {
return errors.Wrap(err, "delete label query executions")
}
@ -403,7 +403,7 @@ func (d *Datastore) ListLabelsForHost(hid uint) ([]*fleet.Label, error) {
`
labels := []*fleet.Label{}
err := d.db.Select(&labels, sqlStatement, hid)
err := d.reader.Select(&labels, sqlStatement, hid)
if err != nil {
return nil, errors.Wrap(err, "selecting host labels")
}
@ -432,7 +432,7 @@ func (d *Datastore) ListHostsInLabel(filter fleet.TeamFilter, lid uint, opt flee
sql = appendListOptionsToSQL(sql, opt.ListOptions)
hosts := []*fleet.Host{}
err := d.db.Select(&hosts, sql, params...)
err := d.reader.Select(&hosts, sql, params...)
if err != nil {
return nil, errors.Wrap(err, "selecting label query executions")
}
@ -458,9 +458,9 @@ func (d *Datastore) ListUniqueHostsInLabels(filter fleet.TeamFilter, labels []ui
return nil, errors.Wrap(err, "building query listing unique hosts in labels")
}
query = d.db.Rebind(query)
query = d.reader.Rebind(query)
hosts := []*fleet.Host{}
err = d.db.Select(&hosts, query, args...)
err = d.reader.Select(&hosts, query, args...)
if err != nil {
return nil, errors.Wrap(err, "listing unique hosts in labels")
}
@ -492,10 +492,10 @@ func (d *Datastore) searchLabelsWithOmits(filter fleet.TeamFilter, query string,
return nil, errors.Wrap(err, "building query for labels with omits")
}
sql = d.db.Rebind(sql)
sql = d.reader.Rebind(sql)
matches := []*fleet.Label{}
err = d.db.Select(&matches, sql, args...)
err = d.reader.Select(&matches, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "selecting labels with omits")
}
@ -527,7 +527,7 @@ func (d *Datastore) addAllHostsLabelToList(filter fleet.TeamFilter, labels []*fl
)
var allHosts fleet.Label
if err := d.db.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
if err := d.reader.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
return nil, errors.Wrap(err, "get all hosts label")
}
@ -575,8 +575,8 @@ func (d *Datastore) searchLabelsDefault(filter fleet.TeamFilter, omit ...uint) (
if err != nil {
return nil, errors.Wrap(err, "searching default labels")
}
sql = d.db.Rebind(sql)
if err := d.db.Select(&labels, sql, args...); err != nil {
sql = d.reader.Rebind(sql)
if err := d.reader.Select(&labels, sql, args...); err != nil {
return nil, errors.Wrap(err, "searching default labels rebound")
}
@ -617,7 +617,7 @@ func (d *Datastore) SearchLabels(filter fleet.TeamFilter, query string, omit ...
)
matches := []*fleet.Label{}
if err := d.db.Select(&matches, sql, transformedQuery); err != nil {
if err := d.reader.Select(&matches, sql, transformedQuery); err != nil {
return nil, errors.Wrap(err, "selecting labels for search")
}
@ -645,7 +645,7 @@ func (d *Datastore) LabelIDsByName(labels []string) ([]uint, error) {
}
var labelIDs []uint
if err := d.db.Select(&labelIDs, sql, args...); err != nil {
if err := d.reader.Select(&labelIDs, sql, args...); err != nil {
return nil, errors.Wrap(err, "get label IDs")
}

View File

@ -29,27 +29,27 @@ func (d *Datastore) Lock(name string, owner string, expiration time.Duration) (b
}
func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) {
return d.db.Exec(
return d.writer.Exec(
`INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`,
name, owner, time.Now().Add(expiration),
)
}
func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) {
return d.db.Exec(
return d.writer.Exec(
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`,
name, owner, time.Now().Add(expiration), name, owner,
)
}
func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) {
return d.db.Exec(
return d.writer.Exec(
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`,
name, owner, time.Now().Add(expiration), name,
)
}
func (d *Datastore) Unlock(name string, owner string) error {
_, err := d.db.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
_, err := d.writer.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
return err
}

View File

@ -85,14 +85,14 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
defer ds.Close()
for {
version, err := tables.MigrationClient.GetDBVersion(ds.db.DB)
version, err := tables.MigrationClient.GetDBVersion(ds.writer.DB)
require.NoError(t, err)
// break right before the the constraint migration
if version == 20210818182258 {
break
}
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
}
host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now())
@ -122,7 +122,7 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
require.NoError(t, ds.DeleteHost(host1.ID))
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
// Make sure we don't delete more than we need
hostCheck, err := ds.Host(host2.ID)

View File

@ -37,13 +37,28 @@ var (
columnCharsRegexp = regexp.MustCompile(`[^\w-]`)
)
// dbReader is an interface that defines the methods required for reads.
type dbReader interface {
sqlx.Queryer
Close() error
Rebind(string) string
Select(interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
}
// Datastore is an implementation of fleet.Datastore interface backed by
// MySQL
type Datastore struct {
db *sqlx.DB
reader dbReader // so it cannot be used to perform writes
writer *sqlx.DB
logger log.Logger
clock clock.Clock
config config.MysqlConfig
// nil if no read replica
readReplicaConfig *config.MysqlConfig
}
type txFn func(*sqlx.Tx) error
@ -67,7 +82,7 @@ func retryableError(err error) bool {
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
func (d *Datastore) withRetryTxx(fn txFn) (err error) {
operation := func() error {
tx, err := d.db.Beginx()
tx, err := d.writer.Beginx()
if err != nil {
return errors.Wrap(err, "create transaction")
}
@ -116,7 +131,7 @@ func (d *Datastore) withRetryTxx(fn txFn) (err error) {
// withTx provides a common way to commit/rollback a txFn
func (d *Datastore) withTx(fn txFn) (err error) {
tx, err := d.db.Beginx()
tx, err := d.writer.Beginx()
if err != nil {
return errors.Wrap(err, "create transaction")
}
@ -153,51 +168,64 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
}
for _, setOpt := range opts {
if setOpt != nil {
setOpt(options)
}
if config.PasswordPath != "" && config.Password != "" {
return nil, errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
}
// Check to see if the flag is populated
// Check if file exists on disk
// If file exists read contents
if config.PasswordPath != "" {
fileContents, err := ioutil.ReadFile(config.PasswordPath)
if err := checkConfig(&config); err != nil {
return nil, err
}
if options.replicaConfig != nil {
if err := checkConfig(options.replicaConfig); err != nil {
return nil, errors.Wrap(err, "replica")
}
}
dbWriter, err := newDB(&config, options)
if err != nil {
return nil, err
}
config.Password = strings.TrimSpace(string(fileContents))
}
if config.TLSCA != "" {
config.TLSConfig = "custom"
err := registerTLS(config)
dbReader := dbWriter
if options.replicaConfig != nil {
dbReader, err = newDB(options.replicaConfig, options)
if err != nil {
return nil, errors.Wrap(err, "register TLS config for mysql")
return nil, err
}
}
dsn := generateMysqlConnectionString(config)
ds := &Datastore{
writer: dbWriter,
reader: dbReader,
logger: options.logger,
clock: c,
config: config,
readReplicaConfig: options.replicaConfig,
}
return ds, nil
}
func newDB(conf *config.MysqlConfig, opts *dbOptions) (*sqlx.DB, error) {
dsn := generateMysqlConnectionString(*conf)
db, err := sqlx.Open("mysql", dsn)
if err != nil {
return nil, err
}
db.SetMaxIdleConns(config.MaxIdleConns)
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetime))
db.SetMaxIdleConns(conf.MaxIdleConns)
db.SetMaxOpenConns(conf.MaxOpenConns)
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
var dbError error
for attempt := 0; attempt < options.maxAttempts; attempt++ {
for attempt := 0; attempt < opts.maxAttempts; attempt++ {
dbError = db.Ping()
if dbError == nil {
// we're connected!
break
}
interval := time.Duration(attempt) * time.Second
options.logger.Log("mysql", fmt.Sprintf(
opts.logger.Log("mysql", fmt.Sprintf(
"could not connect to db: %v, sleeping %v", dbError, interval))
time.Sleep(interval)
}
@ -205,20 +233,37 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
if dbError != nil {
return nil, dbError
}
ds := &Datastore{
db: db,
logger: options.logger,
clock: c,
config: config,
return db, nil
}
return ds, nil
func checkConfig(conf *config.MysqlConfig) error {
if conf.PasswordPath != "" && conf.Password != "" {
return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
}
// Check to see if the flag is populated
// Check if file exists on disk
// If file exists read contents
if conf.PasswordPath != "" {
fileContents, err := ioutil.ReadFile(conf.PasswordPath)
if err != nil {
return err
}
conf.Password = strings.TrimSpace(string(fileContents))
}
if conf.TLSCA != "" {
conf.TLSConfig = "custom"
err := registerTLS(*conf)
if err != nil {
return errors.Wrap(err, "register TLS config for mysql")
}
}
return nil
}
func (d *Datastore) Begin() (fleet.Transaction, error) {
return d.db.Beginx()
return d.writer.Beginx()
}
func (d *Datastore) Name() string {
@ -226,11 +271,11 @@ func (d *Datastore) Name() string {
}
func (d *Datastore) MigrateTables() error {
return tables.MigrationClient.Up(d.db.DB, "")
return tables.MigrationClient.Up(d.writer.DB, "")
}
func (d *Datastore) MigrateData() error {
return data.MigrationClient.Up(d.db.DB, "")
return data.MigrationClient.Up(d.writer.DB, "")
}
func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
@ -243,7 +288,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
return 0, errors.Wrap(err, "missing tables migrations")
}
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.db.DB)
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.writer.DB)
if err != nil {
return 0, errors.Wrap(err, "cannot get table migration status")
}
@ -253,7 +298,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
return 0, errors.Wrap(err, "missing data migrations")
}
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.db.DB)
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.writer.DB)
if err != nil {
return 0, errors.Wrap(err, "cannot get data migration status")
}
@ -283,11 +328,11 @@ func (d *Datastore) Drop() error {
WHERE TABLE_SCHEMA = ?;
`
if err := d.db.Select(&tables, sql, d.config.Database); err != nil {
if err := d.writer.Select(&tables, sql, d.config.Database); err != nil {
return err
}
tx, err := d.db.Begin()
tx, err := d.writer.Begin()
if err != nil {
return err
}
@ -312,13 +357,28 @@ func (d *Datastore) Drop() error {
// HealthCheck returns an error if the MySQL backend is not healthy.
func (d *Datastore) HealthCheck() error {
_, err := d.db.Exec("select 1")
if _, err := d.writer.Exec("select 1"); err != nil {
return err
}
if d.readReplicaConfig != nil {
var dst int
if err := d.reader.Get(&dst, "select 1"); err != nil {
return err
}
}
return nil
}
// Close frees resources associated with underlying mysql connection
func (d *Datastore) Close() error {
return d.db.Close()
err := d.writer.Close()
if d.readReplicaConfig != nil {
errRead := d.reader.Close()
if err == nil {
err = errRead
}
}
return err
}
func sanitizeColumn(col string) string {

View File

@ -1,7 +1,9 @@
package mysql
import (
"database/sql"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/VividCortex/mysqlerr"
@ -15,6 +17,51 @@ import (
"github.com/stretchr/testify/require"
)
func TestDatastoreReplica(t *testing.T) {
// a bit unfortunate to create temp databases just for this - could be mixed
// with other tests when/if we move to subtests to minimize the number of
// databases created for tests (see #1805).
t.Run("noreplica", func(t *testing.T) {
ds := CreateMySQLDSWithOptions(t, nil)
defer ds.Close()
require.Equal(t, ds.reader, ds.writer)
})
t.Run("replica", func(t *testing.T) {
opts := &DatastoreTestOptions{Replica: true}
ds := CreateMySQLDSWithOptions(t, opts)
defer ds.Close()
require.NotEqual(t, ds.reader, ds.writer)
// create a new host
host, err := ds.NewHost(&fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: "1",
UUID: "1",
Hostname: "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
})
require.NoError(t, err)
require.NotNil(t, host)
// trying to read it fails, not replicated yet
_, err = ds.Host(host.ID)
require.Error(t, err)
require.True(t, errors.Is(err, sql.ErrNoRows))
opts.RunReplication()
// now it can read it
host2, err := ds.Host(host.ID)
require.NoError(t, err)
require.Equal(t, host.ID, host2.ID)
})
}
func TestSanitizeColumn(t *testing.T) {
t.Parallel()
@ -121,8 +168,10 @@ func TestSearchLike(t *testing.T) {
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
dbmock := sqlx.NewDb(db, "sqlmock")
ds := &Datastore{
db: sqlx.NewDb(db, "sqlmock"),
writer: dbmock,
reader: dbmock,
logger: log.NewNopLogger(),
}

View File

@ -205,7 +205,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
WHERE name = ?
`
var pack fleet.Pack
err := d.db.Get(&pack, sqlStatement, name)
err := d.reader.Get(&pack, sqlStatement, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, false, nil
@ -213,7 +213,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
return nil, false, errors.Wrap(err, "fetch pack by name")
}
if err := d.loadPackTargets(&pack); err != nil {
if err := d.loadPackTargets(d.reader, &pack); err != nil {
return nil, false, err
}
@ -313,10 +313,10 @@ func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
return nil
}
func (d *Datastore) loadPackTargets(pack *fleet.Pack) error {
func (d *Datastore) loadPackTargets(db dbReader, pack *fleet.Pack) error {
var targets []fleet.PackTarget
sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
if err := d.db.Select(&targets, sql, pack.ID); err != nil {
if err := db.Select(&targets, sql, pack.ID); err != nil {
return errors.Wrap(err, "select pack targets")
}
@ -371,14 +371,14 @@ func (d *Datastore) DeletePack(name string) error {
func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
query := `SELECT * FROM packs WHERE id = ?`
pack := &fleet.Pack{}
err := d.db.Get(pack, query, pid)
err := d.reader.Get(pack, query, pid)
if err == sql.ErrNoRows {
return nil, notFound("Pack").WithID(pid)
} else if err != nil {
return nil, errors.Wrap(err, "get pack")
}
if err := d.loadPackTargets(pack); err != nil {
if err := d.loadPackTargets(d.reader, pack); err != nil {
return nil, err
}
@ -388,14 +388,15 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
// EnsureGlobalPack gets or inserts a pack with type global
func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) {
pack := &fleet.Pack{}
err := d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
// read from primary as we will create the pack if it doesn't exist
err := d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
if err == sql.ErrNoRows {
return d.insertNewGlobalPack()
} else if err != nil {
return nil, errors.Wrap(err, "get pack")
}
if err := d.loadPackTargets(pack); err != nil {
if err := d.loadPackTargets(d.writer, pack); err != nil {
return nil, err
}
@ -439,14 +440,15 @@ func (d *Datastore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) {
}
teamType := fmt.Sprintf("team-%d", teamID)
err = d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
// read from primary as we will create the team pack if it doesn't exist
err = d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
if err == sql.ErrNoRows {
return d.insertNewTeamPack(t)
} else if err != nil {
return nil, errors.Wrap(err, "get pack")
}
if err := d.loadPackTargets(pack); err != nil {
if err := d.loadPackTargets(d.writer, pack); err != nil {
return nil, err
}
@ -499,13 +501,13 @@ func (d *Datastore) ListPacks(opt fleet.PackListOptions) ([]*fleet.Pack, error)
query = `SELECT * FROM packs`
}
var packs []*fleet.Pack
err := d.db.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
err := d.reader.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing packs")
}
for _, pack := range packs {
if err := d.loadPackTargets(pack); err != nil {
if err := d.loadPackTargets(d.reader, pack); err != nil {
return nil, err
}
}
@ -541,7 +543,7 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
`
packs := []*fleet.Pack{}
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
if err := d.reader.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing hosts in pack")
}
return packs, nil

View File

@ -524,7 +524,7 @@ func TestTeamScheduleNamesMigrateToNewFormat(t *testing.T) {
require.NoError(t, err)
// insert team pack by hand with the old naming scheme
_, err = ds.db.Exec(
_, err = ds.writer.Exec(
"INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)",
teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1),
)

View File

@ -11,7 +11,7 @@ func (d *Datastore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
( user_id, token, expires_at)
VALUES (?,?, NOW())
`
response, err := d.db.Exec(sqlStatement, req.UserID, req.Token)
response, err := d.writer.Exec(sqlStatement, req.UserID, req.Token)
if err != nil {
return nil, errors.Wrap(err, "inserting password reset requests")
}
@ -30,7 +30,7 @@ func (d *Datastore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) er
token = ?
WHERE id = ?
`
result, err := d.db.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
result, err := d.writer.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
if err != nil {
return errors.Wrap(err, "updating password reset requests")
}
@ -58,7 +58,7 @@ func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
sqlStatement := `
DELETE FROM password_reset_requests WHERE user_id = ?
`
_, err := d.db.Exec(sqlStatement, userID)
_, err := d.writer.Exec(sqlStatement, userID)
if err != nil {
return errors.Wrap(err, "deleting password reset request by user")
}
@ -72,7 +72,7 @@ func (d *Datastore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest
WHERE id = ? LIMIT 1
`
passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.db.Get(&passwordResetRequest, sqlStatement, id)
err := d.reader.Get(&passwordResetRequest, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "selecting password reset by id")
}
@ -87,7 +87,7 @@ func (d *Datastore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordReset
`
passwordResetRequests := []*fleet.PasswordResetRequest{}
err := d.db.Select(&passwordResetRequests, sqlStatement, id)
err := d.reader.Select(&passwordResetRequests, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "finding password resets by user id")
}
@ -102,7 +102,7 @@ func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordRese
WHERE token = ? LIMIT 1
`
passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.db.Get(passwordResetRequest, sqlStatement, token)
err := d.reader.Get(passwordResetRequest, sqlStatement, token)
if err != nil {
return nil, errors.Wrap(err, "selecting password reset requests")
}
@ -118,7 +118,7 @@ func (d *Datastore) FindPassswordResetByTokenAndUserID(token string, id uint) (*
LIMIT 1
`
passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.db.Get(passwordResetRequest, sqlStatement, id, token)
err := d.reader.Get(passwordResetRequest, sqlStatement, id, token)
if err != nil {
return nil, errors.Wrap(err, "selecting password reset by token and user id")
}

View File

@ -12,7 +12,7 @@ import (
)
func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
res, err := ds.db.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
res, err := ds.writer.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
if err != nil {
return nil, errors.Wrap(err, "inserting new policy")
}
@ -26,7 +26,7 @@ func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) {
var policy fleet.Policy
err := ds.db.Get(
err := ds.reader.Get(
&policy,
`SELECT
p.*,
@ -66,7 +66,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
strings.Join(bindvars, ","),
)
_, err := ds.db.Exec(query, vals...)
_, err := ds.writer.Exec(query, vals...)
if err != nil {
return errors.Wrapf(err, "insert policy_membership (%v)", vals)
}
@ -76,7 +76,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) {
var policies []*fleet.Policy
err := ds.db.Select(
err := ds.reader.Select(
&policies,
`SELECT
p.*,
@ -97,8 +97,8 @@ func (ds *Datastore) DeleteGlobalPolicies(ids []uint) ([]uint, error) {
if err != nil {
return nil, errors.Wrap(err, "IN for DELETE FROM policies")
}
stmt = ds.db.Rebind(stmt)
if _, err := ds.db.Exec(stmt, args...); err != nil {
stmt = ds.writer.Rebind(stmt)
if _, err := ds.writer.Exec(stmt, args...); err != nil {
return nil, errors.Wrap(err, "delete policies")
}
return ids, nil
@ -109,7 +109,7 @@ func (ds *Datastore) PolicyQueriesForHost(_ *fleet.Host) (map[string]string, err
Id string `db:"id"`
Query string `db:"query"`
}
err := ds.db.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
err := ds.reader.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
if err != nil {
return nil, errors.Wrap(err, "selecting policies for host")
}

View File

@ -10,7 +10,7 @@ import (
)
func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) {
tx, err := d.db.Begin()
tx, err := d.writer.Begin()
if err != nil {
return errors.Wrap(err, "begin ApplyQueries transaction")
}
@ -72,7 +72,7 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
WHERE name = ?
`
var query fleet.Query
err := d.db.Get(&query, sqlStatement, name)
err := d.reader.Get(&query, sqlStatement, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, notFound("Query").WithName(name)
@ -99,7 +99,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
observer_can_run
) VALUES ( ?, ?, ?, ?, ?, ? )
`
result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
result, err := d.writer.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
if err != nil && isDuplicate(err) {
return nil, alreadyExists("Query", 0)
@ -120,7 +120,7 @@ func (d *Datastore) SaveQuery(q *fleet.Query) error {
SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ?
WHERE id = ?
`
result, err := d.db.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
result, err := d.writer.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
if err != nil {
return errors.Wrap(err, "updating query")
}
@ -156,7 +156,7 @@ func (d *Datastore) Query(id uint) (*fleet.Query, error) {
WHERE q.id = ?
`
query := &fleet.Query{}
if err := d.db.Get(query, sql, id); err != nil {
if err := d.reader.Get(query, sql, id); err != nil {
return nil, errors.Wrap(err, "selecting query")
}
@ -180,7 +180,7 @@ func (d *Datastore) ListQueries(opt fleet.ListOptions) ([]*fleet.Query, error) {
sql = appendListOptionsToSQL(sql, opt)
results := []*fleet.Query{}
if err := d.db.Select(&results, sql); err != nil {
if err := d.reader.Select(&results, sql); err != nil {
return nil, errors.Wrap(err, "listing queries")
}
@ -225,7 +225,7 @@ func (d *Datastore) loadPacksForQueries(queries []*fleet.Query) error {
fleet.Pack
}{}
err = d.db.Select(&rows, query, args...)
err = d.reader.Select(&rows, query, args...)
if err != nil {
return errors.Wrap(err, "selecting load packs for queries")
}

View File

@ -34,7 +34,7 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
query = appendListOptionsToSQL(query, opts)
results := []*fleet.ScheduledQuery{}
if err := d.db.Select(&results, query, id); err != nil {
if err := d.reader.Select(&results, query, id); err != nil {
return nil, errors.Wrap(err, "listing scheduled queries")
}
@ -46,8 +46,8 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
}
func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
selectFunc := d.db.Select
execFunc := d.db.Exec
selectFunc := d.writer.Select
execFunc := d.writer.Exec
if tx != nil {
selectFunc = tx.Select
execFunc = tx.Exec
@ -108,7 +108,7 @@ func (d *Datastore) SaveScheduledQuery(sq *fleet.ScheduledQuery) (*fleet.Schedul
}
func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
updateFunc := d.db.Exec
updateFunc := d.writer.Exec
if tx != nil {
updateFunc = tx.Exec
}
@ -160,7 +160,7 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
WHERE sq.id = ?
`
sq := &fleet.ScheduledQuery{}
if err := d.db.Get(sq, query, id); err != nil {
if err := d.reader.Get(sq, query, id); err != nil {
return nil, errors.Wrap(err, "select scheduled query")
}
@ -168,11 +168,11 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
}
func (d *Datastore) CleanupOrphanScheduledQueryStats() error {
_, err := d.db.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
_, err := d.writer.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
if err != nil {
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query")
}
_, err = d.db.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
_, err = d.writer.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
if err != nil {
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host")
}

View File

@ -209,7 +209,7 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1")
sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2")
_, err := ds.db.Exec(`INSERT INTO scheduled_query_stats (
_, err := ds.writer.Exec(`INSERT INTO scheduled_query_stats (
host_id, scheduled_query_id, average_memory, denylisted,
executions, schedule_interval, output_size, system_time,
user_time, wall_time
@ -224,12 +224,12 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
require.Len(t, h1.PackStats, 1)
// now we insert a bogus stat
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
host_id, scheduled_query_id, average_memory, denylisted, executions
) VALUES (?, 999, 32, false, 2);`, h1.ID)
require.NoError(t, err)
// and also for an unknown host
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
host_id, scheduled_query_id, average_memory, denylisted, executions
) VALUES (888, 999, 32, true, 4);`)
require.NoError(t, err)
@ -241,14 +241,14 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
// but there are definitely there
var count int
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
require.NoError(t, err)
assert.Equal(t, 3, count)
// now we clean it up
require.NoError(t, ds.CleanupOrphanScheduledQueryStats())
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
require.NoError(t, err)
assert.Equal(t, 1, count)

View File

@ -11,7 +11,7 @@ func (d *Datastore) SessionByKey(key string) (*fleet.Session, error) {
WHERE ` + "`key`" + ` = ? LIMIT 1
`
session := &fleet.Session{}
err := d.db.Get(session, sqlStatement, key)
err := d.reader.Get(session, sqlStatement, key)
if err != nil {
return nil, errors.Wrap(err, "selecting sessions")
}
@ -26,7 +26,7 @@ func (d *Datastore) SessionByID(id uint) (*fleet.Session, error) {
LIMIT 1
`
session := &fleet.Session{}
err := d.db.Get(session, sqlStatement, id)
err := d.reader.Get(session, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "selecting session by id")
}
@ -40,7 +40,7 @@ func (d *Datastore) ListSessionsForUser(id uint) ([]*fleet.Session, error) {
WHERE user_id = ?
`
sessions := []*fleet.Session{}
err := d.db.Select(&sessions, sqlStatement, id)
err := d.reader.Select(&sessions, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "selecting sessions for user")
}
@ -57,7 +57,7 @@ func (d *Datastore) NewSession(session *fleet.Session) (*fleet.Session, error) {
)
VALUES(?,?)
`
result, err := d.db.Exec(sqlStatement, session.UserID, session.Key)
result, err := d.writer.Exec(sqlStatement, session.UserID, session.Key)
if err != nil {
return nil, errors.Wrap(err, "inserting session")
}
@ -80,7 +80,7 @@ func (d *Datastore) DestroyAllSessionsForUser(id uint) error {
sqlStatement := `
DELETE FROM sessions WHERE user_id = ?
`
_, err := d.db.Exec(sqlStatement, id)
_, err := d.writer.Exec(sqlStatement, id)
if err != nil {
return errors.Wrap(err, "deleting sessions for user")
}
@ -94,7 +94,7 @@ func (d *Datastore) MarkSessionAccessed(session *fleet.Session) error {
accessed_at = ?
WHERE id = ?
`
results, err := d.db.Exec(sqlStatement, d.clock.Now(), session.ID)
results, err := d.writer.Exec(sqlStatement, d.clock.Now(), session.ID)
if err != nil {
return errors.Wrap(err, "updating mark session as accessed")
}

View File

@ -206,7 +206,7 @@ func (d *Datastore) insertNewInstalledHostSoftware(
}
func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) {
selectFunc := d.db.Select
selectFunc := d.reader.Select
if tx != nil {
selectFunc = tx.Select
}
@ -231,7 +231,7 @@ func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Softwa
WHERE s.id IN
(SELECT software_id FROM host_software WHERE host_id = ?)
`
queryFunc := d.db.Queryx
queryFunc := d.reader.Queryx
if tx != nil {
queryFunc = tx.Queryx
}
@ -305,7 +305,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null`
// The rows.Close call is done by the caller once iteration using the
// returned fleet.SoftwareIterator is done.
rows, err := d.db.Queryx(sql) //nolint:sqlclosecheck
rows, err := d.reader.Queryx(sql) //nolint:sqlclosecheck
if err != nil {
return nil, errors.Wrap(err, "load host software")
}
@ -314,7 +314,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error {
sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)`
if _, err := d.db.Exec(sql, software.ID, cpe); err != nil {
if _, err := d.writer.Exec(sql, software.ID, cpe); err != nil {
return errors.Wrap(err, "insert software cpe")
}
return nil
@ -323,7 +323,7 @@ func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error
func (d *Datastore) AllCPEs() ([]string, error) {
sql := `SELECT cpe FROM software_cpe`
var cpes []string
err := d.db.Select(&cpes, sql)
err := d.reader.Select(&cpes, sql)
if err != nil {
return nil, errors.Wrap(err, "loads cpes")
}
@ -337,7 +337,7 @@ func (d *Datastore) InsertCVEForCPE(cve string, cpes []string) error {
for _, cpe := range cpes {
args = append(args, cpe, cve)
}
_, err := d.db.Exec(sql, args...)
_, err := d.writer.Exec(sql, args...)
if err != nil {
return errors.Wrap(err, "insert software cve")
}

View File

@ -207,7 +207,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
})
incoming[soft2Key] = true
tx, err := ds.db.Beginx()
tx, err := ds.writer.Beginx()
require.NoError(t, err)
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
require.NoError(t, tx.Commit())
@ -220,7 +220,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
})
incoming[soft3Key] = true
tx, err = ds.db.Beginx()
tx, err = ds.writer.Beginx()
require.NoError(t, err)
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
require.NoError(t, tx.Commit())

View File

@ -21,14 +21,14 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
}
dest := statistics{}
err = d.db.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
err = d.writer.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
if err != nil {
if err == sql.ErrNoRows {
anonIdentifier, err := server.GenerateRandomText(64)
if err != nil {
return fleet.StatisticsPayload{}, false, err
}
_, err = d.db.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
_, err = d.writer.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
if err != nil {
return fleet.StatisticsPayload{}, false, err
}
@ -55,6 +55,6 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
}
func (d *Datastore) RecordStatisticsSent() error {
_, err := d.db.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
_, err := d.writer.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
return err
}

View File

@ -52,7 +52,7 @@ func (d *Datastore) CountHostsInTargets(filter fleet.TeamFilter, targets fleet.H
}
res := fleet.TargetMetrics{}
err = d.db.Get(&res, query, args...)
err = d.reader.Get(&res, query, args...)
if err != nil {
return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets")
}
@ -98,7 +98,7 @@ func (d *Datastore) HostIDsInTargets(filter fleet.TeamFilter, targets fleet.Host
}
var res []uint
err = d.db.Select(&res, query, args...)
err = d.reader.Select(&res, query, args...)
if err != nil {
return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets")
}

View File

@ -20,7 +20,7 @@ func (d *Datastore) NewTeam(team *fleet.Team) (*fleet.Team, error) {
description
) VALUES ( ?, ?, ? )
`
result, err := d.db.Exec(
result, err := d.writer.Exec(
query,
team.Name,
team.AgentOptions,
@ -47,7 +47,7 @@ func (d *Datastore) Team(tid uint) (*fleet.Team, error) {
`
team := &fleet.Team{}
if err := d.db.Get(team, sql, tid); err != nil {
if err := d.reader.Get(team, sql, tid); err != nil {
return nil, errors.Wrap(err, "select team")
}
@ -84,7 +84,7 @@ func (d *Datastore) TeamByName(name string) (*fleet.Team, error) {
`
team := &fleet.Team{}
if err := d.db.Get(team, sql, name); err != nil {
if err := d.reader.Get(team, sql, name); err != nil {
return nil, errors.Wrap(err, "select team")
}
@ -106,7 +106,7 @@ func (d *Datastore) loadUsersForTeam(team *fleet.Team) error {
WHERE ut.team_id = ?
`
rows := []fleet.TeamUser{}
if err := d.db.Select(&rows, sql, team.ID); err != nil {
if err := d.reader.Select(&rows, sql, team.ID); err != nil {
return errors.Wrap(err, "load users for team")
}
@ -157,7 +157,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
description = ?
WHERE id = ?
`
_, err := d.db.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
_, err := d.writer.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
if err != nil {
return nil, errors.Wrap(err, "saving team")
}
@ -174,7 +174,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
}
func (d *Datastore) updateTeamSchedule(team *fleet.Team) error {
_, err := d.db.Exec(
_, err := d.writer.Exec(
`UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team),
)
return err
@ -195,7 +195,7 @@ func (d *Datastore) ListTeams(filter fleet.TeamFilter, opt fleet.ListOptions) ([
query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...)
query = appendListOptionsToSQL(query, opt)
teams := []*fleet.Team{}
if err := d.db.Select(&teams, query, params...); err != nil {
if err := d.reader.Select(&teams, query, params...); err != nil {
return nil, errors.Wrap(err, "list teams")
}
if err := d.loadSecretsForTeams(teams); err != nil {
@ -229,7 +229,7 @@ func (d *Datastore) SearchTeams(filter fleet.TeamFilter, matchQuery string, omit
sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...)
sql += "\nLIMIT 5"
teams := []*fleet.Team{}
if err := d.db.Select(&teams, sql, params...); err != nil {
if err := d.reader.Select(&teams, sql, params...); err != nil {
return nil, errors.Wrap(err, "search teams")
}
if err := d.loadSecretsForTeams(teams); err != nil {
@ -244,7 +244,7 @@ func (d *Datastore) TeamEnrollSecrets(teamID uint) ([]*fleet.EnrollSecret, error
WHERE team_id = ?
`
var secrets []*fleet.EnrollSecret
if err := d.db.Select(&secrets, sql, teamID); err != nil {
if err := d.reader.Select(&secrets, sql, teamID); err != nil {
return nil, errors.Wrap(err, "get secrets")
}
return secrets, nil

View File

@ -1,6 +1,7 @@
package mysql
import (
"context"
"fmt"
"io/ioutil"
"os"
@ -9,10 +10,12 @@ import (
"runtime"
"strings"
"testing"
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/go-kit/kit/log"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/require"
)
@ -20,9 +23,10 @@ const (
testUsername = "root"
testPassword = "toor"
testAddress = "localhost:3307"
testReplicaDatabaseSuffix = "_replica"
)
func connectMySQL(t *testing.T, testName string) *Datastore {
func connectMySQL(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
config := config.MysqlConfig{
Username: testUsername,
Password: testPassword,
@ -31,15 +35,134 @@ func connectMySQL(t *testing.T, testName string) *Datastore {
}
// Create datastore client
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
var replicaOpt DBOption
if opts.Replica {
replicaConf := config
replicaConf.Database += testReplicaDatabaseSuffix
replicaOpt = Replica(&replicaConf)
}
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1), replicaOpt)
require.Nil(t, err)
if opts.Replica {
setupReadReplica(t, testName, ds, opts)
}
return ds
}
func setupReadReplica(t *testing.T, testName string, ds *Datastore, opts *DatastoreTestOptions) {
// create the context that will cancel the replication goroutine on test exit
var cancel func()
ctx := context.Background()
if dl, ok := t.Deadline(); ok {
ctx, cancel = context.WithDeadline(ctx, dl)
} else {
ctx, cancel = context.WithCancel(ctx)
}
t.Cleanup(cancel)
// start the replication goroutine that runs when signalled through a
// channel, the replication runs in lock-step - the test is in control of
// when the replication happens, by calling opts.RunReplication(), and when
// that call returns, the replication is guaranteed to be done. This supports
// simulating all kinds of replica lag.
ch := make(chan chan struct{})
go func() {
// if it exits because of a panic/failed replication, cancel the context
// immediately so that RunReplication is unblocked too.
defer cancel()
primary := ds.writer
replica := ds.reader.(*sqlx.DB)
replicaDB := testName + testReplicaDatabaseSuffix
last := time.Now().Add(-time.Minute)
// drop all foreign keys in the replica, as that causes issues even with
// FOREIGN_KEY_CHECKS=0
var fks []struct {
TableName string `db:"table_name"`
ConstraintName string `db:"constraint_name"`
}
err := primary.SelectContext(ctx, &fks, `
SELECT
table_name, constraint_name
FROM
information_schema.key_column_usage
WHERE
table_schema = ? AND
referenced_table_name IS NOT NULL`, testName)
require.NoError(t, err)
for _, fk := range fks {
stmt := fmt.Sprintf(`ALTER TABLE %s.%s DROP FOREIGN KEY %s`, replicaDB, fk.TableName, fk.ConstraintName)
_, err := replica.ExecContext(ctx, stmt)
require.NoError(t, err)
}
for {
select {
case out := <-ch:
// identify tables with changes since the last call
var tables []string
err := primary.SelectContext(ctx, &tables, `
SELECT
table_name
FROM
information_schema.tables
WHERE
table_schema = ? AND
table_type = 'BASE TABLE' AND
update_time >= ?`, testName, last)
require.NoError(t, err)
err = primary.GetContext(ctx, &last, `
SELECT
MAX(update_time)
FROM
information_schema.tables
WHERE
table_schema = ? AND
table_type = 'BASE TABLE'`, testName)
require.NoError(t, err)
// replicate by dropping the existing table and re-creating it from
// the primary.
for _, tbl := range tables {
stmt := fmt.Sprintf(`DROP TABLE IF EXISTS %s.%s`, replicaDB, tbl)
t.Log(stmt)
_, err = replica.ExecContext(ctx, stmt)
require.NoError(t, err)
stmt = fmt.Sprintf(`CREATE TABLE %s.%s SELECT * FROM %s.%s`, replicaDB, tbl, testName, tbl)
t.Log(stmt)
_, err = replica.ExecContext(ctx, stmt)
require.NoError(t, err)
}
out <- struct{}{}
t.Logf("replication step executed, next will consider updates since %s", last)
case <-ctx.Done():
return
}
}
}()
// set RunReplication to a function that triggers the replication and waits
// for it to complete.
opts.RunReplication = func() {
done := make(chan struct{})
ch <- done
select {
case <-done:
case <-ctx.Done():
}
}
}
// initializeDatabase loads the dumped schema into a newly created database in
// MySQL. This is much faster than running the full set of migrations on each
// test.
func initializeDatabase(t *testing.T, testName string) *Datastore {
func initializeDatabase(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
_, filename, _, _ := runtime.Caller(0)
base := path.Dir(filename)
schema, err := ioutil.ReadFile(path.Join(base, "schema.sql"))
@ -47,6 +170,14 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
t.Error(err)
t.FailNow()
}
// execute the schema for the test db, and once more for the replica db if
// that option is set.
dbs := []string{testName}
if opts.Replica {
dbs = append(dbs, testName+testReplicaDatabaseSuffix)
}
for _, dbName := range dbs {
// Load schema from dumpfile
if out, err := exec.Command(
"docker-compose", "exec", "-T", "mysql_test",
@ -56,24 +187,43 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
"-e",
fmt.Sprintf(
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
testName, testName, testName, schema,
dbName, dbName, dbName, schema,
),
).CombinedOutput(); err != nil {
t.Error(err)
t.Error(string(out))
t.FailNow()
}
return connectMySQL(t, testName)
}
return connectMySQL(t, testName, opts)
}
func CreateMySQLDS(t *testing.T) *Datastore {
// DatastoreTestOptions configures how the test datastore is created
// by CreateMySQLDSWithOptions.
type DatastoreTestOptions struct {
// Replica indicates that a read replica test database should be created.
Replica bool
// RunReplication is the function to call to execute the replication of all
// missing changes from the primary to the replica. The function is created
// and set automatically by CreateMySQLDSWithOptions. The test is in full
// control of when the replication is executed.
RunReplication func()
}
func createMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
if _, ok := os.LookupEnv("MYSQL_TEST"); !ok {
t.Skip("MySQL tests are disabled")
}
t.Parallel()
pc, _, _, ok := runtime.Caller(1)
if opts == nil {
// so it is never nil in internal helper functions
opts = new(DatastoreTestOptions)
}
pc, _, _, ok := runtime.Caller(2)
details := runtime.FuncForPC(pc)
if !ok || details == nil {
t.FailNow()
@ -83,5 +233,13 @@ func CreateMySQLDS(t *testing.T) *Datastore {
strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_",
)
cleanName = strings.ReplaceAll(cleanName, ".", "_")
return initializeDatabase(t, cleanName)
return initializeDatabase(t, cleanName, opts)
}
func CreateMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
return createMySQLDSWithOptions(t, opts)
}
func CreateMySQLDS(t *testing.T) *Datastore {
return createMySQLDSWithOptions(t, nil)
}

View File

@ -72,7 +72,7 @@ func (d *Datastore) findUser(searchCol string, searchVal interface{}) (*fleet.Us
user := &fleet.User{}
err := d.db.Get(user, sqlStatement, searchVal)
err := d.reader.Get(user, sqlStatement, searchVal)
if err != nil && err == sql.ErrNoRows {
return nil, notFound("User").
WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal))
@ -104,7 +104,7 @@ func (d *Datastore) ListUsers(opt fleet.UserListOptions) ([]*fleet.User, error)
sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions)
users := []*fleet.User{}
if err := d.db.Select(&users, sqlStatement, params...); err != nil {
if err := d.reader.Select(&users, sqlStatement, params...); err != nil {
return nil, errors.Wrap(err, "list users")
}
@ -160,7 +160,7 @@ func (d *Datastore) saveUser(tx *sqlx.Tx, user *fleet.User) error {
global_role = ?
WHERE id = ?
`
result, err := d.db.Exec(sqlStatement,
result, err := tx.Exec(sqlStatement,
user.Password,
user.Salt,
user.Name,
@ -221,7 +221,7 @@ func (d *Datastore) loadTeamsForUsers(users []*fleet.User) error {
fleet.UserTeam
UserID uint `db:"user_id"`
}
if err := d.db.Select(&rows, sql, args...); err != nil {
if err := d.reader.Select(&rows, sql, args...); err != nil {
return errors.Wrap(err, "get loadTeamsForUsers")
}