mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Add support for MySQL read replicas (#1786)
Adds configuration options to use a read-only MySQL replica, and uses it instead of the primary for reads.
This commit is contained in:
parent
2116217b6e
commit
3755a58070
1
changes/issue-1695-support-mysql-read-replica
Normal file
1
changes/issue-1695-support-mysql-read-replica
Normal file
@ -0,0 +1 @@
|
||||
* Add support for MySQL read replica.
|
@ -149,7 +149,11 @@ the way that the Fleet server works.
|
||||
var carveStore fleet.CarveStore
|
||||
mailService := mail.NewService()
|
||||
|
||||
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger))
|
||||
var replicaOpt mysql.DBOption
|
||||
if config.MysqlReadReplica.Address != "" {
|
||||
replicaOpt = mysql.Replica(&config.MysqlReadReplica)
|
||||
}
|
||||
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger), replicaOpt)
|
||||
if err != nil {
|
||||
initFatal(err, "initializing datastore")
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ Your Fleet server's two main purposes are:
|
||||
- To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/)
|
||||
- To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc.
|
||||
|
||||
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
|
||||
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. It is also possible to configure connection to a MySQL replica in addition to the primary, to be used for reading only. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
|
||||
|
||||
> Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance.
|
||||
|
||||
@ -118,12 +118,21 @@ mysql:
|
||||
address: 127.0.0.1:3306
|
||||
```
|
||||
|
||||
And `mysql_read_replica_address` would be:
|
||||
|
||||
```
|
||||
mysql_read_replica:
|
||||
address: 127.0.0.1:3307
|
||||
```
|
||||
|
||||
Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way.
|
||||
|
||||
All duration-based settings accept valid time units of `s`, `m`, `h`.
|
||||
|
||||
##### MySQL
|
||||
|
||||
This section describes the configuration options for the primary - if you also want to setup a read replica, the options are the same, except that the yaml section is `mysql_read_replica`, and the flags have the `mysql_read_replica_` prefix instead of `mysql_` (the corresponding environment variables follow the same transformation). Note that there is no default value for `mysql_read_replica_address`, it must be set explicitly for fleet to use a read replica.
|
||||
|
||||
###### `mysql_address`
|
||||
|
||||
The address of the MySQL server which Fleet should connect to. Include the hostname and port.
|
||||
@ -1246,7 +1255,7 @@ AWS STS role ARN to use for S3 authentication.
|
||||
|
||||
The path specified needs to exist and fleet needs to be able to read and write to and from it. This is the only mandatory configuration needed for vulnerability processing to work.
|
||||
|
||||
When `current_instance_checks` is set to `auto` (the default), Fleet instances will try to create the `databases_path` if it doesn't exist.
|
||||
When `current_instance_checks` is set to `auto` (the default), Fleet instances will try to create the `databases_path` if it doesn't exist.
|
||||
|
||||
- Default value: none
|
||||
- Environment variable: `FLEET_VULNERABILITIES_DATABASES_PATH`
|
||||
@ -1395,15 +1404,15 @@ After supplying the above information, the IDP will generate an issuer URI and a
|
||||
|
||||
![Example Okta IDP Configuration](https://raw.githubusercontent.com/fleetdm/fleet/main/docs/images/okta-idp-setup.png)
|
||||
|
||||
> The names of the items required to configure an Identity Provider may vary from provider to provider and may not conform to the SAML spec.
|
||||
> The names of the items required to configure an Identity Provider may vary from provider to provider and may not conform to the SAML spec.
|
||||
|
||||
> Individual users must also be setup on the IDP before they can sign in to Fleet.
|
||||
> Individual users must also be setup on the IDP before they can sign in to Fleet.
|
||||
|
||||
### Fleet SSO Configuration
|
||||
|
||||
A Fleet user must be assigned the Admin role to configure Fleet for SSO. In Fleet, SSO configuration settings are located in **Settings > Organization settings > SAML Single Sign On Options**.
|
||||
A Fleet user must be assigned the Admin role to configure Fleet for SSO. In Fleet, SSO configuration settings are located in **Settings > Organization settings > SAML Single Sign On Options**.
|
||||
|
||||
If your IDP supports dynamic configuration, like Okta, you only need to provide an _Identity Provider Name_ and _Entity ID_, then paste a link in the metadata URL field.
|
||||
If your IDP supports dynamic configuration, like Okta, you only need to provide an _Identity Provider Name_ and _Entity ID_, then paste a link in the metadata URL field.
|
||||
|
||||
Otherwise, the following values are required:
|
||||
|
||||
@ -1430,7 +1439,7 @@ Otherwise, the following values are required:
|
||||
### Creating SSO users in Fleet
|
||||
|
||||
When an admin creates a new user to Fleet, they may select the `Enable Single Sign On` option. The
|
||||
SSO enabled users will not be able to sign in with a regular user ID and password.
|
||||
SSO enabled users will not be able to sign in with a regular user ID and password.
|
||||
|
||||
It is strongly recommended that at least one admin user is set up to use the traditional password
|
||||
based log in so that there is a fallback method for logging into Fleet in the event of SSO
|
||||
|
@ -176,53 +176,58 @@ type VulnerabilitiesConfig struct {
|
||||
// structs, Manager.addConfigs and Manager.LoadConfig should be
|
||||
// updated to set and retrieve the configurations as appropriate.
|
||||
type FleetConfig struct {
|
||||
Mysql MysqlConfig
|
||||
Redis RedisConfig
|
||||
Server ServerConfig
|
||||
Auth AuthConfig
|
||||
App AppConfig
|
||||
Session SessionConfig
|
||||
Osquery OsqueryConfig
|
||||
Logging LoggingConfig
|
||||
Firehose FirehoseConfig
|
||||
Kinesis KinesisConfig
|
||||
Lambda LambdaConfig
|
||||
S3 S3Config
|
||||
PubSub PubSubConfig
|
||||
Filesystem FilesystemConfig
|
||||
License LicenseConfig
|
||||
Vulnerabilities VulnerabilitiesConfig
|
||||
Mysql MysqlConfig
|
||||
MysqlReadReplica MysqlConfig `yaml:"mysql_read_replica"`
|
||||
Redis RedisConfig
|
||||
Server ServerConfig
|
||||
Auth AuthConfig
|
||||
App AppConfig
|
||||
Session SessionConfig
|
||||
Osquery OsqueryConfig
|
||||
Logging LoggingConfig
|
||||
Firehose FirehoseConfig
|
||||
Kinesis KinesisConfig
|
||||
Lambda LambdaConfig
|
||||
S3 S3Config
|
||||
PubSub PubSubConfig
|
||||
Filesystem FilesystemConfig
|
||||
License LicenseConfig
|
||||
Vulnerabilities VulnerabilitiesConfig
|
||||
}
|
||||
|
||||
// addConfigs adds the configuration keys and default values that will be
|
||||
// filled into the FleetConfig struct
|
||||
func (man Manager) addConfigs() {
|
||||
addMysqlConfig := func(prefix, defaultAddr, usageSuffix string) {
|
||||
man.addConfigString(prefix+".protocol", "tcp",
|
||||
"MySQL server communication protocol (tcp,unix,...)"+usageSuffix)
|
||||
man.addConfigString(prefix+".address", defaultAddr,
|
||||
"MySQL server address (host:port)"+usageSuffix)
|
||||
man.addConfigString(prefix+".username", "fleet",
|
||||
"MySQL server username"+usageSuffix)
|
||||
man.addConfigString(prefix+".password", "",
|
||||
"MySQL server password (prefer env variable for security)"+usageSuffix)
|
||||
man.addConfigString(prefix+".password_path", "",
|
||||
"Path to file containg MySQL server password"+usageSuffix)
|
||||
man.addConfigString(prefix+".database", "fleet",
|
||||
"MySQL database name"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_cert", "",
|
||||
"MySQL TLS client certificate path"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_key", "",
|
||||
"MySQL TLS client key path"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_ca", "",
|
||||
"MySQL TLS server CA"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_server_name", "",
|
||||
"MySQL TLS server name"+usageSuffix)
|
||||
man.addConfigString(prefix+".tls_config", "",
|
||||
"MySQL TLS config value"+usageSuffix+" Use skip-verify, true, false or custom key.")
|
||||
man.addConfigInt(prefix+".max_open_conns", 50, "MySQL maximum open connection handles"+usageSuffix)
|
||||
man.addConfigInt(prefix+".max_idle_conns", 50, "MySQL maximum idle connection handles"+usageSuffix)
|
||||
man.addConfigInt(prefix+".conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused"+usageSuffix)
|
||||
}
|
||||
// MySQL
|
||||
man.addConfigString("mysql.protocol", "tcp",
|
||||
"MySQL server communication protocol (tcp,unix,...)")
|
||||
man.addConfigString("mysql.address", "localhost:3306",
|
||||
"MySQL server address (host:port)")
|
||||
man.addConfigString("mysql.username", "fleet",
|
||||
"MySQL server username")
|
||||
man.addConfigString("mysql.password", "",
|
||||
"MySQL server password (prefer env variable for security)")
|
||||
man.addConfigString("mysql.password_path", "",
|
||||
"Path to file containg MySQL server password")
|
||||
man.addConfigString("mysql.database", "fleet",
|
||||
"MySQL database name")
|
||||
man.addConfigString("mysql.tls_cert", "",
|
||||
"MySQL TLS client certificate path")
|
||||
man.addConfigString("mysql.tls_key", "",
|
||||
"MySQL TLS client key path")
|
||||
man.addConfigString("mysql.tls_ca", "",
|
||||
"MySQL TLS server CA")
|
||||
man.addConfigString("mysql.tls_server_name", "",
|
||||
"MySQL TLS server name")
|
||||
man.addConfigString("mysql.tls_config", "",
|
||||
"MySQL TLS config value. Use skip-verify, true, false or custom key.")
|
||||
man.addConfigInt("mysql.max_open_conns", 50, "MySQL maximum open connection handles.")
|
||||
man.addConfigInt("mysql.max_idle_conns", 50, "MySQL maximum idle connection handles.")
|
||||
man.addConfigInt("mysql.conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused.")
|
||||
addMysqlConfig("mysql", "localhost:3306", ".")
|
||||
addMysqlConfig("mysql_read_replica", "", " for the read replica.")
|
||||
|
||||
// Redis
|
||||
man.addConfigString("redis.address", "localhost:6379",
|
||||
@ -382,23 +387,28 @@ func (man Manager) addConfigs() {
|
||||
func (man Manager) LoadConfig() FleetConfig {
|
||||
man.loadConfigFile()
|
||||
|
||||
loadMysqlConfig := func(prefix string) MysqlConfig {
|
||||
return MysqlConfig{
|
||||
Protocol: man.getConfigString(prefix + ".protocol"),
|
||||
Address: man.getConfigString(prefix + ".address"),
|
||||
Username: man.getConfigString(prefix + ".username"),
|
||||
Password: man.getConfigString(prefix + ".password"),
|
||||
PasswordPath: man.getConfigString(prefix + ".password_path"),
|
||||
Database: man.getConfigString(prefix + ".database"),
|
||||
TLSCert: man.getConfigString(prefix + ".tls_cert"),
|
||||
TLSKey: man.getConfigString(prefix + ".tls_key"),
|
||||
TLSCA: man.getConfigString(prefix + ".tls_ca"),
|
||||
TLSServerName: man.getConfigString(prefix + ".tls_server_name"),
|
||||
TLSConfig: man.getConfigString(prefix + ".tls_config"),
|
||||
MaxOpenConns: man.getConfigInt(prefix + ".max_open_conns"),
|
||||
MaxIdleConns: man.getConfigInt(prefix + ".max_idle_conns"),
|
||||
ConnMaxLifetime: man.getConfigInt(prefix + ".conn_max_lifetime"),
|
||||
}
|
||||
}
|
||||
|
||||
return FleetConfig{
|
||||
Mysql: MysqlConfig{
|
||||
Protocol: man.getConfigString("mysql.protocol"),
|
||||
Address: man.getConfigString("mysql.address"),
|
||||
Username: man.getConfigString("mysql.username"),
|
||||
Password: man.getConfigString("mysql.password"),
|
||||
PasswordPath: man.getConfigString("mysql.password_path"),
|
||||
Database: man.getConfigString("mysql.database"),
|
||||
TLSCert: man.getConfigString("mysql.tls_cert"),
|
||||
TLSKey: man.getConfigString("mysql.tls_key"),
|
||||
TLSCA: man.getConfigString("mysql.tls_ca"),
|
||||
TLSServerName: man.getConfigString("mysql.tls_server_name"),
|
||||
TLSConfig: man.getConfigString("mysql.tls_config"),
|
||||
MaxOpenConns: man.getConfigInt("mysql.max_open_conns"),
|
||||
MaxIdleConns: man.getConfigInt("mysql.max_idle_conns"),
|
||||
ConnMaxLifetime: man.getConfigInt("mysql.conn_max_lifetime"),
|
||||
},
|
||||
Mysql: loadMysqlConfig("mysql"),
|
||||
MysqlReadReplica: loadMysqlConfig("mysql_read_replica"),
|
||||
Redis: RedisConfig{
|
||||
Address: man.getConfigString("redis.address"),
|
||||
Password: man.getConfigString("redis.password"),
|
||||
|
@ -58,6 +58,7 @@ func TestConfigRoundtrip(t *testing.T) {
|
||||
// Marshal the generated config
|
||||
buf, err := yaml.Marshal(original)
|
||||
require.Nil(t, err)
|
||||
t.Log(string(buf))
|
||||
|
||||
// Manually load the serialized config
|
||||
man.viper.SetConfigType("yaml")
|
||||
|
@ -3,6 +3,7 @@ package mysql
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@ -13,7 +14,7 @@ func (d *Datastore) NewActivity(user *fleet.User, activityType string, details *
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling activity details")
|
||||
}
|
||||
_, err = d.db.Exec(
|
||||
_, err = d.writer.Exec(
|
||||
`INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`,
|
||||
user.ID,
|
||||
user.Name,
|
||||
@ -29,12 +30,12 @@ func (d *Datastore) NewActivity(user *fleet.User, activityType string, details *
|
||||
// ListActivities returns a slice of activities performed across the organization
|
||||
func (d *Datastore) ListActivities(opt fleet.ListOptions) ([]*fleet.Activity, error) {
|
||||
activities := []*fleet.Activity{}
|
||||
query := `SELECT a.id, a.user_id, a.created_at, a.activity_type, a.details, coalesce(u.name, a.user_name) as name, u.gravatar_url, u.email
|
||||
query := `SELECT a.id, a.user_id, a.created_at, a.activity_type, a.details, coalesce(u.name, a.user_name) as name, u.gravatar_url, u.email
|
||||
FROM activities a LEFT JOIN users u ON (a.user_id=u.id)
|
||||
WHERE true`
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
|
||||
err := d.db.Select(&activities, query)
|
||||
err := d.reader.Select(&activities, query)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Activity")
|
||||
} else if err != nil {
|
||||
|
@ -25,7 +25,7 @@ func (d *Datastore) NewAppConfig(info *fleet.AppConfig) (*fleet.AppConfig, error
|
||||
func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
||||
info := &fleet.AppConfig{}
|
||||
var bytes []byte
|
||||
err := d.db.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
|
||||
err := d.reader.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "selecting app config")
|
||||
}
|
||||
@ -43,7 +43,7 @@ func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
||||
}
|
||||
|
||||
func (d *Datastore) isEventSchedulerEnabled() (bool, error) {
|
||||
rows, err := d.db.Query("SELECT @@event_scheduler")
|
||||
rows, err := d.writer.Query("SELECT @@event_scheduler")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -136,7 +136,7 @@ func (d *Datastore) SaveAppConfig(info *fleet.AppConfig) error {
|
||||
|
||||
func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) {
|
||||
var s fleet.EnrollSecret
|
||||
err := d.db.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
||||
err := d.reader.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
||||
if err != nil {
|
||||
return nil, errors.New("no matching secret found")
|
||||
}
|
||||
@ -184,7 +184,7 @@ func (d *Datastore) GetEnrollSecrets(teamID *uint) ([]*fleet.EnrollSecret, error
|
||||
args = append(args, teamID)
|
||||
}
|
||||
var secrets []*fleet.EnrollSecret
|
||||
if err := d.db.Select(&secrets, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&secrets, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "get secrets")
|
||||
}
|
||||
return secrets, nil
|
||||
|
@ -259,7 +259,7 @@ func TestAppConfigDefaults(t *testing.T) {
|
||||
defer ds.Close()
|
||||
|
||||
insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)`
|
||||
_, err := ds.db.Exec(insertAppConfigQuery, `{}`)
|
||||
_, err := ds.writer.Exec(insertAppConfigQuery, `{}`)
|
||||
require.NoError(t, err)
|
||||
|
||||
ac, err := ds.AppConfig()
|
||||
@ -269,7 +269,7 @@ func TestAppConfigDefaults(t *testing.T) {
|
||||
require.True(t, ac.HostSettings.EnableHostUsers)
|
||||
require.False(t, ac.HostSettings.EnableSoftwareInventory)
|
||||
|
||||
_, err = ds.db.Exec(
|
||||
_, err = ds.writer.Exec(
|
||||
insertAppConfigQuery,
|
||||
`{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ func (d *Datastore) NewDistributedQueryCampaign(camp *fleet.DistributedQueryCamp
|
||||
)
|
||||
VALUES(?,?,?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
|
||||
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting distributed query campaign")
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (d *Datastore) DistributedQueryCampaign(id uint) (*fleet.DistributedQueryCa
|
||||
SELECT * FROM distributed_query_campaigns WHERE id = ?
|
||||
`
|
||||
campaign := &fleet.DistributedQueryCampaign{}
|
||||
if err := d.db.Get(campaign, sql, id); err != nil {
|
||||
if err := d.reader.Get(campaign, sql, id); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting distributed query campaign")
|
||||
}
|
||||
|
||||
@ -47,7 +47,7 @@ func (d *Datastore) SaveDistributedQueryCampaign(camp *fleet.DistributedQueryCam
|
||||
user_id = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
|
||||
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating distributed query campaign")
|
||||
}
|
||||
@ -68,7 +68,7 @@ func (d *Datastore) DistributedQueryCampaignTargetIDs(id uint) (*fleet.HostTarge
|
||||
`
|
||||
targets := []fleet.DistributedQueryCampaignTarget{}
|
||||
|
||||
if err := d.db.Select(&targets, sqlStatement, id); err != nil {
|
||||
if err := d.reader.Select(&targets, sqlStatement, id); err != nil {
|
||||
return nil, errors.Wrap(err, "select distributed campaign target")
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ func (d *Datastore) NewDistributedQueryCampaignTarget(target *fleet.DistributedQ
|
||||
)
|
||||
VALUES (?,?,?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
|
||||
result, err := d.writer.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "insert distributed campaign target")
|
||||
}
|
||||
@ -118,7 +118,7 @@ func (d *Datastore) CleanupDistributedQueryCampaigns(now time.Time) (expired uin
|
||||
WHERE (status = ? AND created_at < ?)
|
||||
OR (status = ? AND created_at < ?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, fleet.QueryComplete,
|
||||
result, err := d.writer.Exec(sqlStatement, fleet.QueryComplete,
|
||||
fleet.QueryWaiting, now.Add(-1*time.Minute),
|
||||
fleet.QueryRunning, now.Add(-24*time.Hour))
|
||||
if err != nil {
|
||||
|
@ -33,7 +33,7 @@ func (d *Datastore) NewCarve(metadata *fleet.CarveMetadata) (*fleet.CarveMetadat
|
||||
?
|
||||
)`
|
||||
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
stmt,
|
||||
metadata.HostId,
|
||||
metadata.CreatedAt.Format(mySQLTimestampFormat),
|
||||
@ -64,7 +64,7 @@ func (d *Datastore) UpdateCarve(metadata *fleet.CarveMetadata) error {
|
||||
expired = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(
|
||||
_, err := d.writer.Exec(
|
||||
stmt,
|
||||
metadata.MaxBlock,
|
||||
metadata.Expired,
|
||||
@ -163,7 +163,7 @@ func (d *Datastore) Carve(carveId int64) (*fleet.CarveMetadata, error) {
|
||||
)
|
||||
|
||||
var metadata fleet.CarveMetadata
|
||||
if err := d.db.Get(&metadata, stmt, carveId); err != nil {
|
||||
if err := d.reader.Get(&metadata, stmt, carveId); err != nil {
|
||||
return nil, errors.Wrap(err, "get carve by ID")
|
||||
}
|
||||
|
||||
@ -179,7 +179,7 @@ func (d *Datastore) CarveBySessionId(sessionId string) (*fleet.CarveMetadata, er
|
||||
)
|
||||
|
||||
var metadata fleet.CarveMetadata
|
||||
if err := d.db.Get(&metadata, stmt, sessionId); err != nil {
|
||||
if err := d.reader.Get(&metadata, stmt, sessionId); err != nil {
|
||||
return nil, errors.Wrap(err, "get carve by session ID")
|
||||
}
|
||||
|
||||
@ -195,7 +195,7 @@ func (d *Datastore) CarveByName(name string) (*fleet.CarveMetadata, error) {
|
||||
)
|
||||
|
||||
var metadata fleet.CarveMetadata
|
||||
if err := d.db.Get(&metadata, stmt, name); err != nil {
|
||||
if err := d.reader.Get(&metadata, stmt, name); err != nil {
|
||||
return nil, errors.Wrap(err, "get carve by name")
|
||||
}
|
||||
|
||||
@ -213,7 +213,7 @@ func (d *Datastore) ListCarves(opt fleet.CarveListOptions) ([]*fleet.CarveMetada
|
||||
}
|
||||
stmt = appendListOptionsToSQL(stmt, opt.ListOptions)
|
||||
carves := []*fleet.CarveMetadata{}
|
||||
if err := d.db.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
|
||||
if err := d.reader.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "list carves")
|
||||
}
|
||||
|
||||
@ -231,13 +231,14 @@ func (d *Datastore) NewBlock(metadata *fleet.CarveMetadata, blockId int64, data
|
||||
?,
|
||||
?
|
||||
)`
|
||||
if _, err := d.db.Exec(stmt, metadata.ID, blockId, data); err != nil {
|
||||
if _, err := d.writer.Exec(stmt, metadata.ID, blockId, data); err != nil {
|
||||
return errors.Wrap(err, "insert carve block")
|
||||
}
|
||||
|
||||
if metadata.MaxBlock < blockId {
|
||||
// Update max_block
|
||||
metadata.MaxBlock = blockId
|
||||
// TODO: looks like this should all be done in an atomic transaction?
|
||||
if err := d.UpdateCarve(metadata); err != nil {
|
||||
return errors.Wrap(err, "insert carve block")
|
||||
}
|
||||
@ -253,7 +254,7 @@ func (d *Datastore) GetBlock(metadata *fleet.CarveMetadata, blockId int64) ([]by
|
||||
WHERE metadata_id = ? AND block_id = ?
|
||||
`
|
||||
var data []byte
|
||||
if err := d.db.Get(&data, stmt, metadata.ID, blockId); err != nil {
|
||||
if err := d.reader.Get(&data, stmt, metadata.ID, blockId); err != nil {
|
||||
return nil, errors.Wrap(err, "select data")
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
package mysql
|
||||
|
||||
import "github.com/go-kit/kit/log"
|
||||
import (
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/go-kit/kit/log"
|
||||
)
|
||||
|
||||
const defaultMaxAttempts int = 15
|
||||
|
||||
@ -9,8 +12,9 @@ type DBOption func(o *dbOptions) error
|
||||
|
||||
type dbOptions struct {
|
||||
// maxAttempts configures the number of retries to connect to the DB
|
||||
maxAttempts int
|
||||
logger log.Logger
|
||||
maxAttempts int
|
||||
logger log.Logger
|
||||
replicaConfig *config.MysqlConfig
|
||||
}
|
||||
|
||||
// Logger adds a logger to the datastore
|
||||
@ -21,6 +25,14 @@ func Logger(l log.Logger) DBOption {
|
||||
}
|
||||
}
|
||||
|
||||
// Replica sets the configuration of the read replica for the datastore.
|
||||
func Replica(conf *config.MysqlConfig) DBOption {
|
||||
return func(o *dbOptions) error {
|
||||
o.replicaConfig = conf
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// LimitAttempts sets a the number of attempts
|
||||
// to try establishing a connection to the database backend
|
||||
// the default value is 15 attempts
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
// returning a notFound error if appropriate.
|
||||
func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
||||
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable)
|
||||
result, err := d.db.Exec(deleteStmt, id)
|
||||
result, err := d.writer.Exec(deleteStmt, id)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "delete %s", dbTable)
|
||||
}
|
||||
@ -26,7 +26,7 @@ func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
||||
// table, returning a notFound error if appropriate.
|
||||
func (d *Datastore) deleteEntityByName(dbTable string, name string) error {
|
||||
deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable)
|
||||
result, err := d.db.Exec(deleteStmt, name)
|
||||
result, err := d.writer.Exec(deleteStmt, name)
|
||||
if err != nil {
|
||||
if isMySQLForeignKey(err) {
|
||||
return foreignKey(dbTable, name)
|
||||
@ -50,7 +50,7 @@ func (d *Datastore) deleteEntities(dbTable string, ids []uint) (uint, error) {
|
||||
return 0, errors.Wrapf(err, "building delete entities query %s", dbTable)
|
||||
}
|
||||
|
||||
result, err := d.db.Exec(query, args...)
|
||||
result, err := d.writer.Exec(query, args...)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable)
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ func (ds *Datastore) PendingEmailChange(uid uint, newEmail, token string) error
|
||||
new_email
|
||||
) VALUES( ?, ?, ? )
|
||||
`
|
||||
_, err := ds.db.Exec(sqlStatement, uid, token, newEmail)
|
||||
_, err := ds.writer.Exec(sqlStatement, uid, token, newEmail)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "inserting email change record")
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (ds *Datastore) ConfirmPendingEmailChange(id uint, token string) (newEmail
|
||||
Token string
|
||||
NewEmail string `db:"new_email"`
|
||||
}{}
|
||||
err = ds.db.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
|
||||
err = ds.writer.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", notFound("email change with token")
|
||||
|
@ -33,7 +33,7 @@ func (d *Datastore) NewHost(host *fleet.Host) (*fleet.Host, error) {
|
||||
)
|
||||
VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? )
|
||||
`
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
sqlStatement,
|
||||
host.OsqueryHostID,
|
||||
host.DetailUpdatedAt,
|
||||
@ -95,7 +95,7 @@ func (d *Datastore) SaveHost(host *fleet.Host) error {
|
||||
percent_disk_space_available = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(sqlStatement,
|
||||
_, err := d.writer.Exec(sqlStatement,
|
||||
host.DetailUpdatedAt,
|
||||
host.LabelUpdatedAt,
|
||||
host.NodeKey,
|
||||
@ -248,7 +248,7 @@ FROM scheduled_query_stats sqs
|
||||
WHERE host_id = ? AND p.pack_type IS NULL
|
||||
`
|
||||
var stats []fleet.ScheduledQueryStats
|
||||
if err := d.db.Select(&stats, sql, host.ID); err != nil {
|
||||
if err := d.reader.Select(&stats, sql, host.ID); err != nil {
|
||||
return errors.Wrap(err, "load pack stats")
|
||||
}
|
||||
|
||||
@ -268,9 +268,9 @@ WHERE host_id = ? AND p.pack_type IS NULL
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Datastore) loadHostUsers(host *fleet.Host) error {
|
||||
func (d *Datastore) loadHostUsers(db dbReader, host *fleet.Host) error {
|
||||
sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL`
|
||||
if err := d.db.Select(&host.Users, sql, host.ID); err != nil {
|
||||
if err := db.Select(&host.Users, sql, host.ID); err != nil {
|
||||
return errors.Wrap(err, "load pack stats")
|
||||
}
|
||||
return nil
|
||||
@ -292,14 +292,14 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
|
||||
LIMIT 1
|
||||
`
|
||||
host := &fleet.Host{}
|
||||
err := d.db.Get(host, sqlStatement, id)
|
||||
err := d.reader.Get(host, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get host by id")
|
||||
}
|
||||
if err := d.loadHostPackStats(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.loadHostUsers(host); err != nil {
|
||||
if err := d.loadHostUsers(d.reader, host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -308,7 +308,9 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
|
||||
|
||||
func (d *Datastore) amountEnrolledHosts() (int, error) {
|
||||
var amount int
|
||||
err := d.db.Get(&amount, `SELECT count(*) FROM hosts`)
|
||||
// uses the writer as it is called from ShouldSendStatistics, which is a
|
||||
// write method.
|
||||
err := d.writer.Get(&amount, `SELECT count(*) FROM hosts`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -363,7 +365,7 @@ func (d *Datastore) ListHosts(filter fleet.TeamFilter, opt fleet.HostListOptions
|
||||
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
||||
|
||||
hosts := []*fleet.Host{}
|
||||
if err := d.db.Select(&hosts, sql, params...); err != nil {
|
||||
if err := d.reader.Select(&hosts, sql, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "list hosts")
|
||||
}
|
||||
|
||||
@ -413,7 +415,7 @@ func (d *Datastore) CleanupIncomingHosts(now time.Time) error {
|
||||
WHERE hostname = '' AND osquery_version = ''
|
||||
AND created_at < (? - INTERVAL 5 MINUTE)
|
||||
`
|
||||
if _, err := d.db.Exec(sqlStatement, now); err != nil {
|
||||
if _, err := d.writer.Exec(sqlStatement, now); err != nil {
|
||||
return errors.Wrap(err, "cleanup incoming hosts")
|
||||
}
|
||||
|
||||
@ -442,7 +444,7 @@ func (d *Datastore) GenerateHostStatusStatistics(filter fleet.TeamFilter, now ti
|
||||
Online uint `db:"online"`
|
||||
New uint `db:"new"`
|
||||
}{}
|
||||
err := d.db.Get(&counts, sqlStatement, now, now, now, now, now)
|
||||
err := d.reader.Get(&counts, sqlStatement, now, now, now, now, now)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
e = errors.Wrap(err, "generating host statistics")
|
||||
return
|
||||
@ -582,7 +584,7 @@ func (d *Datastore) AuthenticateHost(nodeKey string) (*fleet.Host, error) {
|
||||
`
|
||||
|
||||
host := &fleet.Host{}
|
||||
if err := d.db.Get(host, sqlStatement, nodeKey); err != nil {
|
||||
if err := d.reader.Get(host, sqlStatement, nodeKey); err != nil {
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, notFound("Host")
|
||||
@ -601,7 +603,7 @@ func (d *Datastore) MarkHostSeen(host *fleet.Host, t time.Time) error {
|
||||
WHERE node_key=?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(sqlStatement, t, host.NodeKey)
|
||||
_, err := d.writer.Exec(sqlStatement, t, host.NodeKey)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marking host seen")
|
||||
}
|
||||
@ -659,11 +661,11 @@ func (d *Datastore) searchHostsWithOmits(filter fleet.TeamFilter, query string,
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching hosts")
|
||||
}
|
||||
sql = d.db.Rebind(sql)
|
||||
sql = d.reader.Rebind(sql)
|
||||
|
||||
hosts := []*fleet.Host{}
|
||||
|
||||
err = d.db.Select(&hosts, sql, args...)
|
||||
err = d.reader.Select(&hosts, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching hosts rebound")
|
||||
}
|
||||
@ -695,8 +697,8 @@ func (d *Datastore) searchHostsDefault(filter fleet.TeamFilter, omit ...uint) ([
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching default hosts")
|
||||
}
|
||||
sql = d.db.Rebind(sql)
|
||||
err = d.db.Select(&hosts, sql, args...)
|
||||
sql = d.reader.Rebind(sql)
|
||||
err = d.reader.Select(&hosts, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching default hosts rebound")
|
||||
}
|
||||
@ -730,7 +732,7 @@ func (d *Datastore) SearchHosts(filter fleet.TeamFilter, query string, omit ...u
|
||||
)
|
||||
|
||||
hosts := []*fleet.Host{}
|
||||
if err := d.db.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
|
||||
if err := d.reader.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
|
||||
return nil, errors.Wrap(err, "searching hosts")
|
||||
}
|
||||
|
||||
@ -755,7 +757,7 @@ func (d *Datastore) HostIDsByName(filter fleet.TeamFilter, hostnames []string) (
|
||||
}
|
||||
|
||||
var hostIDs []uint
|
||||
if err := d.db.Select(&hostIDs, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&hostIDs, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "get host IDs")
|
||||
}
|
||||
|
||||
@ -770,7 +772,7 @@ func (d *Datastore) HostByIdentifier(identifier string) (*fleet.Host, error) {
|
||||
LIMIT 1
|
||||
`
|
||||
host := &fleet.Host{}
|
||||
err := d.db.Get(host, sql, identifier)
|
||||
err := d.reader.Get(host, sql, identifier)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get host by identifier")
|
||||
}
|
||||
@ -796,7 +798,7 @@ func (d *Datastore) AddHostsToTeam(teamID *uint, hostIDs []uint) error {
|
||||
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
|
||||
}
|
||||
|
||||
if _, err := d.db.Exec(sql, args...); err != nil {
|
||||
if _, err := d.writer.Exec(sql, args...); err != nil {
|
||||
return errors.Wrap(err, "exec AddHostsToTeam")
|
||||
}
|
||||
|
||||
@ -809,7 +811,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
|
||||
VALUES (?, ?)
|
||||
ON DUPLICATE KEY UPDATE additional = VALUES(additional)
|
||||
`
|
||||
if _, err := d.db.Exec(sql, host.ID, host.Additional); err != nil {
|
||||
if _, err := d.writer.Exec(sql, host.ID, host.Additional); err != nil {
|
||||
return errors.Wrap(err, "insert additional")
|
||||
}
|
||||
|
||||
@ -818,7 +820,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
|
||||
|
||||
func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
if len(host.Users) == 0 {
|
||||
if _, err := d.db.Exec(
|
||||
if _, err := d.writer.Exec(
|
||||
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`,
|
||||
host.ID,
|
||||
); err != nil {
|
||||
@ -829,7 +831,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
}
|
||||
|
||||
currentHost := &fleet.Host{ID: host.ID}
|
||||
if err := d.loadHostUsers(currentHost); err != nil {
|
||||
if err := d.loadHostUsers(d.writer, currentHost); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -837,7 +839,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
for _, u := range host.Users {
|
||||
incomingUsers[u.Uid] = true
|
||||
|
||||
if _, err := d.db.Exec(
|
||||
if _, err := d.writer.Exec(
|
||||
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
|
||||
host.ID, u.Uid, u.Username, u.Type, u.GroupName,
|
||||
); err != nil {
|
||||
@ -860,7 +862,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`,
|
||||
removedValues,
|
||||
)
|
||||
if _, err := d.db.Exec(removedSql, removedArgs...); err != nil {
|
||||
if _, err := d.writer.Exec(removedSql, removedArgs...); err != nil {
|
||||
return errors.Wrap(err, "mark users as removed")
|
||||
}
|
||||
|
||||
@ -869,15 +871,15 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||
|
||||
func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) {
|
||||
var totalCount, unseenCount int
|
||||
err := d.db.QueryRow("SELECT count(*) FROM hosts").Scan(&totalCount)
|
||||
err := d.reader.Get(&totalCount, "SELECT count(*) FROM hosts")
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(err, "getting total host count")
|
||||
}
|
||||
|
||||
err = d.db.QueryRow(
|
||||
err = d.reader.Get(&unseenCount,
|
||||
"SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?",
|
||||
daysCount,
|
||||
).Scan(&unseenCount)
|
||||
)
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(err, "getting unseen host count")
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
|
||||
VALUES ( ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := d.db.Exec(sqlStmt, i.InvitedBy, i.Email,
|
||||
result, err := d.writer.Exec(sqlStmt, i.InvitedBy, i.Email,
|
||||
i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole)
|
||||
if err != nil && isDuplicate(err) {
|
||||
return nil, alreadyExists("Invite", i.Email)
|
||||
@ -45,10 +45,11 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
|
||||
for _, userTeam := range i.Teams {
|
||||
args = append(args, i.ID, userTeam.Team.ID, userTeam.Role)
|
||||
}
|
||||
// TODO: seems like this should be in a transaction?
|
||||
sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " +
|
||||
strings.Repeat(valueStr, len(i.Teams))
|
||||
sql = strings.TrimSuffix(sql, ",")
|
||||
if _, err := d.db.Exec(sql, args...); err != nil {
|
||||
if _, err := d.writer.Exec(sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "insert teams")
|
||||
}
|
||||
|
||||
@ -63,7 +64,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
|
||||
query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...)
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
|
||||
err := d.db.Select(&invites, query, params...)
|
||||
err := d.reader.Select(&invites, query, params...)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite")
|
||||
} else if err != nil {
|
||||
@ -80,7 +81,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
|
||||
// Invite returns Invite identified by id.
|
||||
func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
||||
var invite fleet.Invite
|
||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
|
||||
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite").WithID(id)
|
||||
} else if err != nil {
|
||||
@ -97,7 +98,7 @@ func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
||||
// InviteByEmail finds an Invite with a particular email, if one exists.
|
||||
func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
||||
var invite fleet.Invite
|
||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
|
||||
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite").
|
||||
WithMessage(fmt.Sprintf("with email %s", email))
|
||||
@ -115,7 +116,7 @@ func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
||||
// InviteByToken finds an Invite with a particular token, if one exists.
|
||||
func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) {
|
||||
var invite fleet.Invite
|
||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
|
||||
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Invite").
|
||||
WithMessage(fmt.Sprintf("with token %s", token))
|
||||
@ -163,7 +164,7 @@ func (d *Datastore) loadTeamsForInvites(invites []*fleet.Invite) error {
|
||||
fleet.UserTeam
|
||||
InviteID uint `db:"invite_id"`
|
||||
}
|
||||
if err := d.db.Select(&rows, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&rows, sql, args...); err != nil {
|
||||
return errors.Wrap(err, "get loadTeamsForInvites")
|
||||
}
|
||||
|
||||
|
@ -115,7 +115,7 @@ func (d *Datastore) GetLabelSpecs() ([]*fleet.LabelSpec, error) {
|
||||
var specs []*fleet.LabelSpec
|
||||
// Get basic specs
|
||||
query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels"
|
||||
if err := d.db.Select(&specs, query); err != nil {
|
||||
if err := d.reader.Select(&specs, query); err != nil {
|
||||
return nil, errors.Wrap(err, "get labels")
|
||||
}
|
||||
|
||||
@ -138,7 +138,7 @@ SELECT name, description, query, platform, label_type, label_membership_type
|
||||
FROM labels
|
||||
WHERE name = ?
|
||||
`
|
||||
if err := d.db.Select(&specs, query, name); err != nil {
|
||||
if err := d.reader.Select(&specs, query, name); err != nil {
|
||||
return nil, errors.Wrap(err, "get label")
|
||||
}
|
||||
if len(specs) == 0 {
|
||||
@ -171,7 +171,7 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
|
||||
WHERE label_id = (SELECT id FROM labels WHERE name = ?)
|
||||
)
|
||||
`
|
||||
err := d.db.Select(&label.Hosts, sql, label.Name)
|
||||
err := d.reader.Select(&label.Hosts, sql, label.Name)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get hostnames for label")
|
||||
}
|
||||
@ -190,7 +190,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
|
||||
label_membership_type
|
||||
) VALUES ( ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
query,
|
||||
label.Name,
|
||||
label.Description,
|
||||
@ -216,7 +216,7 @@ func (d *Datastore) SaveLabel(label *fleet.Label) (*fleet.Label, error) {
|
||||
description = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(query, label.Name, label.Description, label.ID)
|
||||
_, err := d.writer.Exec(query, label.Name, label.Description, label.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "saving label")
|
||||
}
|
||||
@ -236,7 +236,7 @@ func (d *Datastore) Label(lid uint) (*fleet.Label, error) {
|
||||
`
|
||||
label := &fleet.Label{}
|
||||
|
||||
if err := d.db.Get(label, sql, lid); err != nil {
|
||||
if err := d.reader.Get(label, sql, lid); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting label")
|
||||
}
|
||||
|
||||
@ -255,7 +255,7 @@ func (d *Datastore) ListLabels(filter fleet.TeamFilter, opt fleet.ListOptions) (
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
labels := []*fleet.Label{}
|
||||
|
||||
if err := d.db.Select(&labels, query); err != nil {
|
||||
if err := d.reader.Select(&labels, query); err != nil {
|
||||
// it's ok if no labels exist
|
||||
if err == sql.ErrNoRows {
|
||||
return labels, nil
|
||||
@ -288,7 +288,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
|
||||
WHERE platform = ? OR platform = ''
|
||||
AND label_membership_type = ?
|
||||
`
|
||||
rows, err = d.db.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
|
||||
rows, err = d.reader.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
|
||||
} else {
|
||||
// Retrieve all labels (with matching platform) iff there is a label
|
||||
// that has been created since this host last reported label query
|
||||
@ -300,7 +300,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
|
||||
AND (platform = ? OR platform = '')
|
||||
AND label_membership_type = ?
|
||||
`
|
||||
rows, err = d.db.Query(
|
||||
rows, err = d.reader.Query(
|
||||
sql,
|
||||
platform,
|
||||
host.LabelUpdatedAt,
|
||||
@ -369,7 +369,7 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
|
||||
updated_at = VALUES(updated_at)
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(sql, vals...)
|
||||
_, err := d.writer.Exec(sql, vals...)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "insert label query executions (%v)", vals)
|
||||
}
|
||||
@ -384,8 +384,8 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "IN for DELETE FROM label_membership")
|
||||
}
|
||||
query = d.db.Rebind(query)
|
||||
_, err = d.db.Exec(query, args...)
|
||||
query = d.writer.Rebind(query)
|
||||
_, err = d.writer.Exec(query, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "delete label query executions")
|
||||
}
|
||||
@ -403,7 +403,7 @@ func (d *Datastore) ListLabelsForHost(hid uint) ([]*fleet.Label, error) {
|
||||
`
|
||||
|
||||
labels := []*fleet.Label{}
|
||||
err := d.db.Select(&labels, sqlStatement, hid)
|
||||
err := d.reader.Select(&labels, sqlStatement, hid)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting host labels")
|
||||
}
|
||||
@ -432,7 +432,7 @@ func (d *Datastore) ListHostsInLabel(filter fleet.TeamFilter, lid uint, opt flee
|
||||
|
||||
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
||||
hosts := []*fleet.Host{}
|
||||
err := d.db.Select(&hosts, sql, params...)
|
||||
err := d.reader.Select(&hosts, sql, params...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting label query executions")
|
||||
}
|
||||
@ -458,9 +458,9 @@ func (d *Datastore) ListUniqueHostsInLabels(filter fleet.TeamFilter, labels []ui
|
||||
return nil, errors.Wrap(err, "building query listing unique hosts in labels")
|
||||
}
|
||||
|
||||
query = d.db.Rebind(query)
|
||||
query = d.reader.Rebind(query)
|
||||
hosts := []*fleet.Host{}
|
||||
err = d.db.Select(&hosts, query, args...)
|
||||
err = d.reader.Select(&hosts, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "listing unique hosts in labels")
|
||||
}
|
||||
@ -492,10 +492,10 @@ func (d *Datastore) searchLabelsWithOmits(filter fleet.TeamFilter, query string,
|
||||
return nil, errors.Wrap(err, "building query for labels with omits")
|
||||
}
|
||||
|
||||
sql = d.db.Rebind(sql)
|
||||
sql = d.reader.Rebind(sql)
|
||||
|
||||
matches := []*fleet.Label{}
|
||||
err = d.db.Select(&matches, sql, args...)
|
||||
err = d.reader.Select(&matches, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting labels with omits")
|
||||
}
|
||||
@ -527,7 +527,7 @@ func (d *Datastore) addAllHostsLabelToList(filter fleet.TeamFilter, labels []*fl
|
||||
)
|
||||
|
||||
var allHosts fleet.Label
|
||||
if err := d.db.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
|
||||
if err := d.reader.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
|
||||
return nil, errors.Wrap(err, "get all hosts label")
|
||||
}
|
||||
|
||||
@ -575,8 +575,8 @@ func (d *Datastore) searchLabelsDefault(filter fleet.TeamFilter, omit ...uint) (
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "searching default labels")
|
||||
}
|
||||
sql = d.db.Rebind(sql)
|
||||
if err := d.db.Select(&labels, sql, args...); err != nil {
|
||||
sql = d.reader.Rebind(sql)
|
||||
if err := d.reader.Select(&labels, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "searching default labels rebound")
|
||||
}
|
||||
|
||||
@ -617,7 +617,7 @@ func (d *Datastore) SearchLabels(filter fleet.TeamFilter, query string, omit ...
|
||||
)
|
||||
|
||||
matches := []*fleet.Label{}
|
||||
if err := d.db.Select(&matches, sql, transformedQuery); err != nil {
|
||||
if err := d.reader.Select(&matches, sql, transformedQuery); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting labels for search")
|
||||
}
|
||||
|
||||
@ -645,7 +645,7 @@ func (d *Datastore) LabelIDsByName(labels []string) ([]uint, error) {
|
||||
}
|
||||
|
||||
var labelIDs []uint
|
||||
if err := d.db.Select(&labelIDs, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&labelIDs, sql, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "get label IDs")
|
||||
}
|
||||
|
||||
|
@ -29,27 +29,27 @@ func (d *Datastore) Lock(name string, owner string, expiration time.Duration) (b
|
||||
}
|
||||
|
||||
func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||
return d.db.Exec(
|
||||
return d.writer.Exec(
|
||||
`INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`,
|
||||
name, owner, time.Now().Add(expiration),
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||
return d.db.Exec(
|
||||
return d.writer.Exec(
|
||||
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`,
|
||||
name, owner, time.Now().Add(expiration), name, owner,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||
return d.db.Exec(
|
||||
return d.writer.Exec(
|
||||
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`,
|
||||
name, owner, time.Now().Add(expiration), name,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Datastore) Unlock(name string, owner string) error {
|
||||
_, err := d.db.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
|
||||
_, err := d.writer.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
|
||||
return err
|
||||
}
|
||||
|
@ -85,14 +85,14 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
|
||||
defer ds.Close()
|
||||
|
||||
for {
|
||||
version, err := tables.MigrationClient.GetDBVersion(ds.db.DB)
|
||||
version, err := tables.MigrationClient.GetDBVersion(ds.writer.DB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// break right before the the constraint migration
|
||||
if version == 20210818182258 {
|
||||
break
|
||||
}
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
|
||||
}
|
||||
|
||||
host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now())
|
||||
@ -122,7 +122,7 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
|
||||
|
||||
require.NoError(t, ds.DeleteHost(host1.ID))
|
||||
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
|
||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
|
||||
|
||||
// Make sure we don't delete more than we need
|
||||
hostCheck, err := ds.Host(host2.ID)
|
||||
|
@ -37,13 +37,28 @@ var (
|
||||
columnCharsRegexp = regexp.MustCompile(`[^\w-]`)
|
||||
)
|
||||
|
||||
// dbReader is an interface that defines the methods required for reads.
|
||||
type dbReader interface {
|
||||
sqlx.Queryer
|
||||
|
||||
Close() error
|
||||
Rebind(string) string
|
||||
Select(interface{}, string, ...interface{}) error
|
||||
Get(interface{}, string, ...interface{}) error
|
||||
}
|
||||
|
||||
// Datastore is an implementation of fleet.Datastore interface backed by
|
||||
// MySQL
|
||||
type Datastore struct {
|
||||
db *sqlx.DB
|
||||
reader dbReader // so it cannot be used to perform writes
|
||||
writer *sqlx.DB
|
||||
|
||||
logger log.Logger
|
||||
clock clock.Clock
|
||||
config config.MysqlConfig
|
||||
|
||||
// nil if no read replica
|
||||
readReplicaConfig *config.MysqlConfig
|
||||
}
|
||||
|
||||
type txFn func(*sqlx.Tx) error
|
||||
@ -67,7 +82,7 @@ func retryableError(err error) bool {
|
||||
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
||||
func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
||||
operation := func() error {
|
||||
tx, err := d.db.Beginx()
|
||||
tx, err := d.writer.Beginx()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create transaction")
|
||||
}
|
||||
@ -116,7 +131,7 @@ func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
||||
|
||||
// withTx provides a common way to commit/rollback a txFn
|
||||
func (d *Datastore) withTx(fn txFn) (err error) {
|
||||
tx, err := d.db.Beginx()
|
||||
tx, err := d.writer.Beginx()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create transaction")
|
||||
}
|
||||
@ -153,51 +168,64 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
||||
}
|
||||
|
||||
for _, setOpt := range opts {
|
||||
setOpt(options)
|
||||
if setOpt != nil {
|
||||
setOpt(options)
|
||||
}
|
||||
}
|
||||
|
||||
if config.PasswordPath != "" && config.Password != "" {
|
||||
return nil, errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
||||
if err := checkConfig(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if options.replicaConfig != nil {
|
||||
if err := checkConfig(options.replicaConfig); err != nil {
|
||||
return nil, errors.Wrap(err, "replica")
|
||||
}
|
||||
}
|
||||
|
||||
// Check to see if the flag is populated
|
||||
// Check if file exists on disk
|
||||
// If file exists read contents
|
||||
if config.PasswordPath != "" {
|
||||
fileContents, err := ioutil.ReadFile(config.PasswordPath)
|
||||
dbWriter, err := newDB(&config, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbReader := dbWriter
|
||||
if options.replicaConfig != nil {
|
||||
dbReader, err = newDB(options.replicaConfig, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Password = strings.TrimSpace(string(fileContents))
|
||||
}
|
||||
|
||||
if config.TLSCA != "" {
|
||||
config.TLSConfig = "custom"
|
||||
err := registerTLS(config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "register TLS config for mysql")
|
||||
}
|
||||
ds := &Datastore{
|
||||
writer: dbWriter,
|
||||
reader: dbReader,
|
||||
logger: options.logger,
|
||||
clock: c,
|
||||
config: config,
|
||||
readReplicaConfig: options.replicaConfig,
|
||||
}
|
||||
|
||||
dsn := generateMysqlConnectionString(config)
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
func newDB(conf *config.MysqlConfig, opts *dbOptions) (*sqlx.DB, error) {
|
||||
dsn := generateMysqlConnectionString(*conf)
|
||||
db, err := sqlx.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetime))
|
||||
db.SetMaxIdleConns(conf.MaxIdleConns)
|
||||
db.SetMaxOpenConns(conf.MaxOpenConns)
|
||||
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
|
||||
|
||||
var dbError error
|
||||
for attempt := 0; attempt < options.maxAttempts; attempt++ {
|
||||
for attempt := 0; attempt < opts.maxAttempts; attempt++ {
|
||||
dbError = db.Ping()
|
||||
if dbError == nil {
|
||||
// we're connected!
|
||||
break
|
||||
}
|
||||
interval := time.Duration(attempt) * time.Second
|
||||
options.logger.Log("mysql", fmt.Sprintf(
|
||||
opts.logger.Log("mysql", fmt.Sprintf(
|
||||
"could not connect to db: %v, sleeping %v", dbError, interval))
|
||||
time.Sleep(interval)
|
||||
}
|
||||
@ -205,20 +233,37 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
||||
if dbError != nil {
|
||||
return nil, dbError
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
ds := &Datastore{
|
||||
db: db,
|
||||
logger: options.logger,
|
||||
clock: c,
|
||||
config: config,
|
||||
func checkConfig(conf *config.MysqlConfig) error {
|
||||
if conf.PasswordPath != "" && conf.Password != "" {
|
||||
return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
||||
}
|
||||
|
||||
return ds, nil
|
||||
// Check to see if the flag is populated
|
||||
// Check if file exists on disk
|
||||
// If file exists read contents
|
||||
if conf.PasswordPath != "" {
|
||||
fileContents, err := ioutil.ReadFile(conf.PasswordPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conf.Password = strings.TrimSpace(string(fileContents))
|
||||
}
|
||||
|
||||
if conf.TLSCA != "" {
|
||||
conf.TLSConfig = "custom"
|
||||
err := registerTLS(*conf)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "register TLS config for mysql")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Datastore) Begin() (fleet.Transaction, error) {
|
||||
return d.db.Beginx()
|
||||
return d.writer.Beginx()
|
||||
}
|
||||
|
||||
func (d *Datastore) Name() string {
|
||||
@ -226,11 +271,11 @@ func (d *Datastore) Name() string {
|
||||
}
|
||||
|
||||
func (d *Datastore) MigrateTables() error {
|
||||
return tables.MigrationClient.Up(d.db.DB, "")
|
||||
return tables.MigrationClient.Up(d.writer.DB, "")
|
||||
}
|
||||
|
||||
func (d *Datastore) MigrateData() error {
|
||||
return data.MigrationClient.Up(d.db.DB, "")
|
||||
return data.MigrationClient.Up(d.writer.DB, "")
|
||||
}
|
||||
|
||||
func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||
@ -243,7 +288,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||
return 0, errors.Wrap(err, "missing tables migrations")
|
||||
}
|
||||
|
||||
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.db.DB)
|
||||
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.writer.DB)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "cannot get table migration status")
|
||||
}
|
||||
@ -253,7 +298,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||
return 0, errors.Wrap(err, "missing data migrations")
|
||||
}
|
||||
|
||||
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.db.DB)
|
||||
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.writer.DB)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "cannot get data migration status")
|
||||
}
|
||||
@ -283,11 +328,11 @@ func (d *Datastore) Drop() error {
|
||||
WHERE TABLE_SCHEMA = ?;
|
||||
`
|
||||
|
||||
if err := d.db.Select(&tables, sql, d.config.Database); err != nil {
|
||||
if err := d.writer.Select(&tables, sql, d.config.Database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := d.db.Begin()
|
||||
tx, err := d.writer.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -312,13 +357,28 @@ func (d *Datastore) Drop() error {
|
||||
|
||||
// HealthCheck returns an error if the MySQL backend is not healthy.
|
||||
func (d *Datastore) HealthCheck() error {
|
||||
_, err := d.db.Exec("select 1")
|
||||
return err
|
||||
if _, err := d.writer.Exec("select 1"); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.readReplicaConfig != nil {
|
||||
var dst int
|
||||
if err := d.reader.Get(&dst, "select 1"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close frees resources associated with underlying mysql connection
|
||||
func (d *Datastore) Close() error {
|
||||
return d.db.Close()
|
||||
err := d.writer.Close()
|
||||
if d.readReplicaConfig != nil {
|
||||
errRead := d.reader.Close()
|
||||
if err == nil {
|
||||
err = errRead
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeColumn(col string) string {
|
||||
|
@ -1,7 +1,9 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/VividCortex/mysqlerr"
|
||||
@ -15,6 +17,51 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDatastoreReplica(t *testing.T) {
|
||||
// a bit unfortunate to create temp databases just for this - could be mixed
|
||||
// with other tests when/if we move to subtests to minimize the number of
|
||||
// databases created for tests (see #1805).
|
||||
|
||||
t.Run("noreplica", func(t *testing.T) {
|
||||
ds := CreateMySQLDSWithOptions(t, nil)
|
||||
defer ds.Close()
|
||||
require.Equal(t, ds.reader, ds.writer)
|
||||
})
|
||||
|
||||
t.Run("replica", func(t *testing.T) {
|
||||
opts := &DatastoreTestOptions{Replica: true}
|
||||
ds := CreateMySQLDSWithOptions(t, opts)
|
||||
defer ds.Close()
|
||||
require.NotEqual(t, ds.reader, ds.writer)
|
||||
|
||||
// create a new host
|
||||
host, err := ds.NewHost(&fleet.Host{
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
NodeKey: "1",
|
||||
UUID: "1",
|
||||
Hostname: "foo.local",
|
||||
PrimaryIP: "192.168.1.1",
|
||||
PrimaryMac: "30-65-EC-6F-C4-58",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, host)
|
||||
|
||||
// trying to read it fails, not replicated yet
|
||||
_, err = ds.Host(host.ID)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, sql.ErrNoRows))
|
||||
|
||||
opts.RunReplication()
|
||||
|
||||
// now it can read it
|
||||
host2, err := ds.Host(host.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, host.ID, host2.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeColumn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -121,8 +168,10 @@ func TestSearchLike(t *testing.T) {
|
||||
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
dbmock := sqlx.NewDb(db, "sqlmock")
|
||||
ds := &Datastore{
|
||||
db: sqlx.NewDb(db, "sqlmock"),
|
||||
writer: dbmock,
|
||||
reader: dbmock,
|
||||
logger: log.NewNopLogger(),
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
|
||||
WHERE name = ?
|
||||
`
|
||||
var pack fleet.Pack
|
||||
err := d.db.Get(&pack, sqlStatement, name)
|
||||
err := d.reader.Get(&pack, sqlStatement, name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, false, nil
|
||||
@ -213,7 +213,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
|
||||
return nil, false, errors.Wrap(err, "fetch pack by name")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(&pack); err != nil {
|
||||
if err := d.loadPackTargets(d.reader, &pack); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
@ -313,10 +313,10 @@ func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Datastore) loadPackTargets(pack *fleet.Pack) error {
|
||||
func (d *Datastore) loadPackTargets(db dbReader, pack *fleet.Pack) error {
|
||||
var targets []fleet.PackTarget
|
||||
sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
|
||||
if err := d.db.Select(&targets, sql, pack.ID); err != nil {
|
||||
if err := db.Select(&targets, sql, pack.ID); err != nil {
|
||||
return errors.Wrap(err, "select pack targets")
|
||||
}
|
||||
|
||||
@ -371,14 +371,14 @@ func (d *Datastore) DeletePack(name string) error {
|
||||
func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
||||
query := `SELECT * FROM packs WHERE id = ?`
|
||||
pack := &fleet.Pack{}
|
||||
err := d.db.Get(pack, query, pid)
|
||||
err := d.reader.Get(pack, query, pid)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Pack").WithID(pid)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrap(err, "get pack")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.reader, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -388,14 +388,15 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
||||
// EnsureGlobalPack gets or inserts a pack with type global
|
||||
func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) {
|
||||
pack := &fleet.Pack{}
|
||||
err := d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
|
||||
// read from primary as we will create the pack if it doesn't exist
|
||||
err := d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
|
||||
if err == sql.ErrNoRows {
|
||||
return d.insertNewGlobalPack()
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrap(err, "get pack")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.writer, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -439,14 +440,15 @@ func (d *Datastore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) {
|
||||
}
|
||||
|
||||
teamType := fmt.Sprintf("team-%d", teamID)
|
||||
err = d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
|
||||
// read from primary as we will create the team pack if it doesn't exist
|
||||
err = d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
|
||||
if err == sql.ErrNoRows {
|
||||
return d.insertNewTeamPack(t)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrap(err, "get pack")
|
||||
}
|
||||
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.writer, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -465,7 +467,7 @@ func (d *Datastore) insertNewTeamPack(team *fleet.Team) (*fleet.Pack, error) {
|
||||
var packID uint
|
||||
err := d.withTx(func(tx *sqlx.Tx) error {
|
||||
res, err := tx.Exec(
|
||||
`INSERT INTO packs (name, description, platform, pack_type)
|
||||
`INSERT INTO packs (name, description, platform, pack_type)
|
||||
VALUES (?, 'Schedule additional queries for all hosts assigned to this team.', '',?)`,
|
||||
teamScheduleName(team), teamSchedulePackType(team),
|
||||
)
|
||||
@ -499,13 +501,13 @@ func (d *Datastore) ListPacks(opt fleet.PackListOptions) ([]*fleet.Pack, error)
|
||||
query = `SELECT * FROM packs`
|
||||
}
|
||||
var packs []*fleet.Pack
|
||||
err := d.db.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
|
||||
err := d.reader.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "listing packs")
|
||||
}
|
||||
|
||||
for _, pack := range packs {
|
||||
if err := d.loadPackTargets(pack); err != nil {
|
||||
if err := d.loadPackTargets(d.reader, pack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -541,7 +543,7 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
|
||||
`
|
||||
|
||||
packs := []*fleet.Pack{}
|
||||
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
|
||||
if err := d.reader.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "listing hosts in pack")
|
||||
}
|
||||
return packs, nil
|
||||
|
@ -524,7 +524,7 @@ func TestTeamScheduleNamesMigrateToNewFormat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// insert team pack by hand with the old naming scheme
|
||||
_, err = ds.db.Exec(
|
||||
_, err = ds.writer.Exec(
|
||||
"INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)",
|
||||
teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1),
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ func (d *Datastore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
|
||||
( user_id, token, expires_at)
|
||||
VALUES (?,?, NOW())
|
||||
`
|
||||
response, err := d.db.Exec(sqlStatement, req.UserID, req.Token)
|
||||
response, err := d.writer.Exec(sqlStatement, req.UserID, req.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting password reset requests")
|
||||
}
|
||||
@ -30,7 +30,7 @@ func (d *Datastore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) er
|
||||
token = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
|
||||
result, err := d.writer.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating password reset requests")
|
||||
}
|
||||
@ -58,7 +58,7 @@ func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
|
||||
sqlStatement := `
|
||||
DELETE FROM password_reset_requests WHERE user_id = ?
|
||||
`
|
||||
_, err := d.db.Exec(sqlStatement, userID)
|
||||
_, err := d.writer.Exec(sqlStatement, userID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "deleting password reset request by user")
|
||||
}
|
||||
@ -72,7 +72,7 @@ func (d *Datastore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||
err := d.db.Get(&passwordResetRequest, sqlStatement, id)
|
||||
err := d.reader.Get(&passwordResetRequest, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting password reset by id")
|
||||
}
|
||||
@ -87,7 +87,7 @@ func (d *Datastore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordReset
|
||||
`
|
||||
|
||||
passwordResetRequests := []*fleet.PasswordResetRequest{}
|
||||
err := d.db.Select(&passwordResetRequests, sqlStatement, id)
|
||||
err := d.reader.Select(&passwordResetRequests, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "finding password resets by user id")
|
||||
}
|
||||
@ -102,7 +102,7 @@ func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordRese
|
||||
WHERE token = ? LIMIT 1
|
||||
`
|
||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||
err := d.db.Get(passwordResetRequest, sqlStatement, token)
|
||||
err := d.reader.Get(passwordResetRequest, sqlStatement, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting password reset requests")
|
||||
}
|
||||
@ -118,7 +118,7 @@ func (d *Datastore) FindPassswordResetByTokenAndUserID(token string, id uint) (*
|
||||
LIMIT 1
|
||||
`
|
||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||
err := d.db.Get(passwordResetRequest, sqlStatement, id, token)
|
||||
err := d.reader.Get(passwordResetRequest, sqlStatement, id, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting password reset by token and user id")
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
||||
res, err := ds.db.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
|
||||
res, err := ds.writer.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting new policy")
|
||||
}
|
||||
@ -26,12 +26,12 @@ func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
||||
|
||||
func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) {
|
||||
var policy fleet.Policy
|
||||
err := ds.db.Get(
|
||||
err := ds.reader.Get(
|
||||
&policy,
|
||||
`SELECT
|
||||
p.*,
|
||||
q.name as query_name,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
|
||||
`SELECT
|
||||
p.*,
|
||||
q.name as query_name,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
|
||||
FROM policies p JOIN queries q ON (p.query_id=q.id) WHERE p.id=?`,
|
||||
id,
|
||||
@ -66,7 +66,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
|
||||
strings.Join(bindvars, ","),
|
||||
)
|
||||
|
||||
_, err := ds.db.Exec(query, vals...)
|
||||
_, err := ds.writer.Exec(query, vals...)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "insert policy_membership (%v)", vals)
|
||||
}
|
||||
@ -76,13 +76,13 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
|
||||
|
||||
func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) {
|
||||
var policies []*fleet.Policy
|
||||
err := ds.db.Select(
|
||||
err := ds.reader.Select(
|
||||
&policies,
|
||||
`SELECT
|
||||
p.*,
|
||||
q.name as query_name,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
|
||||
`SELECT
|
||||
p.*,
|
||||
q.name as query_name,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
|
||||
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
|
||||
FROM policies p JOIN queries q ON (p.query_id=q.id)`,
|
||||
)
|
||||
if err != nil {
|
||||
@ -97,8 +97,8 @@ func (ds *Datastore) DeleteGlobalPolicies(ids []uint) ([]uint, error) {
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "IN for DELETE FROM policies")
|
||||
}
|
||||
stmt = ds.db.Rebind(stmt)
|
||||
if _, err := ds.db.Exec(stmt, args...); err != nil {
|
||||
stmt = ds.writer.Rebind(stmt)
|
||||
if _, err := ds.writer.Exec(stmt, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "delete policies")
|
||||
}
|
||||
return ids, nil
|
||||
@ -109,7 +109,7 @@ func (ds *Datastore) PolicyQueriesForHost(_ *fleet.Host) (map[string]string, err
|
||||
Id string `db:"id"`
|
||||
Query string `db:"query"`
|
||||
}
|
||||
err := ds.db.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
|
||||
err := ds.reader.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting policies for host")
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) {
|
||||
tx, err := d.db.Begin()
|
||||
tx, err := d.writer.Begin()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "begin ApplyQueries transaction")
|
||||
}
|
||||
@ -72,7 +72,7 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
|
||||
WHERE name = ?
|
||||
`
|
||||
var query fleet.Query
|
||||
err := d.db.Get(&query, sqlStatement, name)
|
||||
err := d.reader.Get(&query, sqlStatement, name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, notFound("Query").WithName(name)
|
||||
@ -99,7 +99,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
|
||||
observer_can_run
|
||||
) VALUES ( ?, ?, ?, ?, ?, ? )
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
|
||||
result, err := d.writer.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
|
||||
|
||||
if err != nil && isDuplicate(err) {
|
||||
return nil, alreadyExists("Query", 0)
|
||||
@ -120,7 +120,7 @@ func (d *Datastore) SaveQuery(q *fleet.Query) error {
|
||||
SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
|
||||
result, err := d.writer.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating query")
|
||||
}
|
||||
@ -156,7 +156,7 @@ func (d *Datastore) Query(id uint) (*fleet.Query, error) {
|
||||
WHERE q.id = ?
|
||||
`
|
||||
query := &fleet.Query{}
|
||||
if err := d.db.Get(query, sql, id); err != nil {
|
||||
if err := d.reader.Get(query, sql, id); err != nil {
|
||||
return nil, errors.Wrap(err, "selecting query")
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func (d *Datastore) ListQueries(opt fleet.ListOptions) ([]*fleet.Query, error) {
|
||||
sql = appendListOptionsToSQL(sql, opt)
|
||||
results := []*fleet.Query{}
|
||||
|
||||
if err := d.db.Select(&results, sql); err != nil {
|
||||
if err := d.reader.Select(&results, sql); err != nil {
|
||||
return nil, errors.Wrap(err, "listing queries")
|
||||
}
|
||||
|
||||
@ -225,7 +225,7 @@ func (d *Datastore) loadPacksForQueries(queries []*fleet.Query) error {
|
||||
fleet.Pack
|
||||
}{}
|
||||
|
||||
err = d.db.Select(&rows, query, args...)
|
||||
err = d.reader.Select(&rows, query, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "selecting load packs for queries")
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
|
||||
query = appendListOptionsToSQL(query, opts)
|
||||
results := []*fleet.ScheduledQuery{}
|
||||
|
||||
if err := d.db.Select(&results, query, id); err != nil {
|
||||
if err := d.reader.Select(&results, query, id); err != nil {
|
||||
return nil, errors.Wrap(err, "listing scheduled queries")
|
||||
}
|
||||
|
||||
@ -46,8 +46,8 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
|
||||
}
|
||||
|
||||
func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
||||
selectFunc := d.db.Select
|
||||
execFunc := d.db.Exec
|
||||
selectFunc := d.writer.Select
|
||||
execFunc := d.writer.Exec
|
||||
if tx != nil {
|
||||
selectFunc = tx.Select
|
||||
execFunc = tx.Exec
|
||||
@ -108,7 +108,7 @@ func (d *Datastore) SaveScheduledQuery(sq *fleet.ScheduledQuery) (*fleet.Schedul
|
||||
}
|
||||
|
||||
func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
||||
updateFunc := d.db.Exec
|
||||
updateFunc := d.writer.Exec
|
||||
if tx != nil {
|
||||
updateFunc = tx.Exec
|
||||
}
|
||||
@ -160,7 +160,7 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
|
||||
WHERE sq.id = ?
|
||||
`
|
||||
sq := &fleet.ScheduledQuery{}
|
||||
if err := d.db.Get(sq, query, id); err != nil {
|
||||
if err := d.reader.Get(sq, query, id); err != nil {
|
||||
return nil, errors.Wrap(err, "select scheduled query")
|
||||
}
|
||||
|
||||
@ -168,11 +168,11 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
|
||||
}
|
||||
|
||||
func (d *Datastore) CleanupOrphanScheduledQueryStats() error {
|
||||
_, err := d.db.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
|
||||
_, err := d.writer.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query")
|
||||
}
|
||||
_, err = d.db.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
|
||||
_, err = d.writer.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host")
|
||||
}
|
||||
|
@ -209,9 +209,9 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
||||
test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1")
|
||||
sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2")
|
||||
|
||||
_, err := ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted,
|
||||
executions, schedule_interval, output_size, system_time,
|
||||
_, err := ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted,
|
||||
executions, schedule_interval, output_size, system_time,
|
||||
user_time, wall_time
|
||||
) VALUES (?, ?, 32, false, 4, 4, 4, 4, 4, 4);`, h1.ID, sq1.ID)
|
||||
require.NoError(t, err)
|
||||
@ -224,12 +224,12 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
||||
require.Len(t, h1.PackStats, 1)
|
||||
|
||||
// now we insert a bogus stat
|
||||
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
||||
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted, executions
|
||||
) VALUES (?, 999, 32, false, 2);`, h1.ID)
|
||||
require.NoError(t, err)
|
||||
// and also for an unknown host
|
||||
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
||||
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||
host_id, scheduled_query_id, average_memory, denylisted, executions
|
||||
) VALUES (888, 999, 32, true, 4);`)
|
||||
require.NoError(t, err)
|
||||
@ -241,14 +241,14 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
||||
|
||||
// but there are definitely there
|
||||
var count int
|
||||
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, count)
|
||||
|
||||
// now we clean it up
|
||||
require.NoError(t, ds.CleanupOrphanScheduledQueryStats())
|
||||
|
||||
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
|
@ -11,7 +11,7 @@ func (d *Datastore) SessionByKey(key string) (*fleet.Session, error) {
|
||||
WHERE ` + "`key`" + ` = ? LIMIT 1
|
||||
`
|
||||
session := &fleet.Session{}
|
||||
err := d.db.Get(session, sqlStatement, key)
|
||||
err := d.reader.Get(session, sqlStatement, key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting sessions")
|
||||
}
|
||||
@ -26,7 +26,7 @@ func (d *Datastore) SessionByID(id uint) (*fleet.Session, error) {
|
||||
LIMIT 1
|
||||
`
|
||||
session := &fleet.Session{}
|
||||
err := d.db.Get(session, sqlStatement, id)
|
||||
err := d.reader.Get(session, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting session by id")
|
||||
}
|
||||
@ -40,7 +40,7 @@ func (d *Datastore) ListSessionsForUser(id uint) ([]*fleet.Session, error) {
|
||||
WHERE user_id = ?
|
||||
`
|
||||
sessions := []*fleet.Session{}
|
||||
err := d.db.Select(&sessions, sqlStatement, id)
|
||||
err := d.reader.Select(&sessions, sqlStatement, id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "selecting sessions for user")
|
||||
}
|
||||
@ -57,7 +57,7 @@ func (d *Datastore) NewSession(session *fleet.Session) (*fleet.Session, error) {
|
||||
)
|
||||
VALUES(?,?)
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement, session.UserID, session.Key)
|
||||
result, err := d.writer.Exec(sqlStatement, session.UserID, session.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "inserting session")
|
||||
}
|
||||
@ -80,7 +80,7 @@ func (d *Datastore) DestroyAllSessionsForUser(id uint) error {
|
||||
sqlStatement := `
|
||||
DELETE FROM sessions WHERE user_id = ?
|
||||
`
|
||||
_, err := d.db.Exec(sqlStatement, id)
|
||||
_, err := d.writer.Exec(sqlStatement, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "deleting sessions for user")
|
||||
}
|
||||
@ -94,7 +94,7 @@ func (d *Datastore) MarkSessionAccessed(session *fleet.Session) error {
|
||||
accessed_at = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
results, err := d.db.Exec(sqlStatement, d.clock.Now(), session.ID)
|
||||
results, err := d.writer.Exec(sqlStatement, d.clock.Now(), session.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating mark session as accessed")
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func (d *Datastore) insertNewInstalledHostSoftware(
|
||||
}
|
||||
|
||||
func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) {
|
||||
selectFunc := d.db.Select
|
||||
selectFunc := d.reader.Select
|
||||
if tx != nil {
|
||||
selectFunc = tx.Select
|
||||
}
|
||||
@ -231,7 +231,7 @@ func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Softwa
|
||||
WHERE s.id IN
|
||||
(SELECT software_id FROM host_software WHERE host_id = ?)
|
||||
`
|
||||
queryFunc := d.db.Queryx
|
||||
queryFunc := d.reader.Queryx
|
||||
if tx != nil {
|
||||
queryFunc = tx.Queryx
|
||||
}
|
||||
@ -305,7 +305,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
|
||||
sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null`
|
||||
// The rows.Close call is done by the caller once iteration using the
|
||||
// returned fleet.SoftwareIterator is done.
|
||||
rows, err := d.db.Queryx(sql) //nolint:sqlclosecheck
|
||||
rows, err := d.reader.Queryx(sql) //nolint:sqlclosecheck
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "load host software")
|
||||
}
|
||||
@ -314,7 +314,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
|
||||
|
||||
func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error {
|
||||
sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)`
|
||||
if _, err := d.db.Exec(sql, software.ID, cpe); err != nil {
|
||||
if _, err := d.writer.Exec(sql, software.ID, cpe); err != nil {
|
||||
return errors.Wrap(err, "insert software cpe")
|
||||
}
|
||||
return nil
|
||||
@ -323,7 +323,7 @@ func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error
|
||||
func (d *Datastore) AllCPEs() ([]string, error) {
|
||||
sql := `SELECT cpe FROM software_cpe`
|
||||
var cpes []string
|
||||
err := d.db.Select(&cpes, sql)
|
||||
err := d.reader.Select(&cpes, sql)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "loads cpes")
|
||||
}
|
||||
@ -337,7 +337,7 @@ func (d *Datastore) InsertCVEForCPE(cve string, cpes []string) error {
|
||||
for _, cpe := range cpes {
|
||||
args = append(args, cpe, cve)
|
||||
}
|
||||
_, err := d.db.Exec(sql, args...)
|
||||
_, err := d.writer.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "insert software cve")
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
|
||||
})
|
||||
incoming[soft2Key] = true
|
||||
|
||||
tx, err := ds.db.Beginx()
|
||||
tx, err := ds.writer.Beginx()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
||||
require.NoError(t, tx.Commit())
|
||||
@ -220,7 +220,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
|
||||
})
|
||||
incoming[soft3Key] = true
|
||||
|
||||
tx, err = ds.db.Beginx()
|
||||
tx, err = ds.writer.Beginx()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
||||
require.NoError(t, tx.Commit())
|
||||
|
@ -21,14 +21,14 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
|
||||
}
|
||||
|
||||
dest := statistics{}
|
||||
err = d.db.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
|
||||
err = d.writer.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
anonIdentifier, err := server.GenerateRandomText(64)
|
||||
if err != nil {
|
||||
return fleet.StatisticsPayload{}, false, err
|
||||
}
|
||||
_, err = d.db.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
|
||||
_, err = d.writer.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
|
||||
if err != nil {
|
||||
return fleet.StatisticsPayload{}, false, err
|
||||
}
|
||||
@ -55,6 +55,6 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
|
||||
}
|
||||
|
||||
func (d *Datastore) RecordStatisticsSent() error {
|
||||
_, err := d.db.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
|
||||
_, err := d.writer.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
|
||||
return err
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func (d *Datastore) CountHostsInTargets(filter fleet.TeamFilter, targets fleet.H
|
||||
}
|
||||
|
||||
res := fleet.TargetMetrics{}
|
||||
err = d.db.Get(&res, query, args...)
|
||||
err = d.reader.Get(&res, query, args...)
|
||||
if err != nil {
|
||||
return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets")
|
||||
}
|
||||
@ -98,7 +98,7 @@ func (d *Datastore) HostIDsInTargets(filter fleet.TeamFilter, targets fleet.Host
|
||||
}
|
||||
|
||||
var res []uint
|
||||
err = d.db.Select(&res, query, args...)
|
||||
err = d.reader.Select(&res, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets")
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ func (d *Datastore) NewTeam(team *fleet.Team) (*fleet.Team, error) {
|
||||
description
|
||||
) VALUES ( ?, ?, ? )
|
||||
`
|
||||
result, err := d.db.Exec(
|
||||
result, err := d.writer.Exec(
|
||||
query,
|
||||
team.Name,
|
||||
team.AgentOptions,
|
||||
@ -47,7 +47,7 @@ func (d *Datastore) Team(tid uint) (*fleet.Team, error) {
|
||||
`
|
||||
team := &fleet.Team{}
|
||||
|
||||
if err := d.db.Get(team, sql, tid); err != nil {
|
||||
if err := d.reader.Get(team, sql, tid); err != nil {
|
||||
return nil, errors.Wrap(err, "select team")
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ func (d *Datastore) TeamByName(name string) (*fleet.Team, error) {
|
||||
`
|
||||
team := &fleet.Team{}
|
||||
|
||||
if err := d.db.Get(team, sql, name); err != nil {
|
||||
if err := d.reader.Get(team, sql, name); err != nil {
|
||||
return nil, errors.Wrap(err, "select team")
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ func (d *Datastore) loadUsersForTeam(team *fleet.Team) error {
|
||||
WHERE ut.team_id = ?
|
||||
`
|
||||
rows := []fleet.TeamUser{}
|
||||
if err := d.db.Select(&rows, sql, team.ID); err != nil {
|
||||
if err := d.reader.Select(&rows, sql, team.ID); err != nil {
|
||||
return errors.Wrap(err, "load users for team")
|
||||
}
|
||||
|
||||
@ -157,7 +157,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
|
||||
description = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
_, err := d.db.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
|
||||
_, err := d.writer.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "saving team")
|
||||
}
|
||||
@ -174,7 +174,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
|
||||
}
|
||||
|
||||
func (d *Datastore) updateTeamSchedule(team *fleet.Team) error {
|
||||
_, err := d.db.Exec(
|
||||
_, err := d.writer.Exec(
|
||||
`UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team),
|
||||
)
|
||||
return err
|
||||
@ -195,7 +195,7 @@ func (d *Datastore) ListTeams(filter fleet.TeamFilter, opt fleet.ListOptions) ([
|
||||
query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...)
|
||||
query = appendListOptionsToSQL(query, opt)
|
||||
teams := []*fleet.Team{}
|
||||
if err := d.db.Select(&teams, query, params...); err != nil {
|
||||
if err := d.reader.Select(&teams, query, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "list teams")
|
||||
}
|
||||
if err := d.loadSecretsForTeams(teams); err != nil {
|
||||
@ -229,7 +229,7 @@ func (d *Datastore) SearchTeams(filter fleet.TeamFilter, matchQuery string, omit
|
||||
sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...)
|
||||
sql += "\nLIMIT 5"
|
||||
teams := []*fleet.Team{}
|
||||
if err := d.db.Select(&teams, sql, params...); err != nil {
|
||||
if err := d.reader.Select(&teams, sql, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "search teams")
|
||||
}
|
||||
if err := d.loadSecretsForTeams(teams); err != nil {
|
||||
@ -244,7 +244,7 @@ func (d *Datastore) TeamEnrollSecrets(teamID uint) ([]*fleet.EnrollSecret, error
|
||||
WHERE team_id = ?
|
||||
`
|
||||
var secrets []*fleet.EnrollSecret
|
||||
if err := d.db.Select(&secrets, sql, teamID); err != nil {
|
||||
if err := d.reader.Select(&secrets, sql, teamID); err != nil {
|
||||
return nil, errors.Wrap(err, "get secrets")
|
||||
}
|
||||
return secrets, nil
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -9,20 +10,23 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testUsername = "root"
|
||||
testPassword = "toor"
|
||||
testAddress = "localhost:3307"
|
||||
testUsername = "root"
|
||||
testPassword = "toor"
|
||||
testAddress = "localhost:3307"
|
||||
testReplicaDatabaseSuffix = "_replica"
|
||||
)
|
||||
|
||||
func connectMySQL(t *testing.T, testName string) *Datastore {
|
||||
func connectMySQL(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
|
||||
config := config.MysqlConfig{
|
||||
Username: testUsername,
|
||||
Password: testPassword,
|
||||
@ -31,15 +35,134 @@ func connectMySQL(t *testing.T, testName string) *Datastore {
|
||||
}
|
||||
|
||||
// Create datastore client
|
||||
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
|
||||
var replicaOpt DBOption
|
||||
if opts.Replica {
|
||||
replicaConf := config
|
||||
replicaConf.Database += testReplicaDatabaseSuffix
|
||||
replicaOpt = Replica(&replicaConf)
|
||||
}
|
||||
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1), replicaOpt)
|
||||
require.Nil(t, err)
|
||||
|
||||
if opts.Replica {
|
||||
setupReadReplica(t, testName, ds, opts)
|
||||
}
|
||||
|
||||
return ds
|
||||
}
|
||||
|
||||
func setupReadReplica(t *testing.T, testName string, ds *Datastore, opts *DatastoreTestOptions) {
|
||||
// create the context that will cancel the replication goroutine on test exit
|
||||
var cancel func()
|
||||
ctx := context.Background()
|
||||
if dl, ok := t.Deadline(); ok {
|
||||
ctx, cancel = context.WithDeadline(ctx, dl)
|
||||
} else {
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
}
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// start the replication goroutine that runs when signalled through a
|
||||
// channel, the replication runs in lock-step - the test is in control of
|
||||
// when the replication happens, by calling opts.RunReplication(), and when
|
||||
// that call returns, the replication is guaranteed to be done. This supports
|
||||
// simulating all kinds of replica lag.
|
||||
ch := make(chan chan struct{})
|
||||
go func() {
|
||||
// if it exits because of a panic/failed replication, cancel the context
|
||||
// immediately so that RunReplication is unblocked too.
|
||||
defer cancel()
|
||||
|
||||
primary := ds.writer
|
||||
replica := ds.reader.(*sqlx.DB)
|
||||
replicaDB := testName + testReplicaDatabaseSuffix
|
||||
last := time.Now().Add(-time.Minute)
|
||||
|
||||
// drop all foreign keys in the replica, as that causes issues even with
|
||||
// FOREIGN_KEY_CHECKS=0
|
||||
var fks []struct {
|
||||
TableName string `db:"table_name"`
|
||||
ConstraintName string `db:"constraint_name"`
|
||||
}
|
||||
err := primary.SelectContext(ctx, &fks, `
|
||||
SELECT
|
||||
table_name, constraint_name
|
||||
FROM
|
||||
information_schema.key_column_usage
|
||||
WHERE
|
||||
table_schema = ? AND
|
||||
referenced_table_name IS NOT NULL`, testName)
|
||||
require.NoError(t, err)
|
||||
for _, fk := range fks {
|
||||
stmt := fmt.Sprintf(`ALTER TABLE %s.%s DROP FOREIGN KEY %s`, replicaDB, fk.TableName, fk.ConstraintName)
|
||||
_, err := replica.ExecContext(ctx, stmt)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case out := <-ch:
|
||||
// identify tables with changes since the last call
|
||||
var tables []string
|
||||
err := primary.SelectContext(ctx, &tables, `
|
||||
SELECT
|
||||
table_name
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
table_schema = ? AND
|
||||
table_type = 'BASE TABLE' AND
|
||||
update_time >= ?`, testName, last)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = primary.GetContext(ctx, &last, `
|
||||
SELECT
|
||||
MAX(update_time)
|
||||
FROM
|
||||
information_schema.tables
|
||||
WHERE
|
||||
table_schema = ? AND
|
||||
table_type = 'BASE TABLE'`, testName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// replicate by dropping the existing table and re-creating it from
|
||||
// the primary.
|
||||
for _, tbl := range tables {
|
||||
stmt := fmt.Sprintf(`DROP TABLE IF EXISTS %s.%s`, replicaDB, tbl)
|
||||
t.Log(stmt)
|
||||
_, err = replica.ExecContext(ctx, stmt)
|
||||
require.NoError(t, err)
|
||||
stmt = fmt.Sprintf(`CREATE TABLE %s.%s SELECT * FROM %s.%s`, replicaDB, tbl, testName, tbl)
|
||||
t.Log(stmt)
|
||||
_, err = replica.ExecContext(ctx, stmt)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
out <- struct{}{}
|
||||
t.Logf("replication step executed, next will consider updates since %s", last)
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// set RunReplication to a function that triggers the replication and waits
|
||||
// for it to complete.
|
||||
opts.RunReplication = func() {
|
||||
done := make(chan struct{})
|
||||
ch <- done
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initializeDatabase loads the dumped schema into a newly created database in
|
||||
// MySQL. This is much faster than running the full set of migrations on each
|
||||
// test.
|
||||
func initializeDatabase(t *testing.T, testName string) *Datastore {
|
||||
func initializeDatabase(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
base := path.Dir(filename)
|
||||
schema, err := ioutil.ReadFile(path.Join(base, "schema.sql"))
|
||||
@ -47,33 +170,60 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
// Load schema from dumpfile
|
||||
if out, err := exec.Command(
|
||||
"docker-compose", "exec", "-T", "mysql_test",
|
||||
// Command run inside container
|
||||
"mysql",
|
||||
"-u"+testUsername, "-p"+testPassword,
|
||||
"-e",
|
||||
fmt.Sprintf(
|
||||
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
|
||||
testName, testName, testName, schema,
|
||||
),
|
||||
).CombinedOutput(); err != nil {
|
||||
t.Error(err)
|
||||
t.Error(string(out))
|
||||
t.FailNow()
|
||||
|
||||
// execute the schema for the test db, and once more for the replica db if
|
||||
// that option is set.
|
||||
dbs := []string{testName}
|
||||
if opts.Replica {
|
||||
dbs = append(dbs, testName+testReplicaDatabaseSuffix)
|
||||
}
|
||||
return connectMySQL(t, testName)
|
||||
for _, dbName := range dbs {
|
||||
// Load schema from dumpfile
|
||||
if out, err := exec.Command(
|
||||
"docker-compose", "exec", "-T", "mysql_test",
|
||||
// Command run inside container
|
||||
"mysql",
|
||||
"-u"+testUsername, "-p"+testPassword,
|
||||
"-e",
|
||||
fmt.Sprintf(
|
||||
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
|
||||
dbName, dbName, dbName, schema,
|
||||
),
|
||||
).CombinedOutput(); err != nil {
|
||||
t.Error(err)
|
||||
t.Error(string(out))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
return connectMySQL(t, testName, opts)
|
||||
}
|
||||
|
||||
func CreateMySQLDS(t *testing.T) *Datastore {
|
||||
// DatastoreTestOptions configures how the test datastore is created
|
||||
// by CreateMySQLDSWithOptions.
|
||||
type DatastoreTestOptions struct {
|
||||
// Replica indicates that a read replica test database should be created.
|
||||
Replica bool
|
||||
|
||||
// RunReplication is the function to call to execute the replication of all
|
||||
// missing changes from the primary to the replica. The function is created
|
||||
// and set automatically by CreateMySQLDSWithOptions. The test is in full
|
||||
// control of when the replication is executed.
|
||||
RunReplication func()
|
||||
}
|
||||
|
||||
func createMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
|
||||
if _, ok := os.LookupEnv("MYSQL_TEST"); !ok {
|
||||
t.Skip("MySQL tests are disabled")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
pc, _, _, ok := runtime.Caller(1)
|
||||
if opts == nil {
|
||||
// so it is never nil in internal helper functions
|
||||
opts = new(DatastoreTestOptions)
|
||||
}
|
||||
|
||||
pc, _, _, ok := runtime.Caller(2)
|
||||
details := runtime.FuncForPC(pc)
|
||||
if !ok || details == nil {
|
||||
t.FailNow()
|
||||
@ -83,5 +233,13 @@ func CreateMySQLDS(t *testing.T) *Datastore {
|
||||
strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_",
|
||||
)
|
||||
cleanName = strings.ReplaceAll(cleanName, ".", "_")
|
||||
return initializeDatabase(t, cleanName)
|
||||
return initializeDatabase(t, cleanName, opts)
|
||||
}
|
||||
|
||||
func CreateMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
|
||||
return createMySQLDSWithOptions(t, opts)
|
||||
}
|
||||
|
||||
func CreateMySQLDS(t *testing.T) *Datastore {
|
||||
return createMySQLDSWithOptions(t, nil)
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func (d *Datastore) findUser(searchCol string, searchVal interface{}) (*fleet.Us
|
||||
|
||||
user := &fleet.User{}
|
||||
|
||||
err := d.db.Get(user, sqlStatement, searchVal)
|
||||
err := d.reader.Get(user, sqlStatement, searchVal)
|
||||
if err != nil && err == sql.ErrNoRows {
|
||||
return nil, notFound("User").
|
||||
WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal))
|
||||
@ -104,7 +104,7 @@ func (d *Datastore) ListUsers(opt fleet.UserListOptions) ([]*fleet.User, error)
|
||||
sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions)
|
||||
users := []*fleet.User{}
|
||||
|
||||
if err := d.db.Select(&users, sqlStatement, params...); err != nil {
|
||||
if err := d.reader.Select(&users, sqlStatement, params...); err != nil {
|
||||
return nil, errors.Wrap(err, "list users")
|
||||
}
|
||||
|
||||
@ -160,7 +160,7 @@ func (d *Datastore) saveUser(tx *sqlx.Tx, user *fleet.User) error {
|
||||
global_role = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := d.db.Exec(sqlStatement,
|
||||
result, err := tx.Exec(sqlStatement,
|
||||
user.Password,
|
||||
user.Salt,
|
||||
user.Name,
|
||||
@ -221,7 +221,7 @@ func (d *Datastore) loadTeamsForUsers(users []*fleet.User) error {
|
||||
fleet.UserTeam
|
||||
UserID uint `db:"user_id"`
|
||||
}
|
||||
if err := d.db.Select(&rows, sql, args...); err != nil {
|
||||
if err := d.reader.Select(&rows, sql, args...); err != nil {
|
||||
return errors.Wrap(err, "get loadTeamsForUsers")
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user