mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Add support for MySQL read replicas (#1786)
Adds configuration options to use a read-only MySQL replica, and uses it instead of the primary for reads.
This commit is contained in:
parent
2116217b6e
commit
3755a58070
1
changes/issue-1695-support-mysql-read-replica
Normal file
1
changes/issue-1695-support-mysql-read-replica
Normal file
@ -0,0 +1 @@
|
||||
* Add support for MySQL read replica.
|
@ -149,7 +149,11 @@ the way that the Fleet server works.
|
||||
var carveStore fleet.CarveStore
|
||||
mailService := mail.NewService()
|
||||
|
||||
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger))
|
||||
var replicaOpt mysql.DBOption
|
||||
if config.MysqlReadReplica.Address != "" {
|
||||
replicaOpt = mysql.Replica(&config.MysqlReadReplica)
|
||||
}
|
||||
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger), replicaOpt)
|
||||
if err != nil {
|
||||
initFatal(err, "initializing datastore")
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ Your Fleet server's two main purposes are:
|
||||
- To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/)
|
||||
- To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc.
|
||||
|
||||
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
|
||||
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. It is also possible to configure connection to a MySQL replica in addition to the primary, to be used for reading only. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
|
||||
|
||||
> Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance.
|
||||
|
||||
@ -118,12 +118,21 @@ mysql:
|
||||
address: 127.0.0.1:3306
|
||||
```
|
||||
|
||||
And `mysql_read_replica_address` would be:
|
||||
|
||||
```
|
||||
mysql_read_replica:
|
||||
address: 127.0.0.1:3307
|
||||
```
|
||||
|
||||
Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way.
|
||||
|
||||
All duration-based settings accept valid time units of `s`, `m`, `h`.
|
||||
|
||||
##### MySQL
|
||||
|
||||
This section describes the configuration options for the primary - if you also want to setup a read replica, the options are the same, except that the yaml section is `mysql_read_replica`, and the flags have the `mysql_read_replica_` prefix instead of `mysql_` (the corresponding environment variables follow the same transformation). Note that there is no default value for `mysql_read_replica_address`, it must be set explicitly for fleet to use a read replica.
|
||||
|
||||
###### `mysql_address`
|
||||
|
||||
The address of the MySQL server which Fleet should connect to. Include the hostname and port.
|
||||
|
@ -177,6 +177,7 @@ type VulnerabilitiesConfig struct {
|
||||
// updated to set and retrieve the configurations as appropriate.
|
||||
type FleetConfig struct {
|
||||
Mysql MysqlConfig
|
||||
MysqlReadReplica MysqlConfig `yaml:"mysql_read_replica"`
|
||||
Redis RedisConfig
|
||||
Server ServerConfig
|
||||
Auth AuthConfig
|
||||
@ -197,32 +198,36 @@ type FleetConfig struct {
|
||||
// addConfigs adds the configuration keys and default values that will be
|
||||
// filled into the FleetConfig struct
|
||||
func (man Manager) addConfigs() {
|
||||
addMysqlConfig := func(prefix, defaultAddr, usageSuffix string) {
|
||||
man.addConfigString(prefix+".protocol", "tcp",
|
||||
"MySQL server communication protocol (tcp,unix,...)"+usageSuffix)
|
||||
man.addConfigString(prefix+".address", defaultAddr,
|
||||
"MySQL server address (host:port)"+usageSuffix)
|
||||
man.addConfigString(prefix+".username", "fleet",
|
||||
"MySQL server username"+usageSuffix)
|
||||
man.addConfigString(prefix+".password", "",
|
||||
"MySQL server password (prefer env variable for security)"+usageSuffix)
|
||||
man.addConfigString(prefix+".password_path", "",
|
||||
"Path to file containg MySQL server password"+usageSuffix)
|
||||
man.addConfigString(prefix+".database", "fleet",
|
||||
"MySQL database name"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_cert", "",
|
||||
"MySQL TLS client certificate path"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_key", "",
|
||||
"MySQL TLS client key path"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_ca", "",
|
||||
"MySQL TLS server CA"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_server_name", "",
|
||||
"MySQL TLS server name"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_config", "",
|
||||
"MySQL TLS config value"+usageSuffix+" Use skip-verify, true, false or custom key.")
|
||||
man.addConfigInt(prefix+".max_open_conns", 50, "MySQL maximum open connection handles"+usageSuffix)
|
||||
man.addConfigInt(prefix+".max_idle_conns", 50, "MySQL maximum idle connection handles"+usageSuffix)
|
||||
man.addConfigInt(prefix+".conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused"+usageSuffix)
|
||||
}
|
||||
// MySQL
|
||||
man.addConfigString("mysql.protocol", "tcp",
|
||||
"MySQL server communication protocol (tcp,unix,...)")
|
||||
man.addConfigString("mysql.address", "localhost:3306",
|
||||
"MySQL server address (host:port)")
|
||||
man.addConfigString("mysql.username", "fleet",
|
||||
"MySQL server username")
|
||||
man.addConfigString("mysql.password", "",
|
||||
"MySQL server password (prefer env variable for security)")
|
||||
man.addConfigString("mysql.password_path", "",
|
||||
"Path to file containg MySQL server password")
|
||||
man.addConfigString("mysql.database", "fleet",
|
||||
"MySQL database name")
|
||||
man.addConfigString("mysql.tls_cert", "",
|
||||
"MySQL TLS client certificate path")
|
||||
man.addConfigString("mysql.tls_key", "",
|
||||
"MySQL TLS client key path")
|
||||
man.addConfigString("mysql.tls_ca", "",
|
||||
"MySQL TLS server CA")
|
||||
man.addConfigString("mysql.tls_server_name", "",
|
||||
"MySQL TLS server name")
|
||||
man.addConfigString("mysql.tls_config", "",
|
||||
"MySQL TLS config value. Use skip-verify, true, false or custom key.")
|
||||
man.addConfigInt("mysql.max_open_conns", 50, "MySQL maximum open connection handles.")
|
||||
man.addConfigInt("mysql.max_idle_conns", 50, "MySQL maximum idle connection handles.")
|
||||
man.addConfigInt("mysql.conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused.")
|
||||
addMysqlConfig("mysql", "localhost:3306", ".")
|
||||
addMysqlConfig("mysql_read_replica", "", " for the read replica.")
|
||||
|
||||
// Redis
|
||||
man.addConfigString("redis.address", "localhost:6379",
|
||||
@ -382,23 +387,28 @@ func (man Manager) addConfigs() {
|
||||
func (man Manager) LoadConfig() FleetConfig {
|
||||
man.loadConfigFile()
|
||||
|
||||
loadMysqlConfig := func(prefix string) MysqlConfig {
|
||||
return MysqlConfig{
|
||||
Protocol: man.getConfigString(prefix + ".protocol"),
|
||||
Address: man.getConfigString(prefix + ".address"),
|
||||
Username: man.getConfigString(prefix + ".username"),
|
||||
Password: man.getConfigString(prefix + ".password"),
|
||||
PasswordPath: man.getConfigString(prefix + ".password_path"),
|
||||
Database: man.getConfigString(prefix + ".database"),
|
||||
TLSCert: man.getConfigString(prefix + ".tls_cert"),
|
||||
TLSKey: man.getConfigString(prefix + ".tls_key"),
|
||||
TLSCA: man.getConfigString(prefix + ".tls_ca"),
|
||||
TLSServerName: man.getConfigString(prefix + ".tls_server_name"),
|
||||
TLSConfig: man.getConfigString(prefix + ".tls_config"),
|
||||
MaxOpenConns: man.getConfigInt(prefix + ".max_open_conns"),
|
||||
MaxIdleConns: man.getConfigInt(prefix + ".max_idle_conns"),
|
||||
ConnMaxLifetime: man.getConfigInt(prefix + ".conn_max_lifetime"),
|
||||
}
|
||||
}
|
||||
|
||||
return FleetConfig{
|
||||
Mysql: MysqlConfig{
|
||||
Protocol: man.getConfigString("mysql.protocol"),
|
||||
Address: man.getConfigString("mysql.address"),
|
||||
Username: man.getConfigString("mysql.username"),
|
||||
Password: man.getConfigString("mysql.password"),
|
||||
PasswordPath: man.getConfigString("mysql.password_path"),
|
||||
Database: man.getConfigString("mysql.database"),
|
||||
TLSCert: man.getConfigString("mysql.tls_cert"),
|
||||
TLSKey: man.getConfigString("mysql.tls_key"),
|
||||
TLSCA: man.getConfigString("mysql.tls_ca"),
|
||||
TLSServerName: man.getConfigString("mysql.tls_server_name"),
|
||||
TLSConfig: man.getConfigString("mysql.tls_config"),
|
||||
MaxOpenConns: man.getConfigInt("mysql.max_open_conns"),
|
||||
MaxIdleConns: man.getConfigInt("mysql.max_idle_conns"),
|
||||
ConnMaxLifetime: man.getConfigInt("mysql.conn_max_lifetime"),
|
||||
},
|
||||
Mysql: loadMysqlConfig("mysql"),
|
||||
MysqlReadReplica: loadMysqlConfig("mysql_read_replica"),
|
||||
Redis: RedisConfig{
|
||||
Address: man.getConfigString("redis.address"),
|
||||
Password: man.getConfigString("redis.password"),
|
||||
|
@ -58,6 +58,7 @@ func TestConfigRoundtrip(t *testing.T) {
|
||||
// Marshal the generated config
|
||||
buf, err := yaml.Marshal(original)
|
||||
require.Nil(t, err)
|
||||
t.Log(string(buf))
|
||||
|
||||
// Manually load the serialized config
|
||||
man.viper.SetConfigType("yaml")
|
||||
|
@ -3,6 +3,7 @@ package mysql
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@ -13,7 +14,7 @@ func (d *Datastore) NewActivity(user *fleet.User, activityType string, details *
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling activity details")
|
||||
}
|
||||
_, err = d.db.Exec(
|
||||
_, err = d.writer.Exec(
|
||||
`INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`,
|
||||
user.ID,
|
||||
user.Name,
|
||||
@ -34,7 +35,7 @@ func (d *Datastore) ListActivities(opt fleet.ListOptions) ([]*fleet.Activity, er
|
||||
WHERE true`
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
|
||||
err := d.db.Select(&activities, query)
|
||||
err := d.reader.Select(&activities, query)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Activity")
|
||||
} else if err != nil {
|
||||
|
@ -25,7 +25,7 @@ func (d *Datastore) NewAppConfig(info *fleet.AppConfig) (*fleet.AppConfig, error
|
||||
func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
||||
info := &fleet.AppConfig{}
|
||||
var bytes []byte
|
||||
err := d.db.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
|
||||
err := d.reader.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "selecting app config")
|
||||
}
|
||||
@ -43,7 +43,7 @@ func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
||||
}
|
||||
|
||||
func (d *Datastore) isEventSchedulerEnabled() (bool, error) {
|
||||
rows, err := d.db.Query("SELECT @@event_scheduler")
|
||||
rows, err := d.writer.Query("SELECT @@event_scheduler")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -136,7 +136,7 @@ func (d *Datastore) SaveAppConfig(info *fleet.AppConfig) error {
|
||||
|
||||
func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) {
|
||||
var s fleet.EnrollSecret
|
||||
err := d.db.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
||||
err := d.reader.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
||||
if err != nil {
|
||||
return nil, errors.New("no matching secret found")
|
||||
}
|
||||
@ -184,7 +184,7 @@ func (d *Datastore) GetEnrollSecrets(teamID *uint) ([]*fleet.EnrollSecret, error
|
||||
args = append(args, teamID)
|
||||
}
|
||||
var secrets []*fleet.EnrollSecret
|
||||
if err := d.db.Select(&secrets, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&secrets, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "get secrets")
|
||||
}
|
||||
return secrets, nil
|
||||
|
@ -259,7 +259,7 @@ func TestAppConfigDefaults(t *testing.T) {
|
||||
defer ds.Close()
|
||||
|
||||
insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)`
|
||||
_, err := ds.db.Exec(insertAppConfigQuery, `{}`)
|
||||
_, err := ds.writer.Exec(insertAppConfigQuery, `{}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
ac, err := ds.AppConfig()
|
||||
@ -269,7 +269,7 @@ func TestAppConfigDefaults(t *testing.T) {
|
||||
require.True(t, ac.HostSettings.EnableHostUsers)
|
||||
require.False(t, ac.HostSettings.EnableSoftwareInventory)
|
||||
|
||||
_, err = ds.db.Exec(
|
||||
_, err = ds.writer.Exec(
|
||||
insertAppConfigQuery,
|
||||
`{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ func (d *Datastore) NewDistributedQueryCampaign(camp *fleet.DistributedQueryCamp
|
||||
)
|
||||
VALUES(?,?,?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
|
||||
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting distributed query campaign")
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (d *Datastore) DistributedQueryCampaign(id uint) (*fleet.DistributedQueryCa
|
||||
SELECT * FROM distributed_query_campaigns WHERE id = ?
|
||||
`
|
||||
campaign := &fleet.DistributedQueryCampaign{}
|
||||
if err := d.db.Get(campaign, sql, id); err != nil {
|
||||
if err := d.reader.Get(campaign, sql, id); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting distributed query campaign")
|
||||
}
|
||||
|
||||
@ -47,7 +47,7 @@ func (d *Datastore) SaveDistributedQueryCampaign(camp *fleet.DistributedQueryCam
|
||||
user_id = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
|
||||
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating distributed query campaign")
|
||||
}
|
||||
@ -68,7 +68,7 @@ func (d *Datastore) DistributedQueryCampaignTargetIDs(id uint) (*fleet.HostTarge
|
||||
`
|
||||
targets := []fleet.DistributedQueryCampaignTarget{}
|
||||
|
||||
if err := d.db.Select(&targets, sqlStatement, id); err != nil {
|
||||
if err := d.reader.Select(&targets, sqlStatement, id); err != nil {
|
||||
return nil, errors.Wrap(err, "select distributed campaign target")
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ func (d *Datastore) NewDistributedQueryCampaignTarget(target *fleet.DistributedQ
|
||||
)
|
||||
VALUES (?,?,?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
|
||||
result, err := d.writer.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "insert distributed campaign target")
|
||||
}
|
||||
@ -118,7 +118,7 @@ func (d *Datastore) CleanupDistributedQueryCampaigns(now time.Time) (expired uin
|
||||
WHERE (status = ? AND created_at < ?)
|
||||
OR (status = ? AND created_at < ?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, fleet.QueryComplete,
|
||||
result, err := d.writer.Exec(sqlStatement, fleet.QueryComplete,
|
||||
fleet.QueryWaiting, now.Add(-1*time.Minute),
|
||||
fleet.QueryRunning, now.Add(-24*time.Hour))
|
||||
if err != nil {
|
||||
|
@ -33,7 +33,7 @@ func (d *Datastore) NewCarve(metadata *fleet.CarveMetadata) (*fleet.CarveMetadat
|
||||
?
|
||||
)`
|
||||
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
stmt,
|
||||
metadata.HostId,
|
||||
metadata.CreatedAt.Format(mySQLTimestampFormat),
|
||||
@ -64,7 +64,7 @@ func (d *Datastore) UpdateCarve(metadata *fleet.CarveMetadata) error {
|
||||
expired = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(
|
||||
_, err := d.writer.Exec(
|
||||
stmt,
|
||||
metadata.MaxBlock,
|
||||
metadata.Expired,
|
||||
@ -163,7 +163,7 @@ func (d *Datastore) Carve(carveId int64) (*fleet.CarveMetadata, error) {
|
||||
)
|
||||
|
||||
var metadata fleet.CarveMetadata
|
||||
if err := d.db.Get(&metadata, stmt, carveId); err != nil {
|
||||
if err := d.reader.Get(&metadata, stmt, carveId); err != nil {
|
||||
return nil, errors.Wrap(err, "get carve by ID")
|
||||
}
|
||||
|
||||
@ -179,7 +179,7 @@ func (d *Datastore) CarveBySessionId(sessionId string) (*fleet.CarveMetadata, er
|
||||
)
|
||||
|
||||
var metadata fleet.CarveMetadata
|
||||
if err := d.db.Get(&metadata, stmt, sessionId); err != nil {
|
||||
if err := d.reader.Get(&metadata, stmt, sessionId); err != nil {
|
||||
return nil, errors.Wrap(err, "get carve by session ID")
|
||||
}
|
||||
|
||||
@ -195,7 +195,7 @@ func (d *Datastore) CarveByName(name string) (*fleet.CarveMetadata, error) {
|
||||
)
|
||||
|
||||
var metadata fleet.CarveMetadata
|
||||
if err := d.db.Get(&metadata, stmt, name); err != nil {
|
||||
if err := d.reader.Get(&metadata, stmt, name); err != nil {
|
||||
return nil, errors.Wrap(err, "get carve by name")
|
||||
}
|
||||
|
||||
@ -213,7 +213,7 @@ func (d *Datastore) ListCarves(opt fleet.CarveListOptions) ([]*fleet.CarveMetada
|
||||
}
|
||||
stmt = appendListOptionsToSQL(stmt, opt.ListOptions)
|
||||
carves := []*fleet.CarveMetadata{}
|
||||
if err := d.db.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
|
||||
if err := d.reader.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "list carves")
|
||||
}
|
||||
|
||||
@ -231,13 +231,14 @@ func (d *Datastore) NewBlock(metadata *fleet.CarveMetadata, blockId int64, data
|
||||
?,
|
||||
?
|
||||
)`
|
||||
if _, err := d.db.Exec(stmt, metadata.ID, blockId, data); err != nil {
|
||||
if _, err := d.writer.Exec(stmt, metadata.ID, blockId, data); err != nil {
|
||||
return errors.Wrap(err, "insert carve block")
|
||||
}
|
||||
|
||||
if metadata.MaxBlock < blockId {
|
||||
// Update max_block
|
||||
metadata.MaxBlock = blockId
|
||||
// TODO: looks like this should all be done in an atomic transaction?
|
||||
if err := d.UpdateCarve(metadata); err != nil {
|
||||
return errors.Wrap(err, "insert carve block")
|
||||
}
|
||||
@ -253,7 +254,7 @@ func (d *Datastore) GetBlock(metadata *fleet.CarveMetadata, blockId int64) ([]by
|
||||
WHERE metadata_id = ? AND block_id = ?
|
||||
`
|
||||
var data []byte
|
||||
if err := d.db.Get(&data, stmt, metadata.ID, blockId); err != nil {
|
||||
if err := d.reader.Get(&data, stmt, metadata.ID, blockId); err != nil {
|
||||
return nil, errors.Wrap(err, "select data")
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
package mysql
|
||||
|
||||
import "github.com/go-kit/kit/log"
|
||||
import (
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
const defaultMaxAttempts int = 15
|
||||
|
||||
@ -11,6 +14,7 @@ type dbOptions struct {
|
||||
// maxAttempts configures the number of retries to connect to the DB
|
||||
maxAttempts int
|
||||
logger log.Logger
|
||||
replicaConfig *config.MysqlConfig
|
||||
}
|
||||
|
||||
// Logger adds a logger to the datastore
|
||||
@ -21,6 +25,14 @@ func Logger(l log.Logger) DBOption {
|
||||
}
|
||||
}
|
||||
|
||||
// Replica sets the configuration of the read replica for the datastore.
|
||||
func Replica(conf *config.MysqlConfig) DBOption {
|
||||
return func(o *dbOptions) error {
|
||||
o.replicaConfig = conf
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// LimitAttempts sets a the number of attempts
|
||||
// to try establishing a connection to the database backend
|
||||
// the default value is 15 attempts
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
// returning a notFound error if appropriate.
|
||||
func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
||||
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable)
|
||||
result, err := d.db.Exec(deleteStmt, id)
|
||||
result, err := d.writer.Exec(deleteStmt, id)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "delete %s", dbTable)
|
||||
}
|
||||
@ -26,7 +26,7 @@ func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
||||
// table, returning a notFound error if appropriate.
|
||||
func (d *Datastore) deleteEntityByName(dbTable string, name string) error {
|
||||
deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable)
|
||||
result, err := d.db.Exec(deleteStmt, name)
|
||||
result, err := d.writer.Exec(deleteStmt, name)
|
||||
if err != nil {
|
||||
if isMySQLForeignKey(err) {
|
||||
return foreignKey(dbTable, name)
|
||||
@ -50,7 +50,7 @@ func (d *Datastore) deleteEntities(dbTable string, ids []uint) (uint, error) {
|
||||
return 0, errors.Wrapf(err, "building delete entities query %s", dbTable)
|
||||
}
|
||||
|
||||
result, err := d.db.Exec(query, args...)
|
||||
result, err := d.writer.Exec(query, args...)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable)
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ func (ds *Datastore) PendingEmailChange(uid uint, newEmail, token string) error
|
||||
new_email
|
||||
) VALUES( ?, ?, ? )
|
||||
`
|
||||
_, err := ds.db.Exec(sqlStatement, uid, token, newEmail)
|
||||
_, err := ds.writer.Exec(sqlStatement, uid, token, newEmail)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "inserting email change record")
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (ds *Datastore) ConfirmPendingEmailChange(id uint, token string) (newEmail
|
||||
Token string
|
||||
NewEmail string `db:"new_email"`
|
||||
}{}
|
||||
err = ds.db.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
|
||||
err = ds.writer.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", notFound("email change with token")
|
||||
|
@ -33,7 +33,7 @@ func (d *Datastore) NewHost(host *fleet.Host) (*fleet.Host, error) {
|
||||
)
|
||||
VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? )
|
||||
`
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
sqlStatement,
|
||||
host.OsqueryHostID,
|
||||
host.DetailUpdatedAt,
|
||||
@ -95,7 +95,7 @@ func (d *Datastore) SaveHost(host *fleet.Host) error {
|
||||
percent_disk_space_available = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(sqlStatement,
|
||||
_, err := d.writer.Exec(sqlStatement,
|
||||
host.DetailUpdatedAt,
|
||||
host.LabelUpdatedAt,
|
||||
host.NodeKey,
|
||||
@ -248,7 +248,7 @@ FROM scheduled_query_stats sqs
|
||||
WHERE host_id = ? AND p.pack_type IS NULL
|
||||
`
|
||||
var stats []fleet.ScheduledQueryStats
|
||||
if err := d.db.Select(&stats, sql, host.ID); err != nil {
|
||||
if err := d.reader.Select(&stats, sql, host.ID); err != nil {
|
||||
return errors.Wrap(err, "load pack stats")
|
||||
}
|
||||
|
||||
@ -268,9 +268,9 @@ WHERE host_id = ? AND p.pack_type IS NULL
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Datastore) loadHostUsers(host *fleet.Host) error {
|
||||
func (d *Datastore) loadHostUsers(db dbReader, host *fleet.Host) error {
|
||||
sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL`
|
||||
if err := d.db.Select(&host.Users, sql, host.ID); err != nil {
|
||||
if err := db.Select(&host.Users, sql, host.ID); err != nil {
|
||||
return errors.Wrap(err, "load pack stats")
|
||||
}
|
||||
return nil
|
||||
@ -292,14 +292,14 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
|
||||
LIMIT 1
|
||||
`
|
||||
host := &fleet.Host{}
|
||||
err := d.db.Get(host, sqlStatement, id)
|
||||
err := d.reader.Get(host, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get host by id")
|
||||
}
|
||||
if err := d.loadHostPackStats(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.loadHostUsers(host); err != nil {
|
||||
if err := d.loadHostUsers(d.reader, host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -308,7 +308,9 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
|
||||
|
||||
func (d *Datastore) amountEnrolledHosts() (int, error) {
|
||||
var amount int
|
||||
err := d.db.Get(&amount, `SELECT count(*) FROM hosts`)
|
||||
// uses the writer as it is called from ShouldSendStatistics, which is a
|
||||
// write method.
|
||||
err := d.writer.Get(&amount, `SELECT count(*) FROM hosts`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -363,7 +365,7 @@ func (d *Datastore) ListHosts(filter fleet.TeamFilter, opt fleet.HostListOptions
|
||||
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
||||
|
||||
hosts := []*fleet.Host{}
|
||||
if err := d.db.Select(&hosts, sql, params...); err != nil {
|
||||
if err := d.reader.Select(&hosts, sql, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "list hosts")
|
||||
}
|
||||
|
||||
@ -413,7 +415,7 @@ func (d *Datastore) CleanupIncomingHosts(now time.Time) error {
|
||||
WHERE hostname = '' AND osquery_version = ''
|
||||
AND created_at < (? - INTERVAL 5 MINUTE)
|
||||
`
|
||||
if _, err := d.db.Exec(sqlStatement, now); err != nil {
|
||||
if _, err := d.writer.Exec(sqlStatement, now); err != nil {
|
||||
return errors.Wrap(err, "cleanup incoming hosts")
|
||||
}
|
||||
|
||||
@ -442,7 +444,7 @@ func (d *Datastore) GenerateHostStatusStatistics(filter fleet.TeamFilter, now ti
|
||||
Online uint `db:"online"`
|
||||
New uint `db:"new"`
|
||||
}{}
|
||||
err := d.db.Get(&counts, sqlStatement, now, now, now, now, now)
|
||||
err := d.reader.Get(&counts, sqlStatement, now, now, now, now, now)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
e = errors.Wrap(err, "generating host statistics")
|
||||
return
|
||||
@ -582,7 +584,7 @@ func (d *Datastore) AuthenticateHost(nodeKey string) (*fleet.Host, error) {
|
||||
`
|
||||
|
||||
host := &fleet.Host{}
|
||||
if err := d.db.Get(host, sqlStatement, nodeKey); err != nil {
|
||||
if err := d.reader.Get(host, sqlStatement, nodeKey); err != nil {
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, notFound("Host")
|
||||
@ -601,7 +603,7 @@ func (d *Datastore) MarkHostSeen(host *fleet.Host, t time.Time) error {
|
||||
WHERE node_key=?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(sqlStatement, t, host.NodeKey)
|
||||
_, err := d.writer.Exec(sqlStatement, t, host.NodeKey)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marking host seen")
|
||||
}
|
||||
@ -659,11 +661,11 @@ func (d *Datastore) searchHostsWithOmits(filter fleet.TeamFilter, query string,
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching hosts")
|
||||
}
|
||||
sql = d.db.Rebind(sql)
|
||||
sql = d.reader.Rebind(sql)
|
||||
|
||||
hosts := []*fleet.Host{}
|
||||
|
||||
err = d.db.Select(&hosts, sql, args...)
|
||||
err = d.reader.Select(&hosts, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching hosts rebound")
|
||||
}
|
||||
@ -695,8 +697,8 @@ func (d *Datastore) searchHostsDefault(filter fleet.TeamFilter, omit ...uint) ([
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching default hosts")
|
||||
}
|
||||
sql = d.db.Rebind(sql)
|
||||
err = d.db.Select(&hosts, sql, args...)
|
||||
sql = d.reader.Rebind(sql)
|
||||
err = d.reader.Select(&hosts, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching default hosts rebound")
|
||||
}
|
||||
@ -730,7 +732,7 @@ func (d *Datastore) SearchHosts(filter fleet.TeamFilter, query string, omit ...u
|
||||
)
|
||||
|
||||
hosts := []*fleet.Host{}
|
||||
if err := d.db.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
|
||||
if err := d.reader.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
|
||||
return nil, errors.Wrap(err, "searching hosts")
|
||||
}
|
||||
|
||||
@ -755,7 +757,7 @@ func (d *Datastore) HostIDsByName(filter fleet.TeamFilter, hostnames []string) (
|
||||
}
|
||||
|
||||
var hostIDs []uint
|
||||
if err := d.db.Select(&hostIDs, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&hostIDs, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "get host IDs")
|
||||
}
|
||||
|
||||
@ -770,7 +772,7 @@ func (d *Datastore) HostByIdentifier(identifier string) (*fleet.Host, error) {
|
||||
LIMIT 1
|
||||
`
|
||||
host := &fleet.Host{}
|
||||
err := d.db.Get(host, sql, identifier)
|
||||
err := d.reader.Get(host, sql, identifier)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get host by identifier")
|
||||
}
|
||||
@ -796,7 +798,7 @@ func (d *Datastore) AddHostsToTeam(teamID *uint, hostIDs []uint) error {
|
||||
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
|
||||
}
|
||||
|
||||
if _, err := d.db.Exec(sql, args...); err != nil {
|
||||
if _, err := d.writer.Exec(sql, args...); err != nil {
|
||||
return errors.Wrap(err, "exec AddHostsToTeam")
|
||||
}
|
||||
|
||||
@ -809,7 +811,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
|
||||
VALUES (?, ?)
|
||||
ON DUPLICATE KEY UPDATE additional = VALUES(additional)
|
||||
`
|
||||
if _, err := d.db.Exec(sql, host.ID, host.Additional); err != nil {
|
||||
if _, err := d.writer.Exec(sql, host.ID, host.Additional); err != nil {
|
||||
return errors.Wrap(err, "insert additional")
|
||||
}
|
||||
|
||||
@ -818,7 +820,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
|
||||
|
||||
func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
if len(host.Users) == 0 {
|
||||
if _, err := d.db.Exec(
|
||||
if _, err := d.writer.Exec(
|
||||
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`,
|
||||
host.ID,
|
||||
); err != nil {
|
||||
@ -829,7 +831,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
}
|
||||
|
||||
currentHost := &fleet.Host{ID: host.ID}
|
||||
if err := d.loadHostUsers(currentHost); err != nil {
|
||||
if err := d.loadHostUsers(d.writer, currentHost); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -837,7 +839,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
for _, u := range host.Users {
|
||||
incomingUsers[u.Uid] = true
|
||||
|
||||
if _, err := d.db.Exec(
|
||||
if _, err := d.writer.Exec(
|
||||
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
|
||||
host.ID, u.Uid, u.Username, u.Type, u.GroupName,
|
||||
); err != nil {
|
||||
@ -860,7 +862,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`,
|
||||
removedValues,
|
||||
)
|
||||
if _, err := d.db.Exec(removedSql, removedArgs...); err != nil {
|
||||
if _, err := d.writer.Exec(removedSql, removedArgs...); err != nil {
|
||||
return errors.Wrap(err, "mark users as removed")
|
||||
}
|
||||
|
||||
@ -869,15 +871,15 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
|
||||
func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) {
|
||||
var totalCount, unseenCount int
|
||||
err := d.db.QueryRow("SELECT count(*) FROM hosts").Scan(&totalCount)
|
||||
err := d.reader.Get(&totalCount, "SELECT count(*) FROM hosts")
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(err, "getting total host count")
|
||||
}
|
||||
|
||||
err = d.db.QueryRow(
|
||||
err = d.reader.Get(&unseenCount,
|
||||
"SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?",
|
||||
daysCount,
|
||||
).Scan(&unseenCount)
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(err, "getting unseen host count")
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
|
||||
VALUES ( ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := d.db.Exec(sqlStmt, i.InvitedBy, i.Email,
|
||||
result, err := d.writer.Exec(sqlStmt, i.InvitedBy, i.Email,
|
||||
i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole)
|
||||
if err != nil && isDuplicate(err) {
|
||||
return nil, alreadyExists("Invite", i.Email)
|
||||
@ -45,10 +45,11 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
|
||||
for _, userTeam := range i.Teams {
|
||||
args = append(args, i.ID, userTeam.Team.ID, userTeam.Role)
|
||||
}
|
||||
// TODO: seems like this should be in a transaction?
|
||||
sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " +
|
||||
strings.Repeat(valueStr, len(i.Teams))
|
||||
sql = strings.TrimSuffix(sql, ",")
|
||||
if _, err := d.db.Exec(sql, args...); err != nil {
|
||||
if _, err := d.writer.Exec(sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "insert teams")
|
||||
}
|
||||
|
||||
@ -63,7 +64,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
|
||||
query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...)
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
|
||||
err := d.db.Select(&invites, query, params...)
|
||||
err := d.reader.Select(&invites, query, params...)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite")
|
||||
} else if err != nil {
|
||||
@ -80,7 +81,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
|
||||
// Invite returns Invite identified by id.
|
||||
func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
||||
var invite fleet.Invite
|
||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
|
||||
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite").WithID(id)
|
||||
} else if err != nil {
|
||||
@ -97,7 +98,7 @@ func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
||||
// InviteByEmail finds an Invite with a particular email, if one exists.
|
||||
func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
||||
var invite fleet.Invite
|
||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
|
||||
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite").
|
||||
WithMessage(fmt.Sprintf("with email %s", email))
|
||||
@ -115,7 +116,7 @@ func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
||||
// InviteByToken finds an Invite with a particular token, if one exists.
|
||||
func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) {
|
||||
var invite fleet.Invite
|
||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
|
||||
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite").
|
||||
WithMessage(fmt.Sprintf("with token %s", token))
|
||||
@ -163,7 +164,7 @@ func (d *Datastore) loadTeamsForInvites(invites []*fleet.Invite) error {
|
||||
fleet.UserTeam
|
||||
InviteID uint `db:"invite_id"`
|
||||
}
|
||||
if err := d.db.Select(&rows, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&rows, sql, args...); err != nil {
|
||||
return errors.Wrap(err, "get loadTeamsForInvites")
|
||||
}
|
||||
|
||||
|
@ -115,7 +115,7 @@ func (d *Datastore) GetLabelSpecs() ([]*fleet.LabelSpec, error) {
|
||||
var specs []*fleet.LabelSpec
|
||||
// Get basic specs
|
||||
query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels"
|
||||
if err := d.db.Select(&specs, query); err != nil {
|
||||
if err := d.reader.Select(&specs, query); err != nil {
|
||||
return nil, errors.Wrap(err, "get labels")
|
||||
}
|
||||
|
||||
@ -138,7 +138,7 @@ SELECT name, description, query, platform, label_type, label_membership_type
|
||||
FROM labels
|
||||
WHERE name = ?
|
||||
`
|
||||
if err := d.db.Select(&specs, query, name); err != nil {
|
||||
if err := d.reader.Select(&specs, query, name); err != nil {
|
||||
return nil, errors.Wrap(err, "get label")
|
||||
}
|
||||
if len(specs) == 0 {
|
||||
@ -171,7 +171,7 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
|
||||
WHERE label_id = (SELECT id FROM labels WHERE name = ?)
|
||||
)
|
||||
`
|
||||
err := d.db.Select(&label.Hosts, sql, label.Name)
|
||||
err := d.reader.Select(&label.Hosts, sql, label.Name)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get hostnames for label")
|
||||
}
|
||||
@ -190,7 +190,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
|
||||
label_membership_type
|
||||
) VALUES ( ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
query,
|
||||
label.Name,
|
||||
label.Description,
|
||||
@ -216,7 +216,7 @@ func (d *Datastore) SaveLabel(label *fleet.Label) (*fleet.Label, error) {
|
||||
description = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(query, label.Name, label.Description, label.ID)
|
||||
_, err := d.writer.Exec(query, label.Name, label.Description, label.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "saving label")
|
||||
}
|
||||
@ -236,7 +236,7 @@ func (d *Datastore) Label(lid uint) (*fleet.Label, error) {
|
||||
`
|
||||
label := &fleet.Label{}
|
||||
|
||||
if err := d.db.Get(label, sql, lid); err != nil {
|
||||
if err := d.reader.Get(label, sql, lid); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting label")
|
||||
}
|
||||
|
||||
@ -255,7 +255,7 @@ func (d *Datastore) ListLabels(filter fleet.TeamFilter, opt fleet.ListOptions) (
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
labels := []*fleet.Label{}
|
||||
|
||||
if err := d.db.Select(&labels, query); err != nil {
|
||||
if err := d.reader.Select(&labels, query); err != nil {
|
||||
// it's ok if no labels exist
|
||||
if err == sql.ErrNoRows {
|
||||
return labels, nil
|
||||
@ -288,7 +288,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
|
||||
WHERE platform = ? OR platform = ''
|
||||
AND label_membership_type = ?
|
||||
`
|
||||
rows, err = d.db.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
|
||||
rows, err = d.reader.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
|
||||
} else {
|
||||
// Retrieve all labels (with matching platform) iff there is a label
|
||||
// that has been created since this host last reported label query
|
||||
@ -300,7 +300,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
|
||||
AND (platform = ? OR platform = '')
|
||||
AND label_membership_type = ?
|
||||
`
|
||||
rows, err = d.db.Query(
|
||||
rows, err = d.reader.Query(
|
||||
sql,
|
||||
platform,
|
||||
host.LabelUpdatedAt,
|
||||
@ -369,7 +369,7 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
|
||||
updated_at = VALUES(updated_at)
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(sql, vals...)
|
||||
_, err := d.writer.Exec(sql, vals...)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "insert label query executions (%v)", vals)
|
||||
}
|
||||
@ -384,8 +384,8 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "IN for DELETE FROM label_membership")
|
||||
}
|
||||
query = d.db.Rebind(query)
|
||||
_, err = d.db.Exec(query, args...)
|
||||
query = d.writer.Rebind(query)
|
||||
_, err = d.writer.Exec(query, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "delete label query executions")
|
||||
}
|
||||
@ -403,7 +403,7 @@ func (d *Datastore) ListLabelsForHost(hid uint) ([]*fleet.Label, error) {
|
||||
`
|
||||
|
||||
labels := []*fleet.Label{}
|
||||
err := d.db.Select(&labels, sqlStatement, hid)
|
||||
err := d.reader.Select(&labels, sqlStatement, hid)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting host labels")
|
||||
}
|
||||
@ -432,7 +432,7 @@ func (d *Datastore) ListHostsInLabel(filter fleet.TeamFilter, lid uint, opt flee
|
||||
|
||||
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
||||
hosts := []*fleet.Host{}
|
||||
err := d.db.Select(&hosts, sql, params...)
|
||||
err := d.reader.Select(&hosts, sql, params...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting label query executions")
|
||||
}
|
||||
@ -458,9 +458,9 @@ func (d *Datastore) ListUniqueHostsInLabels(filter fleet.TeamFilter, labels []ui
|
||||
return nil, errors.Wrap(err, "building query listing unique hosts in labels")
|
||||
}
|
||||
|
||||
query = d.db.Rebind(query)
|
||||
query = d.reader.Rebind(query)
|
||||
hosts := []*fleet.Host{}
|
||||
err = d.db.Select(&hosts, query, args...)
|
||||
err = d.reader.Select(&hosts, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "listing unique hosts in labels")
|
||||
}
|
||||
@ -492,10 +492,10 @@ func (d *Datastore) searchLabelsWithOmits(filter fleet.TeamFilter, query string,
|
||||
return nil, errors.Wrap(err, "building query for labels with omits")
|
||||
}
|
||||
|
||||
sql = d.db.Rebind(sql)
|
||||
sql = d.reader.Rebind(sql)
|
||||
|
||||
matches := []*fleet.Label{}
|
||||
err = d.db.Select(&matches, sql, args...)
|
||||
err = d.reader.Select(&matches, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting labels with omits")
|
||||
}
|
||||
@ -527,7 +527,7 @@ func (d *Datastore) addAllHostsLabelToList(filter fleet.TeamFilter, labels []*fl
|
||||
)
|
||||
|
||||
var allHosts fleet.Label
|
||||
if err := d.db.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
|
||||
if err := d.reader.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
|
||||
return nil, errors.Wrap(err, "get all hosts label")
|
||||
}
|
||||
|
||||
@ -575,8 +575,8 @@ func (d *Datastore) searchLabelsDefault(filter fleet.TeamFilter, omit ...uint) (
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching default labels")
|
||||
}
|
||||
sql = d.db.Rebind(sql)
|
||||
if err := d.db.Select(&labels, sql, args...); err != nil {
|
||||
sql = d.reader.Rebind(sql)
|
||||
if err := d.reader.Select(&labels, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "searching default labels rebound")
|
||||
}
|
||||
|
||||
@ -617,7 +617,7 @@ func (d *Datastore) SearchLabels(filter fleet.TeamFilter, query string, omit ...
|
||||
)
|
||||
|
||||
matches := []*fleet.Label{}
|
||||
if err := d.db.Select(&matches, sql, transformedQuery); err != nil {
|
||||
if err := d.reader.Select(&matches, sql, transformedQuery); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting labels for search")
|
||||
}
|
||||
|
||||
@ -645,7 +645,7 @@ func (d *Datastore) LabelIDsByName(labels []string) ([]uint, error) {
|
||||
}
|
||||
|
||||
var labelIDs []uint
|
||||
if err := d.db.Select(&labelIDs, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&labelIDs, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "get label IDs")
|
||||
}
|
||||
|
||||
|
@ -29,27 +29,27 @@ func (d *Datastore) Lock(name string, owner string, expiration time.Duration) (b
|
||||
}
|
||||
|
||||
func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||
return d.db.Exec(
|
||||
return d.writer.Exec(
|
||||
`INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`,
|
||||
name, owner, time.Now().Add(expiration),
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||
return d.db.Exec(
|
||||
return d.writer.Exec(
|
||||
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`,
|
||||
name, owner, time.Now().Add(expiration), name, owner,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||
return d.db.Exec(
|
||||
return d.writer.Exec(
|
||||
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`,
|
||||
name, owner, time.Now().Add(expiration), name,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Datastore) Unlock(name string, owner string) error {
|
||||
_, err := d.db.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
|
||||
_, err := d.writer.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
|
||||
return err
|
||||
}
|
||||
|
@ -85,14 +85,14 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
|
||||
defer ds.Close()
|
||||
|
||||
for {
|
||||
version, err := tables.MigrationClient.GetDBVersion(ds.db.DB)
|
||||
version, err := tables.MigrationClient.GetDBVersion(ds.writer.DB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// break right before the the constraint migration
|
||||
if version == 20210818182258 {
|
||||
break
|
||||
}
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
|
||||
}
|
||||
|
||||
host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now())
|
||||
@ -122,7 +122,7 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
|
||||
|
||||
require.NoError(t, ds.DeleteHost(host1.ID))
|
||||
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
|
||||
|
||||
// Make sure we don't delete more than we need
|
||||
hostCheck, err := ds.Host(host2.ID)
|
||||
|
@ -37,13 +37,28 @@ var (
|
||||
columnCharsRegexp = regexp.MustCompile(`[^\w-]`)
|
||||
)
|
||||
|
||||
// dbReader is an interface that defines the methods required for reads.
|
||||
type dbReader interface {
|
||||
sqlx.Queryer
|
||||
|
||||
Close() error
|
||||
Rebind(string) string
|
||||
Select(interface{}, string, ...interface{}) error
|
||||
Get(interface{}, string, ...interface{}) error
|
||||
}
|
||||
|
||||
// Datastore is an implementation of fleet.Datastore interface backed by
|
||||
// MySQL
|
||||
type Datastore struct {
|
||||
db *sqlx.DB
|
||||
reader dbReader // so it cannot be used to perform writes
|
||||
writer *sqlx.DB
|
||||
|
||||
logger log.Logger
|
||||
clock clock.Clock
|
||||
config config.MysqlConfig
|
||||
|
||||
// nil if no read replica
|
||||
readReplicaConfig *config.MysqlConfig
|
||||
}
|
||||
|
||||
type txFn func(*sqlx.Tx) error
|
||||
@ -67,7 +82,7 @@ func retryableError(err error) bool {
|
||||
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
||||
func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
||||
operation := func() error {
|
||||
tx, err := d.db.Beginx()
|
||||
tx, err := d.writer.Beginx()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create transaction")
|
||||
}
|
||||
@ -116,7 +131,7 @@ func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
||||
|
||||
// withTx provides a common way to commit/rollback a txFn
|
||||
func (d *Datastore) withTx(fn txFn) (err error) {
|
||||
tx, err := d.db.Beginx()
|
||||
tx, err := d.writer.Beginx()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create transaction")
|
||||
}
|
||||
@ -153,51 +168,64 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
||||
}
|
||||
|
||||
for _, setOpt := range opts {
|
||||
if setOpt != nil {
|
||||
setOpt(options)
|
||||
}
|
||||
|
||||
if config.PasswordPath != "" && config.Password != "" {
|
||||
return nil, errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
||||
}
|
||||
|
||||
// Check to see if the flag is populated
|
||||
// Check if file exists on disk
|
||||
// If file exists read contents
|
||||
if config.PasswordPath != "" {
|
||||
fileContents, err := ioutil.ReadFile(config.PasswordPath)
|
||||
if err := checkConfig(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if options.replicaConfig != nil {
|
||||
if err := checkConfig(options.replicaConfig); err != nil {
|
||||
return nil, errors.Wrap(err, "replica")
|
||||
}
|
||||
}
|
||||
|
||||
dbWriter, err := newDB(&config, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Password = strings.TrimSpace(string(fileContents))
|
||||
}
|
||||
|
||||
if config.TLSCA != "" {
|
||||
config.TLSConfig = "custom"
|
||||
err := registerTLS(config)
|
||||
dbReader := dbWriter
|
||||
if options.replicaConfig != nil {
|
||||
dbReader, err = newDB(options.replicaConfig, options)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "register TLS config for mysql")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dsn := generateMysqlConnectionString(config)
|
||||
ds := &Datastore{
|
||||
writer: dbWriter,
|
||||
reader: dbReader,
|
||||
logger: options.logger,
|
||||
clock: c,
|
||||
config: config,
|
||||
readReplicaConfig: options.replicaConfig,
|
||||
}
|
||||
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
func newDB(conf *config.MysqlConfig, opts *dbOptions) (*sqlx.DB, error) {
|
||||
dsn := generateMysqlConnectionString(*conf)
|
||||
db, err := sqlx.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetime))
|
||||
db.SetMaxIdleConns(conf.MaxIdleConns)
|
||||
db.SetMaxOpenConns(conf.MaxOpenConns)
|
||||
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
|
||||
|
||||
var dbError error
|
||||
for attempt := 0; attempt < options.maxAttempts; attempt++ {
|
||||
for attempt := 0; attempt < opts.maxAttempts; attempt++ {
|
||||
dbError = db.Ping()
|
||||
if dbError == nil {
|
||||
// we're connected!
|
||||
break
|
||||
}
|
||||
interval := time.Duration(attempt) * time.Second
|
||||
options.logger.Log("mysql", fmt.Sprintf(
|
||||
opts.logger.Log("mysql", fmt.Sprintf(
|
||||
"could not connect to db: %v, sleeping %v", dbError, interval))
|
||||
time.Sleep(interval)
|
||||
}
|
||||
@ -205,20 +233,37 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
||||
if dbError != nil {
|
||||
return nil, dbError
|
||||
}
|
||||
|
||||
ds := &Datastore{
|
||||
db: db,
|
||||
logger: options.logger,
|
||||
clock: c,
|
||||
config: config,
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return ds, nil
|
||||
func checkConfig(conf *config.MysqlConfig) error {
|
||||
if conf.PasswordPath != "" && conf.Password != "" {
|
||||
return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
||||
}
|
||||
|
||||
// Check to see if the flag is populated
|
||||
// Check if file exists on disk
|
||||
// If file exists read contents
|
||||
if conf.PasswordPath != "" {
|
||||
fileContents, err := ioutil.ReadFile(conf.PasswordPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conf.Password = strings.TrimSpace(string(fileContents))
|
||||
}
|
||||
|
||||
if conf.TLSCA != "" {
|
||||
conf.TLSConfig = "custom"
|
||||
err := registerTLS(*conf)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "register TLS config for mysql")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Datastore) Begin() (fleet.Transaction, error) {
|
||||
return d.db.Beginx()
|
||||
return d.writer.Beginx()
|
||||
}
|
||||
|
||||
func (d *Datastore) Name() string {
|
||||
@ -226,11 +271,11 @@ func (d *Datastore) Name() string {
|
||||
}
|
||||
|
||||
func (d *Datastore) MigrateTables() error {
|
||||
return tables.MigrationClient.Up(d.db.DB, "")
|
||||
return tables.MigrationClient.Up(d.writer.DB, "")
|
||||
}
|
||||
|
||||
func (d *Datastore) MigrateData() error {
|
||||
return data.MigrationClient.Up(d.db.DB, "")
|
||||
return data.MigrationClient.Up(d.writer.DB, "")
|
||||
}
|
||||
|
||||
func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||
@ -243,7 +288,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||
return 0, errors.Wrap(err, "missing tables migrations")
|
||||
}
|
||||
|
||||
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.db.DB)
|
||||
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.writer.DB)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "cannot get table migration status")
|
||||
}
|
||||
@ -253,7 +298,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||
return 0, errors.Wrap(err, "missing data migrations")
|
||||
}
|
||||
|
||||
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.db.DB)
|
||||
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.writer.DB)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "cannot get data migration status")
|
||||
}
|
||||
@ -283,11 +328,11 @@ func (d *Datastore) Drop() error {
|
||||
WHERE TABLE_SCHEMA = ?;
|
||||
`
|
||||
|
||||
if err := d.db.Select(&tables, sql, d.config.Database); err != nil {
|
||||
if err := d.writer.Select(&tables, sql, d.config.Database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := d.db.Begin()
|
||||
tx, err := d.writer.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -312,13 +357,28 @@ func (d *Datastore) Drop() error {
|
||||
|
||||
// HealthCheck returns an error if the MySQL backend is not healthy.
|
||||
func (d *Datastore) HealthCheck() error {
|
||||
_, err := d.db.Exec("select 1")
|
||||
if _, err := d.writer.Exec("select 1"); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.readReplicaConfig != nil {
|
||||
var dst int
|
||||
if err := d.reader.Get(&dst, "select 1"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close frees resources associated with underlying mysql connection
|
||||
func (d *Datastore) Close() error {
|
||||
return d.db.Close()
|
||||
err := d.writer.Close()
|
||||
if d.readReplicaConfig != nil {
|
||||
errRead := d.reader.Close()
|
||||
if err == nil {
|
||||
err = errRead
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeColumn(col string) string {
|
||||
|
@ -1,7 +1,9 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/VividCortex/mysqlerr"
|
||||
@ -15,6 +17,51 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDatastoreReplica(t *testing.T) {
|
||||
// a bit unfortunate to create temp databases just for this - could be mixed
|
||||
// with other tests when/if we move to subtests to minimize the number of
|
||||
// databases created for tests (see #1805).
|
||||
|
||||
t.Run("noreplica", func(t *testing.T) {
|
||||
ds := CreateMySQLDSWithOptions(t, nil)
|
||||
defer ds.Close()
|
||||
require.Equal(t, ds.reader, ds.writer)
|
||||
})
|
||||
|
||||
t.Run("replica", func(t *testing.T) {
|
||||
opts := &DatastoreTestOptions{Replica: true}
|
||||
ds := CreateMySQLDSWithOptions(t, opts)
|
||||
defer ds.Close()
|
||||
require.NotEqual(t, ds.reader, ds.writer)
|
||||
|
||||
// create a new host
|
||||
host, err := ds.NewHost(&fleet.Host{
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
NodeKey: "1",
|
||||
UUID: "1",
|
||||
Hostname: "foo.local",
|
||||
PrimaryIP: "192.168.1.1",
|
||||
PrimaryMac: "30-65-EC-6F-C4-58",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, host)
|
||||
|
||||
// trying to read it fails, not replicated yet
|
||||
_, err = ds.Host(host.ID)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, sql.ErrNoRows))
|
||||
|
||||
opts.RunReplication()
|
||||
|
||||
// now it can read it
|
||||
host2, err := ds.Host(host.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, host.ID, host2.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeColumn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -121,8 +168,10 @@ func TestSearchLike(t *testing.T) {
|
||||
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
dbmock := sqlx.NewDb(db, "sqlmock")
|
||||
ds := &Datastore{
|
||||
db: sqlx.NewDb(db, "sqlmock"),
|
||||
writer: dbmock,
|
||||
reader: dbmock,
|
||||
logger: log.NewNopLogger(),
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
|
||||
WHERE name = ?
|
||||
`
|
||||
var pack fleet.Pack
|
||||
err := d.db.Get(&pack, sqlStatement, name)
|
||||
err := d.reader.Get(&pack, sqlStatement, name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, false, nil
|
||||
@ -213,7 +213,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
|
||||
return nil, false, errors.Wrap(err, "fetch pack by name")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(&pack); err != nil {
|
||||
if err := d.loadPackTargets(d.reader, &pack); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
@ -313,10 +313,10 @@ func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Datastore) loadPackTargets(pack *fleet.Pack) error {
|
||||
func (d *Datastore) loadPackTargets(db dbReader, pack *fleet.Pack) error {
|
||||
var targets []fleet.PackTarget
|
||||
sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
|
||||
if err := d.db.Select(&targets, sql, pack.ID); err != nil {
|
||||
if err := db.Select(&targets, sql, pack.ID); err != nil {
|
||||
return errors.Wrap(err, "select pack targets")
|
||||
}
|
||||
|
||||
@ -371,14 +371,14 @@ func (d *Datastore) DeletePack(name string) error {
|
||||
func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
||||
query := `SELECT * FROM packs WHERE id = ?`
|
||||
pack := &fleet.Pack{}
|
||||
err := d.db.Get(pack, query, pid)
|
||||
err := d.reader.Get(pack, query, pid)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Pack").WithID(pid)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrap(err, "get pack")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.reader, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -388,14 +388,15 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
||||
// EnsureGlobalPack gets or inserts a pack with type global
|
||||
func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) {
|
||||
pack := &fleet.Pack{}
|
||||
err := d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
|
||||
// read from primary as we will create the pack if it doesn't exist
|
||||
err := d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
|
||||
if err == sql.ErrNoRows {
|
||||
return d.insertNewGlobalPack()
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrap(err, "get pack")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.writer, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -439,14 +440,15 @@ func (d *Datastore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) {
|
||||
}
|
||||
|
||||
teamType := fmt.Sprintf("team-%d", teamID)
|
||||
err = d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
|
||||
// read from primary as we will create the team pack if it doesn't exist
|
||||
err = d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
|
||||
if err == sql.ErrNoRows {
|
||||
return d.insertNewTeamPack(t)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrap(err, "get pack")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.writer, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -499,13 +501,13 @@ func (d *Datastore) ListPacks(opt fleet.PackListOptions) ([]*fleet.Pack, error)
|
||||
query = `SELECT * FROM packs`
|
||||
}
|
||||
var packs []*fleet.Pack
|
||||
err := d.db.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
|
||||
err := d.reader.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "listing packs")
|
||||
}
|
||||
|
||||
for _, pack := range packs {
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.reader, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -541,7 +543,7 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
|
||||
`
|
||||
|
||||
packs := []*fleet.Pack{}
|
||||
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
|
||||
if err := d.reader.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "listing hosts in pack")
|
||||
}
|
||||
return packs, nil
|
||||
|
@ -524,7 +524,7 @@ func TestTeamScheduleNamesMigrateToNewFormat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// insert team pack by hand with the old naming scheme
|
||||
_, err = ds.db.Exec(
|
||||
_, err = ds.writer.Exec(
|
||||
"INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)",
|
||||
teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1),
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ func (d *Datastore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
|
||||
( user_id, token, expires_at)
|
||||
VALUES (?,?, NOW())
|
||||
`
|
||||
response, err := d.db.Exec(sqlStatement, req.UserID, req.Token)
|
||||
response, err := d.writer.Exec(sqlStatement, req.UserID, req.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting password reset requests")
|
||||
}
|
||||
@ -30,7 +30,7 @@ func (d *Datastore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) er
|
||||
token = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
|
||||
result, err := d.writer.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating password reset requests")
|
||||
}
|
||||
@ -58,7 +58,7 @@ func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
|
||||
sqlStatement := `
|
||||
DELETE FROM password_reset_requests WHERE user_id = ?
|
||||
`
|
||||
_, err := d.db.Exec(sqlStatement, userID)
|
||||
_, err := d.writer.Exec(sqlStatement, userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "deleting password reset request by user")
|
||||
}
|
||||
@ -72,7 +72,7 @@ func (d *Datastore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||
err := d.db.Get(&passwordResetRequest, sqlStatement, id)
|
||||
err := d.reader.Get(&passwordResetRequest, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting password reset by id")
|
||||
}
|
||||
@ -87,7 +87,7 @@ func (d *Datastore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordReset
|
||||
`
|
||||
|
||||
passwordResetRequests := []*fleet.PasswordResetRequest{}
|
||||
err := d.db.Select(&passwordResetRequests, sqlStatement, id)
|
||||
err := d.reader.Select(&passwordResetRequests, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "finding password resets by user id")
|
||||
}
|
||||
@ -102,7 +102,7 @@ func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordRese
|
||||
WHERE token = ? LIMIT 1
|
||||
`
|
||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||
err := d.db.Get(passwordResetRequest, sqlStatement, token)
|
||||
err := d.reader.Get(passwordResetRequest, sqlStatement, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting password reset requests")
|
||||
}
|
||||
@ -118,7 +118,7 @@ func (d *Datastore) FindPassswordResetByTokenAndUserID(token string, id uint) (*
|
||||
LIMIT 1
|
||||
`
|
||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||
err := d.db.Get(passwordResetRequest, sqlStatement, id, token)
|
||||
err := d.reader.Get(passwordResetRequest, sqlStatement, id, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting password reset by token and user id")
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
||||
res, err := ds.db.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
|
||||
res, err := ds.writer.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting new policy")
|
||||
}
|
||||
@ -26,7 +26,7 @@ func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
||||
|
||||
func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) {
|
||||
var policy fleet.Policy
|
||||
err := ds.db.Get(
|
||||
err := ds.reader.Get(
|
||||
&policy,
|
||||
`SELECT
|
||||
p.*,
|
||||
@ -66,7 +66,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
|
||||
strings.Join(bindvars, ","),
|
||||
)
|
||||
|
||||
_, err := ds.db.Exec(query, vals...)
|
||||
_, err := ds.writer.Exec(query, vals...)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "insert policy_membership (%v)", vals)
|
||||
}
|
||||
@ -76,7 +76,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
|
||||
|
||||
func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) {
|
||||
var policies []*fleet.Policy
|
||||
err := ds.db.Select(
|
||||
err := ds.reader.Select(
|
||||
&policies,
|
||||
`SELECT
|
||||
p.*,
|
||||
@ -97,8 +97,8 @@ func (ds *Datastore) DeleteGlobalPolicies(ids []uint) ([]uint, error) {
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "IN for DELETE FROM policies")
|
||||
}
|
||||
stmt = ds.db.Rebind(stmt)
|
||||
if _, err := ds.db.Exec(stmt, args...); err != nil {
|
||||
stmt = ds.writer.Rebind(stmt)
|
||||
if _, err := ds.writer.Exec(stmt, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "delete policies")
|
||||
}
|
||||
return ids, nil
|
||||
@ -109,7 +109,7 @@ func (ds *Datastore) PolicyQueriesForHost(_ *fleet.Host) (map[string]string, err
|
||||
Id string `db:"id"`
|
||||
Query string `db:"query"`
|
||||
}
|
||||
err := ds.db.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
|
||||
err := ds.reader.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting policies for host")
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) {
|
||||
tx, err := d.db.Begin()
|
||||
tx, err := d.writer.Begin()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "begin ApplyQueries transaction")
|
||||
}
|
||||
@ -72,7 +72,7 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
|
||||
WHERE name = ?
|
||||
`
|
||||
var query fleet.Query
|
||||
err := d.db.Get(&query, sqlStatement, name)
|
||||
err := d.reader.Get(&query, sqlStatement, name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Query").WithName(name)
|
||||
@ -99,7 +99,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
|
||||
observer_can_run
|
||||
) VALUES ( ?, ?, ?, ?, ?, ? )
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
|
||||
result, err := d.writer.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
|
||||
|
||||
if err != nil && isDuplicate(err) {
|
||||
return nil, alreadyExists("Query", 0)
|
||||
@ -120,7 +120,7 @@ func (d *Datastore) SaveQuery(q *fleet.Query) error {
|
||||
SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
|
||||
result, err := d.writer.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating query")
|
||||
}
|
||||
@ -156,7 +156,7 @@ func (d *Datastore) Query(id uint) (*fleet.Query, error) {
|
||||
WHERE q.id = ?
|
||||
`
|
||||
query := &fleet.Query{}
|
||||
if err := d.db.Get(query, sql, id); err != nil {
|
||||
if err := d.reader.Get(query, sql, id); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting query")
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func (d *Datastore) ListQueries(opt fleet.ListOptions) ([]*fleet.Query, error) {
|
||||
sql = appendListOptionsToSQL(sql, opt)
|
||||
results := []*fleet.Query{}
|
||||
|
||||
if err := d.db.Select(&results, sql); err != nil {
|
||||
if err := d.reader.Select(&results, sql); err != nil {
|
||||
return nil, errors.Wrap(err, "listing queries")
|
||||
}
|
||||
|
||||
@ -225,7 +225,7 @@ func (d *Datastore) loadPacksForQueries(queries []*fleet.Query) error {
|
||||
fleet.Pack
|
||||
}{}
|
||||
|
||||
err = d.db.Select(&rows, query, args...)
|
||||
err = d.reader.Select(&rows, query, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "selecting load packs for queries")
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
|
||||
query = appendListOptionsToSQL(query, opts)
|
||||
results := []*fleet.ScheduledQuery{}
|
||||
|
||||
if err := d.db.Select(&results, query, id); err != nil {
|
||||
if err := d.reader.Select(&results, query, id); err != nil {
|
||||
return nil, errors.Wrap(err, "listing scheduled queries")
|
||||
}
|
||||
|
||||
@ -46,8 +46,8 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
|
||||
}
|
||||
|
||||
func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
||||
selectFunc := d.db.Select
|
||||
execFunc := d.db.Exec
|
||||
selectFunc := d.writer.Select
|
||||
execFunc := d.writer.Exec
|
||||
if tx != nil {
|
||||
selectFunc = tx.Select
|
||||
execFunc = tx.Exec
|
||||
@ -108,7 +108,7 @@ func (d *Datastore) SaveScheduledQuery(sq *fleet.ScheduledQuery) (*fleet.Schedul
|
||||
}
|
||||
|
||||
func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
||||
updateFunc := d.db.Exec
|
||||
updateFunc := d.writer.Exec
|
||||
if tx != nil {
|
||||
updateFunc = tx.Exec
|
||||
}
|
||||
@ -160,7 +160,7 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
|
||||
WHERE sq.id = ?
|
||||
`
|
||||
sq := &fleet.ScheduledQuery{}
|
||||
if err := d.db.Get(sq, query, id); err != nil {
|
||||
if err := d.reader.Get(sq, query, id); err != nil {
|
||||
return nil, errors.Wrap(err, "select scheduled query")
|
||||
}
|
||||
|
||||
@ -168,11 +168,11 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
|
||||
}
|
||||
|
||||
func (d *Datastore) CleanupOrphanScheduledQueryStats() error {
|
||||
_, err := d.db.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
|
||||
_, err := d.writer.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query")
|
||||
}
|
||||
_, err = d.db.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
|
||||
_, err = d.writer.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host")
|
||||
}
|
||||
|
@ -209,7 +209,7 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
||||
test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1")
|
||||
sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2")
|
||||
|
||||
_, err := ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
||||
_, err := ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted,
|
||||
executions, schedule_interval, output_size, system_time,
|
||||
user_time, wall_time
|
||||
@ -224,12 +224,12 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
||||
require.Len(t, h1.PackStats, 1)
|
||||
|
||||
// now we insert a bogus stat
|
||||
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
||||
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted, executions
|
||||
) VALUES (?, 999, 32, false, 2);`, h1.ID)
|
||||
require.NoError(t, err)
|
||||
// and also for an unknown host
|
||||
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
||||
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted, executions
|
||||
) VALUES (888, 999, 32, true, 4);`)
|
||||
require.NoError(t, err)
|
||||
@ -241,14 +241,14 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
||||
|
||||
// but there are definitely there
|
||||
var count int
|
||||
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
|
||||
// now we clean it up
|
||||
require.NoError(t, ds.CleanupOrphanScheduledQueryStats())
|
||||
|
||||
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
|
@ -11,7 +11,7 @@ func (d *Datastore) SessionByKey(key string) (*fleet.Session, error) {
|
||||
WHERE ` + "`key`" + ` = ? LIMIT 1
|
||||
`
|
||||
session := &fleet.Session{}
|
||||
err := d.db.Get(session, sqlStatement, key)
|
||||
err := d.reader.Get(session, sqlStatement, key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting sessions")
|
||||
}
|
||||
@ -26,7 +26,7 @@ func (d *Datastore) SessionByID(id uint) (*fleet.Session, error) {
|
||||
LIMIT 1
|
||||
`
|
||||
session := &fleet.Session{}
|
||||
err := d.db.Get(session, sqlStatement, id)
|
||||
err := d.reader.Get(session, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting session by id")
|
||||
}
|
||||
@ -40,7 +40,7 @@ func (d *Datastore) ListSessionsForUser(id uint) ([]*fleet.Session, error) {
|
||||
WHERE user_id = ?
|
||||
`
|
||||
sessions := []*fleet.Session{}
|
||||
err := d.db.Select(&sessions, sqlStatement, id)
|
||||
err := d.reader.Select(&sessions, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting sessions for user")
|
||||
}
|
||||
@ -57,7 +57,7 @@ func (d *Datastore) NewSession(session *fleet.Session) (*fleet.Session, error) {
|
||||
)
|
||||
VALUES(?,?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, session.UserID, session.Key)
|
||||
result, err := d.writer.Exec(sqlStatement, session.UserID, session.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting session")
|
||||
}
|
||||
@ -80,7 +80,7 @@ func (d *Datastore) DestroyAllSessionsForUser(id uint) error {
|
||||
sqlStatement := `
|
||||
DELETE FROM sessions WHERE user_id = ?
|
||||
`
|
||||
_, err := d.db.Exec(sqlStatement, id)
|
||||
_, err := d.writer.Exec(sqlStatement, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "deleting sessions for user")
|
||||
}
|
||||
@ -94,7 +94,7 @@ func (d *Datastore) MarkSessionAccessed(session *fleet.Session) error {
|
||||
accessed_at = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
results, err := d.db.Exec(sqlStatement, d.clock.Now(), session.ID)
|
||||
results, err := d.writer.Exec(sqlStatement, d.clock.Now(), session.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating mark session as accessed")
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func (d *Datastore) insertNewInstalledHostSoftware(
|
||||
}
|
||||
|
||||
func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) {
|
||||
selectFunc := d.db.Select
|
||||
selectFunc := d.reader.Select
|
||||
if tx != nil {
|
||||
selectFunc = tx.Select
|
||||
}
|
||||
@ -231,7 +231,7 @@ func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Softwa
|
||||
WHERE s.id IN
|
||||
(SELECT software_id FROM host_software WHERE host_id = ?)
|
||||
`
|
||||
queryFunc := d.db.Queryx
|
||||
queryFunc := d.reader.Queryx
|
||||
if tx != nil {
|
||||
queryFunc = tx.Queryx
|
||||
}
|
||||
@ -305,7 +305,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
|
||||
sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null`
|
||||
// The rows.Close call is done by the caller once iteration using the
|
||||
// returned fleet.SoftwareIterator is done.
|
||||
rows, err := d.db.Queryx(sql) //nolint:sqlclosecheck
|
||||
rows, err := d.reader.Queryx(sql) //nolint:sqlclosecheck
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "load host software")
|
||||
}
|
||||
@ -314,7 +314,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
|
||||
|
||||
func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error {
|
||||
sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)`
|
||||
if _, err := d.db.Exec(sql, software.ID, cpe); err != nil {
|
||||
if _, err := d.writer.Exec(sql, software.ID, cpe); err != nil {
|
||||
return errors.Wrap(err, "insert software cpe")
|
||||
}
|
||||
return nil
|
||||
@ -323,7 +323,7 @@ func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error
|
||||
func (d *Datastore) AllCPEs() ([]string, error) {
|
||||
sql := `SELECT cpe FROM software_cpe`
|
||||
var cpes []string
|
||||
err := d.db.Select(&cpes, sql)
|
||||
err := d.reader.Select(&cpes, sql)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "loads cpes")
|
||||
}
|
||||
@ -337,7 +337,7 @@ func (d *Datastore) InsertCVEForCPE(cve string, cpes []string) error {
|
||||
for _, cpe := range cpes {
|
||||
args = append(args, cpe, cve)
|
||||
}
|
||||
_, err := d.db.Exec(sql, args...)
|
||||
_, err := d.writer.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "insert software cve")
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
|
||||
})
|
||||
incoming[soft2Key] = true
|
||||
|
||||
tx, err := ds.db.Beginx()
|
||||
tx, err := ds.writer.Beginx()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
||||
require.NoError(t, tx.Commit())
|
||||
@ -220,7 +220,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
|
||||
})
|
||||
incoming[soft3Key] = true
|
||||
|
||||
tx, err = ds.db.Beginx()
|
||||
tx, err = ds.writer.Beginx()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
||||
require.NoError(t, tx.Commit())
|
||||
|
@ -21,14 +21,14 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
|
||||
}
|
||||
|
||||
dest := statistics{}
|
||||
err = d.db.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
|
||||
err = d.writer.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
anonIdentifier, err := server.GenerateRandomText(64)
|
||||
if err != nil {
|
||||
return fleet.StatisticsPayload{}, false, err
|
||||
}
|
||||
_, err = d.db.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
|
||||
_, err = d.writer.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
|
||||
if err != nil {
|
||||
return fleet.StatisticsPayload{}, false, err
|
||||
}
|
||||
@ -55,6 +55,6 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
|
||||
}
|
||||
|
||||
func (d *Datastore) RecordStatisticsSent() error {
|
||||
_, err := d.db.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
|
||||
_, err := d.writer.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
|
||||
return err
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func (d *Datastore) CountHostsInTargets(filter fleet.TeamFilter, targets fleet.H
|
||||
}
|
||||
|
||||
res := fleet.TargetMetrics{}
|
||||
err = d.db.Get(&res, query, args...)
|
||||
err = d.reader.Get(&res, query, args...)
|
||||
if err != nil {
|
||||
return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets")
|
||||
}
|
||||
@ -98,7 +98,7 @@ func (d *Datastore) HostIDsInTargets(filter fleet.TeamFilter, targets fleet.Host
|
||||
}
|
||||
|
||||
var res []uint
|
||||
err = d.db.Select(&res, query, args...)
|
||||
err = d.reader.Select(&res, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets")
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ func (d *Datastore) NewTeam(team *fleet.Team) (*fleet.Team, error) {
|
||||
description
|
||||
) VALUES ( ?, ?, ? )
|
||||
`
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
query,
|
||||
team.Name,
|
||||
team.AgentOptions,
|
||||
@ -47,7 +47,7 @@ func (d *Datastore) Team(tid uint) (*fleet.Team, error) {
|
||||
`
|
||||
team := &fleet.Team{}
|
||||
|
||||
if err := d.db.Get(team, sql, tid); err != nil {
|
||||
if err := d.reader.Get(team, sql, tid); err != nil {
|
||||
return nil, errors.Wrap(err, "select team")
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ func (d *Datastore) TeamByName(name string) (*fleet.Team, error) {
|
||||
`
|
||||
team := &fleet.Team{}
|
||||
|
||||
if err := d.db.Get(team, sql, name); err != nil {
|
||||
if err := d.reader.Get(team, sql, name); err != nil {
|
||||
return nil, errors.Wrap(err, "select team")
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ func (d *Datastore) loadUsersForTeam(team *fleet.Team) error {
|
||||
WHERE ut.team_id = ?
|
||||
`
|
||||
rows := []fleet.TeamUser{}
|
||||
if err := d.db.Select(&rows, sql, team.ID); err != nil {
|
||||
if err := d.reader.Select(&rows, sql, team.ID); err != nil {
|
||||
return errors.Wrap(err, "load users for team")
|
||||
}
|
||||
|
||||
@ -157,7 +157,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
|
||||
description = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
|
||||
_, err := d.writer.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "saving team")
|
||||
}
|
||||
@ -174,7 +174,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
|
||||
}
|
||||
|
||||
func (d *Datastore) updateTeamSchedule(team *fleet.Team) error {
|
||||
_, err := d.db.Exec(
|
||||
_, err := d.writer.Exec(
|
||||
`UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team),
|
||||
)
|
||||
return err
|
||||
@ -195,7 +195,7 @@ func (d *Datastore) ListTeams(filter fleet.TeamFilter, opt fleet.ListOptions) ([
|
||||
query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...)
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
teams := []*fleet.Team{}
|
||||
if err := d.db.Select(&teams, query, params...); err != nil {
|
||||
if err := d.reader.Select(&teams, query, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "list teams")
|
||||
}
|
||||
if err := d.loadSecretsForTeams(teams); err != nil {
|
||||
@ -229,7 +229,7 @@ func (d *Datastore) SearchTeams(filter fleet.TeamFilter, matchQuery string, omit
|
||||
sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...)
|
||||
sql += "\nLIMIT 5"
|
||||
teams := []*fleet.Team{}
|
||||
if err := d.db.Select(&teams, sql, params...); err != nil {
|
||||
if err := d.reader.Select(&teams, sql, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "search teams")
|
||||
}
|
||||
if err := d.loadSecretsForTeams(teams); err != nil {
|
||||
@ -244,7 +244,7 @@ func (d *Datastore) TeamEnrollSecrets(teamID uint) ([]*fleet.EnrollSecret, error
|
||||
WHERE team_id = ?
|
||||
`
|
||||
var secrets []*fleet.EnrollSecret
|
||||
if err := d.db.Select(&secrets, sql, teamID); err != nil {
|
||||
if err := d.reader.Select(&secrets, sql, teamID); err != nil {
|
||||
return nil, errors.Wrap(err, "get secrets")
|
||||
}
|
||||
return secrets, nil
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -9,10 +10,12 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -20,9 +23,10 @@ const (
|
||||
testUsername = "root"
|
||||
testPassword = "toor"
|
||||
testAddress = "localhost:3307"
|
||||
testReplicaDatabaseSuffix = "_replica"
|
||||
)
|
||||
|
||||
func connectMySQL(t *testing.T, testName string) *Datastore {
|
||||
func connectMySQL(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
|
||||
config := config.MysqlConfig{
|
||||
Username: testUsername,
|
||||
Password: testPassword,
|
||||
@ -31,15 +35,134 @@ func connectMySQL(t *testing.T, testName string) *Datastore {
|
||||
}
|
||||
|
||||
// Create datastore client
|
||||
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
|
||||
var replicaOpt DBOption
|
||||
if opts.Replica {
|
||||
replicaConf := config
|
||||
replicaConf.Database += testReplicaDatabaseSuffix
|
||||
replicaOpt = Replica(&replicaConf)
|
||||
}
|
||||
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1), replicaOpt)
|
||||
require.Nil(t, err)
|
||||
|
||||
if opts.Replica {
|
||||
setupReadReplica(t, testName, ds, opts)
|
||||
}
|
||||
|
||||
return ds
|
||||
}
|
||||
|
||||
func setupReadReplica(t *testing.T, testName string, ds *Datastore, opts *DatastoreTestOptions) {
|
||||
// create the context that will cancel the replication goroutine on test exit
|
||||
var cancel func()
|
||||
ctx := context.Background()
|
||||
if dl, ok := t.Deadline(); ok {
|
||||
ctx, cancel = context.WithDeadline(ctx, dl)
|
||||
} else {
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
}
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// start the replication goroutine that runs when signalled through a
|
||||
// channel, the replication runs in lock-step - the test is in control of
|
||||
// when the replication happens, by calling opts.RunReplication(), and when
|
||||
// that call returns, the replication is guaranteed to be done. This supports
|
||||
// simulating all kinds of replica lag.
|
||||
ch := make(chan chan struct{})
|
||||
go func() {
|
||||
// if it exits because of a panic/failed replication, cancel the context
|
||||
// immediately so that RunReplication is unblocked too.
|
||||
defer cancel()
|
||||
|
||||
primary := ds.writer
|
||||
replica := ds.reader.(*sqlx.DB)
|
||||
replicaDB := testName + testReplicaDatabaseSuffix
|
||||
last := time.Now().Add(-time.Minute)
|
||||
|
||||
// drop all foreign keys in the replica, as that causes issues even with
|
||||
// FOREIGN_KEY_CHECKS=0
|
||||
var fks []struct {
|
||||
TableName string `db:"table_name"`
|
||||
ConstraintName string `db:"constraint_name"`
|
||||
}
|
||||
err := primary.SelectContext(ctx, &fks, `
|
||||
SELECT
|
||||
table_name, constraint_name
|
||||
FROM
|
||||
information_schema.key_column_usage
|
||||
WHERE
|
||||
table_schema = ? AND
|
||||
referenced_table_name IS NOT NULL`, testName)
|
||||
require.NoError(t, err)
|
||||
for _, fk := range fks {
|
||||
stmt := fmt.Sprintf(`ALTER TABLE %s.%s DROP FOREIGN KEY %s`, replicaDB, fk.TableName, fk.ConstraintName)
|
||||
_, err := replica.ExecContext(ctx, stmt)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case out := <-ch:
|
||||
// identify tables with changes since the last call
|
||||
var tables []string
|
||||
err := primary.SelectContext(ctx, &tables, `
|
||||
SELECT
|
||||
table_name
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
table_schema = ? AND
|
||||
table_type = 'BASE TABLE' AND
|
||||
update_time >= ?`, testName, last)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = primary.GetContext(ctx, &last, `
|
||||
SELECT
|
||||
MAX(update_time)
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
table_schema = ? AND
|
||||
table_type = 'BASE TABLE'`, testName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// replicate by dropping the existing table and re-creating it from
|
||||
// the primary.
|
||||
for _, tbl := range tables {
|
||||
stmt := fmt.Sprintf(`DROP TABLE IF EXISTS %s.%s`, replicaDB, tbl)
|
||||
t.Log(stmt)
|
||||
_, err = replica.ExecContext(ctx, stmt)
|
||||
require.NoError(t, err)
|
||||
stmt = fmt.Sprintf(`CREATE TABLE %s.%s SELECT * FROM %s.%s`, replicaDB, tbl, testName, tbl)
|
||||
t.Log(stmt)
|
||||
_, err = replica.ExecContext(ctx, stmt)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
out <- struct{}{}
|
||||
t.Logf("replication step executed, next will consider updates since %s", last)
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// set RunReplication to a function that triggers the replication and waits
|
||||
// for it to complete.
|
||||
opts.RunReplication = func() {
|
||||
done := make(chan struct{})
|
||||
ch <- done
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initializeDatabase loads the dumped schema into a newly created database in
|
||||
// MySQL. This is much faster than running the full set of migrations on each
|
||||
// test.
|
||||
func initializeDatabase(t *testing.T, testName string) *Datastore {
|
||||
func initializeDatabase(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
base := path.Dir(filename)
|
||||
schema, err := ioutil.ReadFile(path.Join(base, "schema.sql"))
|
||||
@ -47,6 +170,14 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// execute the schema for the test db, and once more for the replica db if
|
||||
// that option is set.
|
||||
dbs := []string{testName}
|
||||
if opts.Replica {
|
||||
dbs = append(dbs, testName+testReplicaDatabaseSuffix)
|
||||
}
|
||||
for _, dbName := range dbs {
|
||||
// Load schema from dumpfile
|
||||
if out, err := exec.Command(
|
||||
"docker-compose", "exec", "-T", "mysql_test",
|
||||
@ -56,24 +187,43 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
|
||||
"-e",
|
||||
fmt.Sprintf(
|
||||
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
|
||||
testName, testName, testName, schema,
|
||||
dbName, dbName, dbName, schema,
|
||||
),
|
||||
).CombinedOutput(); err != nil {
|
||||
t.Error(err)
|
||||
t.Error(string(out))
|
||||
t.FailNow()
|
||||
}
|
||||
return connectMySQL(t, testName)
|
||||
}
|
||||
return connectMySQL(t, testName, opts)
|
||||
}
|
||||
|
||||
func CreateMySQLDS(t *testing.T) *Datastore {
|
||||
// DatastoreTestOptions configures how the test datastore is created
|
||||
// by CreateMySQLDSWithOptions.
|
||||
type DatastoreTestOptions struct {
|
||||
// Replica indicates that a read replica test database should be created.
|
||||
Replica bool
|
||||
|
||||
// RunReplication is the function to call to execute the replication of all
|
||||
// missing changes from the primary to the replica. The function is created
|
||||
// and set automatically by CreateMySQLDSWithOptions. The test is in full
|
||||
// control of when the replication is executed.
|
||||
RunReplication func()
|
||||
}
|
||||
|
||||
func createMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
|
||||
if _, ok := os.LookupEnv("MYSQL_TEST"); !ok {
|
||||
t.Skip("MySQL tests are disabled")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
pc, _, _, ok := runtime.Caller(1)
|
||||
if opts == nil {
|
||||
// so it is never nil in internal helper functions
|
||||
opts = new(DatastoreTestOptions)
|
||||
}
|
||||
|
||||
pc, _, _, ok := runtime.Caller(2)
|
||||
details := runtime.FuncForPC(pc)
|
||||
if !ok || details == nil {
|
||||
t.FailNow()
|
||||
@ -83,5 +233,13 @@ func CreateMySQLDS(t *testing.T) *Datastore {
|
||||
strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_",
|
||||
)
|
||||
cleanName = strings.ReplaceAll(cleanName, ".", "_")
|
||||
return initializeDatabase(t, cleanName)
|
||||
return initializeDatabase(t, cleanName, opts)
|
||||
}
|
||||
|
||||
func CreateMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
|
||||
return createMySQLDSWithOptions(t, opts)
|
||||
}
|
||||
|
||||
func CreateMySQLDS(t *testing.T) *Datastore {
|
||||
return createMySQLDSWithOptions(t, nil)
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func (d *Datastore) findUser(searchCol string, searchVal interface{}) (*fleet.Us
|
||||
|
||||
user := &fleet.User{}
|
||||
|
||||
err := d.db.Get(user, sqlStatement, searchVal)
|
||||
err := d.reader.Get(user, sqlStatement, searchVal)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, notFound("User").
|
||||
WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal))
|
||||
@ -104,7 +104,7 @@ func (d *Datastore) ListUsers(opt fleet.UserListOptions) ([]*fleet.User, error)
|
||||
sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions)
|
||||
users := []*fleet.User{}
|
||||
|
||||
if err := d.db.Select(&users, sqlStatement, params...); err != nil {
|
||||
if err := d.reader.Select(&users, sqlStatement, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "list users")
|
||||
}
|
||||
|
||||
@ -160,7 +160,7 @@ func (d *Datastore) saveUser(tx *sqlx.Tx, user *fleet.User) error {
|
||||
global_role = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement,
|
||||
result, err := tx.Exec(sqlStatement,
|
||||
user.Password,
|
||||
user.Salt,
|
||||
user.Name,
|
||||
@ -221,7 +221,7 @@ func (d *Datastore) loadTeamsForUsers(users []*fleet.User) error {
|
||||
fleet.UserTeam
|
||||
UserID uint `db:"user_id"`
|
||||
}
|
||||
if err := d.db.Select(&rows, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&rows, sql, args...); err != nil {
|
||||
return errors.Wrap(err, "get loadTeamsForUsers")
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user