mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Add support for MySQL read replicas (#1786)
Adds configuration options to use a read-only MySQL replica, and uses it instead of the primary for reads.
This commit is contained in:
parent
2116217b6e
commit
3755a58070
1
changes/issue-1695-support-mysql-read-replica
Normal file
1
changes/issue-1695-support-mysql-read-replica
Normal file
@ -0,0 +1 @@
|
|||||||
|
* Add support for MySQL read replica.
|
@ -149,7 +149,11 @@ the way that the Fleet server works.
|
|||||||
var carveStore fleet.CarveStore
|
var carveStore fleet.CarveStore
|
||||||
mailService := mail.NewService()
|
mailService := mail.NewService()
|
||||||
|
|
||||||
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger))
|
var replicaOpt mysql.DBOption
|
||||||
|
if config.MysqlReadReplica.Address != "" {
|
||||||
|
replicaOpt = mysql.Replica(&config.MysqlReadReplica)
|
||||||
|
}
|
||||||
|
ds, err = mysql.New(config.Mysql, clock.C, mysql.Logger(logger), replicaOpt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
initFatal(err, "initializing datastore")
|
initFatal(err, "initializing datastore")
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ Your Fleet server's two main purposes are:
|
|||||||
- To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/)
|
- To serve as your [osquery TLS server](https://osquery.readthedocs.io/en/stable/deployment/remote/)
|
||||||
- To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc.
|
- To serve the Fleet web UI, which allows you to manage osquery configuration, query hosts, etc.
|
||||||
|
|
||||||
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
|
The Fleet server allows you persist configuration, manage users, etc. Thus, it needs a database. Fleet uses MySQL and requires you to supply configurations to connect to a MySQL server. It is also possible to configure connection to a MySQL replica in addition to the primary, to be used for reading only. Fleet also uses Redis to perform some more high-speed data access action throughout the lifecycle of the application (for example, distributed query result ingestion). Thus, Fleet also requires that you supply Redis connection configurations.
|
||||||
|
|
||||||
> Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance.
|
> Fleet does not support Redis Cluster or Redis Sentinel. Fleet can scale to hundreds of thousands of devices with a single Redis instance.
|
||||||
|
|
||||||
@ -118,12 +118,21 @@ mysql:
|
|||||||
address: 127.0.0.1:3306
|
address: 127.0.0.1:3306
|
||||||
```
|
```
|
||||||
|
|
||||||
|
And `mysql_read_replica_address` would be:
|
||||||
|
|
||||||
|
```
|
||||||
|
mysql_read_replica:
|
||||||
|
address: 127.0.0.1:3307
|
||||||
|
```
|
||||||
|
|
||||||
Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way.
|
Basically, just capitalize the option and prepend `FLEET_` to it in order to get the environment variable. The conversion works the same the opposite way.
|
||||||
|
|
||||||
All duration-based settings accept valid time units of `s`, `m`, `h`.
|
All duration-based settings accept valid time units of `s`, `m`, `h`.
|
||||||
|
|
||||||
##### MySQL
|
##### MySQL
|
||||||
|
|
||||||
|
This section describes the configuration options for the primary - if you also want to setup a read replica, the options are the same, except that the yaml section is `mysql_read_replica`, and the flags have the `mysql_read_replica_` prefix instead of `mysql_` (the corresponding environment variables follow the same transformation). Note that there is no default value for `mysql_read_replica_address`, it must be set explicitly for fleet to use a read replica.
|
||||||
|
|
||||||
###### `mysql_address`
|
###### `mysql_address`
|
||||||
|
|
||||||
The address of the MySQL server which Fleet should connect to. Include the hostname and port.
|
The address of the MySQL server which Fleet should connect to. Include the hostname and port.
|
||||||
|
@ -177,6 +177,7 @@ type VulnerabilitiesConfig struct {
|
|||||||
// updated to set and retrieve the configurations as appropriate.
|
// updated to set and retrieve the configurations as appropriate.
|
||||||
type FleetConfig struct {
|
type FleetConfig struct {
|
||||||
Mysql MysqlConfig
|
Mysql MysqlConfig
|
||||||
|
MysqlReadReplica MysqlConfig `yaml:"mysql_read_replica"`
|
||||||
Redis RedisConfig
|
Redis RedisConfig
|
||||||
Server ServerConfig
|
Server ServerConfig
|
||||||
Auth AuthConfig
|
Auth AuthConfig
|
||||||
@ -197,32 +198,36 @@ type FleetConfig struct {
|
|||||||
// addConfigs adds the configuration keys and default values that will be
|
// addConfigs adds the configuration keys and default values that will be
|
||||||
// filled into the FleetConfig struct
|
// filled into the FleetConfig struct
|
||||||
func (man Manager) addConfigs() {
|
func (man Manager) addConfigs() {
|
||||||
|
addMysqlConfig := func(prefix, defaultAddr, usageSuffix string) {
|
||||||
|
man.addConfigString(prefix+".protocol", "tcp",
|
||||||
|
"MySQL server communication protocol (tcp,unix,...)"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".address", defaultAddr,
|
||||||
|
"MySQL server address (host:port)"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".username", "fleet",
|
||||||
|
"MySQL server username"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".password", "",
|
||||||
|
"MySQL server password (prefer env variable for security)"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".password_path", "",
|
||||||
|
"Path to file containg MySQL server password"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".database", "fleet",
|
||||||
|
"MySQL database name"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".tls_cert", "",
|
||||||
|
"MySQL TLS client certificate path"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".tls_key", "",
|
||||||
|
"MySQL TLS client key path"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".tls_ca", "",
|
||||||
|
"MySQL TLS server CA"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".tls_server_name", "",
|
||||||
|
"MySQL TLS server name"+usageSuffix)
|
||||||
|
man.addConfigString(prefix+".tls_config", "",
|
||||||
|
"MySQL TLS config value"+usageSuffix+" Use skip-verify, true, false or custom key.")
|
||||||
|
man.addConfigInt(prefix+".max_open_conns", 50, "MySQL maximum open connection handles"+usageSuffix)
|
||||||
|
man.addConfigInt(prefix+".max_idle_conns", 50, "MySQL maximum idle connection handles"+usageSuffix)
|
||||||
|
man.addConfigInt(prefix+".conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused"+usageSuffix)
|
||||||
|
}
|
||||||
// MySQL
|
// MySQL
|
||||||
man.addConfigString("mysql.protocol", "tcp",
|
addMysqlConfig("mysql", "localhost:3306", ".")
|
||||||
"MySQL server communication protocol (tcp,unix,...)")
|
addMysqlConfig("mysql_read_replica", "", " for the read replica.")
|
||||||
man.addConfigString("mysql.address", "localhost:3306",
|
|
||||||
"MySQL server address (host:port)")
|
|
||||||
man.addConfigString("mysql.username", "fleet",
|
|
||||||
"MySQL server username")
|
|
||||||
man.addConfigString("mysql.password", "",
|
|
||||||
"MySQL server password (prefer env variable for security)")
|
|
||||||
man.addConfigString("mysql.password_path", "",
|
|
||||||
"Path to file containg MySQL server password")
|
|
||||||
man.addConfigString("mysql.database", "fleet",
|
|
||||||
"MySQL database name")
|
|
||||||
man.addConfigString("mysql.tls_cert", "",
|
|
||||||
"MySQL TLS client certificate path")
|
|
||||||
man.addConfigString("mysql.tls_key", "",
|
|
||||||
"MySQL TLS client key path")
|
|
||||||
man.addConfigString("mysql.tls_ca", "",
|
|
||||||
"MySQL TLS server CA")
|
|
||||||
man.addConfigString("mysql.tls_server_name", "",
|
|
||||||
"MySQL TLS server name")
|
|
||||||
man.addConfigString("mysql.tls_config", "",
|
|
||||||
"MySQL TLS config value. Use skip-verify, true, false or custom key.")
|
|
||||||
man.addConfigInt("mysql.max_open_conns", 50, "MySQL maximum open connection handles.")
|
|
||||||
man.addConfigInt("mysql.max_idle_conns", 50, "MySQL maximum idle connection handles.")
|
|
||||||
man.addConfigInt("mysql.conn_max_lifetime", 0, "MySQL maximum amount of time a connection may be reused.")
|
|
||||||
|
|
||||||
// Redis
|
// Redis
|
||||||
man.addConfigString("redis.address", "localhost:6379",
|
man.addConfigString("redis.address", "localhost:6379",
|
||||||
@ -382,23 +387,28 @@ func (man Manager) addConfigs() {
|
|||||||
func (man Manager) LoadConfig() FleetConfig {
|
func (man Manager) LoadConfig() FleetConfig {
|
||||||
man.loadConfigFile()
|
man.loadConfigFile()
|
||||||
|
|
||||||
|
loadMysqlConfig := func(prefix string) MysqlConfig {
|
||||||
|
return MysqlConfig{
|
||||||
|
Protocol: man.getConfigString(prefix + ".protocol"),
|
||||||
|
Address: man.getConfigString(prefix + ".address"),
|
||||||
|
Username: man.getConfigString(prefix + ".username"),
|
||||||
|
Password: man.getConfigString(prefix + ".password"),
|
||||||
|
PasswordPath: man.getConfigString(prefix + ".password_path"),
|
||||||
|
Database: man.getConfigString(prefix + ".database"),
|
||||||
|
TLSCert: man.getConfigString(prefix + ".tls_cert"),
|
||||||
|
TLSKey: man.getConfigString(prefix + ".tls_key"),
|
||||||
|
TLSCA: man.getConfigString(prefix + ".tls_ca"),
|
||||||
|
TLSServerName: man.getConfigString(prefix + ".tls_server_name"),
|
||||||
|
TLSConfig: man.getConfigString(prefix + ".tls_config"),
|
||||||
|
MaxOpenConns: man.getConfigInt(prefix + ".max_open_conns"),
|
||||||
|
MaxIdleConns: man.getConfigInt(prefix + ".max_idle_conns"),
|
||||||
|
ConnMaxLifetime: man.getConfigInt(prefix + ".conn_max_lifetime"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return FleetConfig{
|
return FleetConfig{
|
||||||
Mysql: MysqlConfig{
|
Mysql: loadMysqlConfig("mysql"),
|
||||||
Protocol: man.getConfigString("mysql.protocol"),
|
MysqlReadReplica: loadMysqlConfig("mysql_read_replica"),
|
||||||
Address: man.getConfigString("mysql.address"),
|
|
||||||
Username: man.getConfigString("mysql.username"),
|
|
||||||
Password: man.getConfigString("mysql.password"),
|
|
||||||
PasswordPath: man.getConfigString("mysql.password_path"),
|
|
||||||
Database: man.getConfigString("mysql.database"),
|
|
||||||
TLSCert: man.getConfigString("mysql.tls_cert"),
|
|
||||||
TLSKey: man.getConfigString("mysql.tls_key"),
|
|
||||||
TLSCA: man.getConfigString("mysql.tls_ca"),
|
|
||||||
TLSServerName: man.getConfigString("mysql.tls_server_name"),
|
|
||||||
TLSConfig: man.getConfigString("mysql.tls_config"),
|
|
||||||
MaxOpenConns: man.getConfigInt("mysql.max_open_conns"),
|
|
||||||
MaxIdleConns: man.getConfigInt("mysql.max_idle_conns"),
|
|
||||||
ConnMaxLifetime: man.getConfigInt("mysql.conn_max_lifetime"),
|
|
||||||
},
|
|
||||||
Redis: RedisConfig{
|
Redis: RedisConfig{
|
||||||
Address: man.getConfigString("redis.address"),
|
Address: man.getConfigString("redis.address"),
|
||||||
Password: man.getConfigString("redis.password"),
|
Password: man.getConfigString("redis.password"),
|
||||||
|
@ -58,6 +58,7 @@ func TestConfigRoundtrip(t *testing.T) {
|
|||||||
// Marshal the generated config
|
// Marshal the generated config
|
||||||
buf, err := yaml.Marshal(original)
|
buf, err := yaml.Marshal(original)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
t.Log(string(buf))
|
||||||
|
|
||||||
// Manually load the serialized config
|
// Manually load the serialized config
|
||||||
man.viper.SetConfigType("yaml")
|
man.viper.SetConfigType("yaml")
|
||||||
|
@ -3,6 +3,7 @@ package mysql
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
@ -13,7 +14,7 @@ func (d *Datastore) NewActivity(user *fleet.User, activityType string, details *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "marshaling activity details")
|
return errors.Wrap(err, "marshaling activity details")
|
||||||
}
|
}
|
||||||
_, err = d.db.Exec(
|
_, err = d.writer.Exec(
|
||||||
`INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`,
|
`INSERT INTO activities (user_id, user_name, activity_type, details) VALUES(?,?,?,?)`,
|
||||||
user.ID,
|
user.ID,
|
||||||
user.Name,
|
user.Name,
|
||||||
@ -34,7 +35,7 @@ func (d *Datastore) ListActivities(opt fleet.ListOptions) ([]*fleet.Activity, er
|
|||||||
WHERE true`
|
WHERE true`
|
||||||
query = appendListOptionsToSQL(query, opt)
|
query = appendListOptionsToSQL(query, opt)
|
||||||
|
|
||||||
err := d.db.Select(&activities, query)
|
err := d.reader.Select(&activities, query)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Activity")
|
return nil, notFound("Activity")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
@ -25,7 +25,7 @@ func (d *Datastore) NewAppConfig(info *fleet.AppConfig) (*fleet.AppConfig, error
|
|||||||
func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
||||||
info := &fleet.AppConfig{}
|
info := &fleet.AppConfig{}
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
err := d.db.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
|
err := d.reader.Get(&bytes, `SELECT json_value FROM app_config_json LIMIT 1`)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, errors.Wrap(err, "selecting app config")
|
return nil, errors.Wrap(err, "selecting app config")
|
||||||
}
|
}
|
||||||
@ -43,7 +43,7 @@ func (d *Datastore) AppConfig() (*fleet.AppConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) isEventSchedulerEnabled() (bool, error) {
|
func (d *Datastore) isEventSchedulerEnabled() (bool, error) {
|
||||||
rows, err := d.db.Query("SELECT @@event_scheduler")
|
rows, err := d.writer.Query("SELECT @@event_scheduler")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -136,7 +136,7 @@ func (d *Datastore) SaveAppConfig(info *fleet.AppConfig) error {
|
|||||||
|
|
||||||
func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) {
|
func (d *Datastore) VerifyEnrollSecret(secret string) (*fleet.EnrollSecret, error) {
|
||||||
var s fleet.EnrollSecret
|
var s fleet.EnrollSecret
|
||||||
err := d.db.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
err := d.reader.Get(&s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("no matching secret found")
|
return nil, errors.New("no matching secret found")
|
||||||
}
|
}
|
||||||
@ -184,7 +184,7 @@ func (d *Datastore) GetEnrollSecrets(teamID *uint) ([]*fleet.EnrollSecret, error
|
|||||||
args = append(args, teamID)
|
args = append(args, teamID)
|
||||||
}
|
}
|
||||||
var secrets []*fleet.EnrollSecret
|
var secrets []*fleet.EnrollSecret
|
||||||
if err := d.db.Select(&secrets, sql, args...); err != nil {
|
if err := d.reader.Select(&secrets, sql, args...); err != nil {
|
||||||
return nil, errors.Wrap(err, "get secrets")
|
return nil, errors.Wrap(err, "get secrets")
|
||||||
}
|
}
|
||||||
return secrets, nil
|
return secrets, nil
|
||||||
|
@ -259,7 +259,7 @@ func TestAppConfigDefaults(t *testing.T) {
|
|||||||
defer ds.Close()
|
defer ds.Close()
|
||||||
|
|
||||||
insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)`
|
insertAppConfigQuery := `INSERT INTO app_config_json(json_value) VALUES(?) ON DUPLICATE KEY UPDATE json_value = VALUES(json_value)`
|
||||||
_, err := ds.db.Exec(insertAppConfigQuery, `{}`)
|
_, err := ds.writer.Exec(insertAppConfigQuery, `{}`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ac, err := ds.AppConfig()
|
ac, err := ds.AppConfig()
|
||||||
@ -269,7 +269,7 @@ func TestAppConfigDefaults(t *testing.T) {
|
|||||||
require.True(t, ac.HostSettings.EnableHostUsers)
|
require.True(t, ac.HostSettings.EnableHostUsers)
|
||||||
require.False(t, ac.HostSettings.EnableSoftwareInventory)
|
require.False(t, ac.HostSettings.EnableSoftwareInventory)
|
||||||
|
|
||||||
_, err = ds.db.Exec(
|
_, err = ds.writer.Exec(
|
||||||
insertAppConfigQuery,
|
insertAppConfigQuery,
|
||||||
`{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`,
|
`{"webhook_settings": {"interval": "12h"}, "host_settings": {"enable_host_users": false}}`,
|
||||||
)
|
)
|
||||||
|
@ -17,7 +17,7 @@ func (d *Datastore) NewDistributedQueryCampaign(camp *fleet.DistributedQueryCamp
|
|||||||
)
|
)
|
||||||
VALUES(?,?,?)
|
VALUES(?,?,?)
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
|
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "inserting distributed query campaign")
|
return nil, errors.Wrap(err, "inserting distributed query campaign")
|
||||||
}
|
}
|
||||||
@ -32,7 +32,7 @@ func (d *Datastore) DistributedQueryCampaign(id uint) (*fleet.DistributedQueryCa
|
|||||||
SELECT * FROM distributed_query_campaigns WHERE id = ?
|
SELECT * FROM distributed_query_campaigns WHERE id = ?
|
||||||
`
|
`
|
||||||
campaign := &fleet.DistributedQueryCampaign{}
|
campaign := &fleet.DistributedQueryCampaign{}
|
||||||
if err := d.db.Get(campaign, sql, id); err != nil {
|
if err := d.reader.Get(campaign, sql, id); err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting distributed query campaign")
|
return nil, errors.Wrap(err, "selecting distributed query campaign")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ func (d *Datastore) SaveDistributedQueryCampaign(camp *fleet.DistributedQueryCam
|
|||||||
user_id = ?
|
user_id = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
|
result, err := d.writer.Exec(sqlStatement, camp.QueryID, camp.Status, camp.UserID, camp.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "updating distributed query campaign")
|
return errors.Wrap(err, "updating distributed query campaign")
|
||||||
}
|
}
|
||||||
@ -68,7 +68,7 @@ func (d *Datastore) DistributedQueryCampaignTargetIDs(id uint) (*fleet.HostTarge
|
|||||||
`
|
`
|
||||||
targets := []fleet.DistributedQueryCampaignTarget{}
|
targets := []fleet.DistributedQueryCampaignTarget{}
|
||||||
|
|
||||||
if err := d.db.Select(&targets, sqlStatement, id); err != nil {
|
if err := d.reader.Select(&targets, sqlStatement, id); err != nil {
|
||||||
return nil, errors.Wrap(err, "select distributed campaign target")
|
return nil, errors.Wrap(err, "select distributed campaign target")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ func (d *Datastore) NewDistributedQueryCampaignTarget(target *fleet.DistributedQ
|
|||||||
)
|
)
|
||||||
VALUES (?,?,?)
|
VALUES (?,?,?)
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
|
result, err := d.writer.Exec(sqlStatement, target.Type, target.DistributedQueryCampaignID, target.TargetID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "insert distributed campaign target")
|
return nil, errors.Wrap(err, "insert distributed campaign target")
|
||||||
}
|
}
|
||||||
@ -118,7 +118,7 @@ func (d *Datastore) CleanupDistributedQueryCampaigns(now time.Time) (expired uin
|
|||||||
WHERE (status = ? AND created_at < ?)
|
WHERE (status = ? AND created_at < ?)
|
||||||
OR (status = ? AND created_at < ?)
|
OR (status = ? AND created_at < ?)
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, fleet.QueryComplete,
|
result, err := d.writer.Exec(sqlStatement, fleet.QueryComplete,
|
||||||
fleet.QueryWaiting, now.Add(-1*time.Minute),
|
fleet.QueryWaiting, now.Add(-1*time.Minute),
|
||||||
fleet.QueryRunning, now.Add(-24*time.Hour))
|
fleet.QueryRunning, now.Add(-24*time.Hour))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -33,7 +33,7 @@ func (d *Datastore) NewCarve(metadata *fleet.CarveMetadata) (*fleet.CarveMetadat
|
|||||||
?
|
?
|
||||||
)`
|
)`
|
||||||
|
|
||||||
result, err := d.db.Exec(
|
result, err := d.writer.Exec(
|
||||||
stmt,
|
stmt,
|
||||||
metadata.HostId,
|
metadata.HostId,
|
||||||
metadata.CreatedAt.Format(mySQLTimestampFormat),
|
metadata.CreatedAt.Format(mySQLTimestampFormat),
|
||||||
@ -64,7 +64,7 @@ func (d *Datastore) UpdateCarve(metadata *fleet.CarveMetadata) error {
|
|||||||
expired = ?
|
expired = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
_, err := d.db.Exec(
|
_, err := d.writer.Exec(
|
||||||
stmt,
|
stmt,
|
||||||
metadata.MaxBlock,
|
metadata.MaxBlock,
|
||||||
metadata.Expired,
|
metadata.Expired,
|
||||||
@ -163,7 +163,7 @@ func (d *Datastore) Carve(carveId int64) (*fleet.CarveMetadata, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
var metadata fleet.CarveMetadata
|
var metadata fleet.CarveMetadata
|
||||||
if err := d.db.Get(&metadata, stmt, carveId); err != nil {
|
if err := d.reader.Get(&metadata, stmt, carveId); err != nil {
|
||||||
return nil, errors.Wrap(err, "get carve by ID")
|
return nil, errors.Wrap(err, "get carve by ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ func (d *Datastore) CarveBySessionId(sessionId string) (*fleet.CarveMetadata, er
|
|||||||
)
|
)
|
||||||
|
|
||||||
var metadata fleet.CarveMetadata
|
var metadata fleet.CarveMetadata
|
||||||
if err := d.db.Get(&metadata, stmt, sessionId); err != nil {
|
if err := d.reader.Get(&metadata, stmt, sessionId); err != nil {
|
||||||
return nil, errors.Wrap(err, "get carve by session ID")
|
return nil, errors.Wrap(err, "get carve by session ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ func (d *Datastore) CarveByName(name string) (*fleet.CarveMetadata, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
var metadata fleet.CarveMetadata
|
var metadata fleet.CarveMetadata
|
||||||
if err := d.db.Get(&metadata, stmt, name); err != nil {
|
if err := d.reader.Get(&metadata, stmt, name); err != nil {
|
||||||
return nil, errors.Wrap(err, "get carve by name")
|
return nil, errors.Wrap(err, "get carve by name")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ func (d *Datastore) ListCarves(opt fleet.CarveListOptions) ([]*fleet.CarveMetada
|
|||||||
}
|
}
|
||||||
stmt = appendListOptionsToSQL(stmt, opt.ListOptions)
|
stmt = appendListOptionsToSQL(stmt, opt.ListOptions)
|
||||||
carves := []*fleet.CarveMetadata{}
|
carves := []*fleet.CarveMetadata{}
|
||||||
if err := d.db.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
|
if err := d.reader.Select(&carves, stmt); err != nil && err != sql.ErrNoRows {
|
||||||
return nil, errors.Wrap(err, "list carves")
|
return nil, errors.Wrap(err, "list carves")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,13 +231,14 @@ func (d *Datastore) NewBlock(metadata *fleet.CarveMetadata, blockId int64, data
|
|||||||
?,
|
?,
|
||||||
?
|
?
|
||||||
)`
|
)`
|
||||||
if _, err := d.db.Exec(stmt, metadata.ID, blockId, data); err != nil {
|
if _, err := d.writer.Exec(stmt, metadata.ID, blockId, data); err != nil {
|
||||||
return errors.Wrap(err, "insert carve block")
|
return errors.Wrap(err, "insert carve block")
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata.MaxBlock < blockId {
|
if metadata.MaxBlock < blockId {
|
||||||
// Update max_block
|
// Update max_block
|
||||||
metadata.MaxBlock = blockId
|
metadata.MaxBlock = blockId
|
||||||
|
// TODO: looks like this should all be done in an atomic transaction?
|
||||||
if err := d.UpdateCarve(metadata); err != nil {
|
if err := d.UpdateCarve(metadata); err != nil {
|
||||||
return errors.Wrap(err, "insert carve block")
|
return errors.Wrap(err, "insert carve block")
|
||||||
}
|
}
|
||||||
@ -253,7 +254,7 @@ func (d *Datastore) GetBlock(metadata *fleet.CarveMetadata, blockId int64) ([]by
|
|||||||
WHERE metadata_id = ? AND block_id = ?
|
WHERE metadata_id = ? AND block_id = ?
|
||||||
`
|
`
|
||||||
var data []byte
|
var data []byte
|
||||||
if err := d.db.Get(&data, stmt, metadata.ID, blockId); err != nil {
|
if err := d.reader.Get(&data, stmt, metadata.ID, blockId); err != nil {
|
||||||
return nil, errors.Wrap(err, "select data")
|
return nil, errors.Wrap(err, "select data")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import "github.com/go-kit/kit/log"
|
import (
|
||||||
|
"github.com/fleetdm/fleet/v4/server/config"
|
||||||
|
"github.com/go-kit/kit/log"
|
||||||
|
)
|
||||||
|
|
||||||
const defaultMaxAttempts int = 15
|
const defaultMaxAttempts int = 15
|
||||||
|
|
||||||
@ -11,6 +14,7 @@ type dbOptions struct {
|
|||||||
// maxAttempts configures the number of retries to connect to the DB
|
// maxAttempts configures the number of retries to connect to the DB
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
replicaConfig *config.MysqlConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger adds a logger to the datastore
|
// Logger adds a logger to the datastore
|
||||||
@ -21,6 +25,14 @@ func Logger(l log.Logger) DBOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replica sets the configuration of the read replica for the datastore.
|
||||||
|
func Replica(conf *config.MysqlConfig) DBOption {
|
||||||
|
return func(o *dbOptions) error {
|
||||||
|
o.replicaConfig = conf
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LimitAttempts sets a the number of attempts
|
// LimitAttempts sets a the number of attempts
|
||||||
// to try establishing a connection to the database backend
|
// to try establishing a connection to the database backend
|
||||||
// the default value is 15 attempts
|
// the default value is 15 attempts
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
// returning a notFound error if appropriate.
|
// returning a notFound error if appropriate.
|
||||||
func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
||||||
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable)
|
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, dbTable)
|
||||||
result, err := d.db.Exec(deleteStmt, id)
|
result, err := d.writer.Exec(deleteStmt, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "delete %s", dbTable)
|
return errors.Wrapf(err, "delete %s", dbTable)
|
||||||
}
|
}
|
||||||
@ -26,7 +26,7 @@ func (d *Datastore) deleteEntity(dbTable string, id uint) error {
|
|||||||
// table, returning a notFound error if appropriate.
|
// table, returning a notFound error if appropriate.
|
||||||
func (d *Datastore) deleteEntityByName(dbTable string, name string) error {
|
func (d *Datastore) deleteEntityByName(dbTable string, name string) error {
|
||||||
deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable)
|
deleteStmt := fmt.Sprintf("DELETE FROM %s WHERE name = ?", dbTable)
|
||||||
result, err := d.db.Exec(deleteStmt, name)
|
result, err := d.writer.Exec(deleteStmt, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isMySQLForeignKey(err) {
|
if isMySQLForeignKey(err) {
|
||||||
return foreignKey(dbTable, name)
|
return foreignKey(dbTable, name)
|
||||||
@ -50,7 +50,7 @@ func (d *Datastore) deleteEntities(dbTable string, ids []uint) (uint, error) {
|
|||||||
return 0, errors.Wrapf(err, "building delete entities query %s", dbTable)
|
return 0, errors.Wrapf(err, "building delete entities query %s", dbTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := d.db.Exec(query, args...)
|
result, err := d.writer.Exec(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable)
|
return 0, errors.Wrapf(err, "executing delete entities query %s", dbTable)
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ func (ds *Datastore) PendingEmailChange(uid uint, newEmail, token string) error
|
|||||||
new_email
|
new_email
|
||||||
) VALUES( ?, ?, ? )
|
) VALUES( ?, ?, ? )
|
||||||
`
|
`
|
||||||
_, err := ds.db.Exec(sqlStatement, uid, token, newEmail)
|
_, err := ds.writer.Exec(sqlStatement, uid, token, newEmail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "inserting email change record")
|
return errors.Wrap(err, "inserting email change record")
|
||||||
}
|
}
|
||||||
@ -32,7 +32,7 @@ func (ds *Datastore) ConfirmPendingEmailChange(id uint, token string) (newEmail
|
|||||||
Token string
|
Token string
|
||||||
NewEmail string `db:"new_email"`
|
NewEmail string `db:"new_email"`
|
||||||
}{}
|
}{}
|
||||||
err = ds.db.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
|
err = ds.writer.Get(&changeRecord, "SELECT * FROM email_changes WHERE token = ? AND user_id = ?", token, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", notFound("email change with token")
|
return "", notFound("email change with token")
|
||||||
|
@ -33,7 +33,7 @@ func (d *Datastore) NewHost(host *fleet.Host) (*fleet.Host, error) {
|
|||||||
)
|
)
|
||||||
VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? )
|
VALUES( ?,?,?,?,?,?,?,?,?,?,?,?,? )
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(
|
result, err := d.writer.Exec(
|
||||||
sqlStatement,
|
sqlStatement,
|
||||||
host.OsqueryHostID,
|
host.OsqueryHostID,
|
||||||
host.DetailUpdatedAt,
|
host.DetailUpdatedAt,
|
||||||
@ -95,7 +95,7 @@ func (d *Datastore) SaveHost(host *fleet.Host) error {
|
|||||||
percent_disk_space_available = ?
|
percent_disk_space_available = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
_, err := d.db.Exec(sqlStatement,
|
_, err := d.writer.Exec(sqlStatement,
|
||||||
host.DetailUpdatedAt,
|
host.DetailUpdatedAt,
|
||||||
host.LabelUpdatedAt,
|
host.LabelUpdatedAt,
|
||||||
host.NodeKey,
|
host.NodeKey,
|
||||||
@ -248,7 +248,7 @@ FROM scheduled_query_stats sqs
|
|||||||
WHERE host_id = ? AND p.pack_type IS NULL
|
WHERE host_id = ? AND p.pack_type IS NULL
|
||||||
`
|
`
|
||||||
var stats []fleet.ScheduledQueryStats
|
var stats []fleet.ScheduledQueryStats
|
||||||
if err := d.db.Select(&stats, sql, host.ID); err != nil {
|
if err := d.reader.Select(&stats, sql, host.ID); err != nil {
|
||||||
return errors.Wrap(err, "load pack stats")
|
return errors.Wrap(err, "load pack stats")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -268,9 +268,9 @@ WHERE host_id = ? AND p.pack_type IS NULL
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) loadHostUsers(host *fleet.Host) error {
|
func (d *Datastore) loadHostUsers(db dbReader, host *fleet.Host) error {
|
||||||
sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL`
|
sql := `SELECT id, username, groupname, uid, user_type FROM host_users WHERE host_id = ? and removed_at IS NULL`
|
||||||
if err := d.db.Select(&host.Users, sql, host.ID); err != nil {
|
if err := db.Select(&host.Users, sql, host.ID); err != nil {
|
||||||
return errors.Wrap(err, "load pack stats")
|
return errors.Wrap(err, "load pack stats")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -292,14 +292,14 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
|
|||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
host := &fleet.Host{}
|
host := &fleet.Host{}
|
||||||
err := d.db.Get(host, sqlStatement, id)
|
err := d.reader.Get(host, sqlStatement, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "get host by id")
|
return nil, errors.Wrap(err, "get host by id")
|
||||||
}
|
}
|
||||||
if err := d.loadHostPackStats(host); err != nil {
|
if err := d.loadHostPackStats(host); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := d.loadHostUsers(host); err != nil {
|
if err := d.loadHostUsers(d.reader, host); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,7 +308,9 @@ func (d *Datastore) Host(id uint) (*fleet.Host, error) {
|
|||||||
|
|
||||||
func (d *Datastore) amountEnrolledHosts() (int, error) {
|
func (d *Datastore) amountEnrolledHosts() (int, error) {
|
||||||
var amount int
|
var amount int
|
||||||
err := d.db.Get(&amount, `SELECT count(*) FROM hosts`)
|
// uses the writer as it is called from ShouldSendStatistics, which is a
|
||||||
|
// write method.
|
||||||
|
err := d.writer.Get(&amount, `SELECT count(*) FROM hosts`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -363,7 +365,7 @@ func (d *Datastore) ListHosts(filter fleet.TeamFilter, opt fleet.HostListOptions
|
|||||||
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
||||||
|
|
||||||
hosts := []*fleet.Host{}
|
hosts := []*fleet.Host{}
|
||||||
if err := d.db.Select(&hosts, sql, params...); err != nil {
|
if err := d.reader.Select(&hosts, sql, params...); err != nil {
|
||||||
return nil, errors.Wrap(err, "list hosts")
|
return nil, errors.Wrap(err, "list hosts")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -413,7 +415,7 @@ func (d *Datastore) CleanupIncomingHosts(now time.Time) error {
|
|||||||
WHERE hostname = '' AND osquery_version = ''
|
WHERE hostname = '' AND osquery_version = ''
|
||||||
AND created_at < (? - INTERVAL 5 MINUTE)
|
AND created_at < (? - INTERVAL 5 MINUTE)
|
||||||
`
|
`
|
||||||
if _, err := d.db.Exec(sqlStatement, now); err != nil {
|
if _, err := d.writer.Exec(sqlStatement, now); err != nil {
|
||||||
return errors.Wrap(err, "cleanup incoming hosts")
|
return errors.Wrap(err, "cleanup incoming hosts")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,7 +444,7 @@ func (d *Datastore) GenerateHostStatusStatistics(filter fleet.TeamFilter, now ti
|
|||||||
Online uint `db:"online"`
|
Online uint `db:"online"`
|
||||||
New uint `db:"new"`
|
New uint `db:"new"`
|
||||||
}{}
|
}{}
|
||||||
err := d.db.Get(&counts, sqlStatement, now, now, now, now, now)
|
err := d.reader.Get(&counts, sqlStatement, now, now, now, now, now)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
e = errors.Wrap(err, "generating host statistics")
|
e = errors.Wrap(err, "generating host statistics")
|
||||||
return
|
return
|
||||||
@ -582,7 +584,7 @@ func (d *Datastore) AuthenticateHost(nodeKey string) (*fleet.Host, error) {
|
|||||||
`
|
`
|
||||||
|
|
||||||
host := &fleet.Host{}
|
host := &fleet.Host{}
|
||||||
if err := d.db.Get(host, sqlStatement, nodeKey); err != nil {
|
if err := d.reader.Get(host, sqlStatement, nodeKey); err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows:
|
case sql.ErrNoRows:
|
||||||
return nil, notFound("Host")
|
return nil, notFound("Host")
|
||||||
@ -601,7 +603,7 @@ func (d *Datastore) MarkHostSeen(host *fleet.Host, t time.Time) error {
|
|||||||
WHERE node_key=?
|
WHERE node_key=?
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err := d.db.Exec(sqlStatement, t, host.NodeKey)
|
_, err := d.writer.Exec(sqlStatement, t, host.NodeKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "marking host seen")
|
return errors.Wrap(err, "marking host seen")
|
||||||
}
|
}
|
||||||
@ -659,11 +661,11 @@ func (d *Datastore) searchHostsWithOmits(filter fleet.TeamFilter, query string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "searching hosts")
|
return nil, errors.Wrap(err, "searching hosts")
|
||||||
}
|
}
|
||||||
sql = d.db.Rebind(sql)
|
sql = d.reader.Rebind(sql)
|
||||||
|
|
||||||
hosts := []*fleet.Host{}
|
hosts := []*fleet.Host{}
|
||||||
|
|
||||||
err = d.db.Select(&hosts, sql, args...)
|
err = d.reader.Select(&hosts, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "searching hosts rebound")
|
return nil, errors.Wrap(err, "searching hosts rebound")
|
||||||
}
|
}
|
||||||
@ -695,8 +697,8 @@ func (d *Datastore) searchHostsDefault(filter fleet.TeamFilter, omit ...uint) ([
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "searching default hosts")
|
return nil, errors.Wrap(err, "searching default hosts")
|
||||||
}
|
}
|
||||||
sql = d.db.Rebind(sql)
|
sql = d.reader.Rebind(sql)
|
||||||
err = d.db.Select(&hosts, sql, args...)
|
err = d.reader.Select(&hosts, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "searching default hosts rebound")
|
return nil, errors.Wrap(err, "searching default hosts rebound")
|
||||||
}
|
}
|
||||||
@ -730,7 +732,7 @@ func (d *Datastore) SearchHosts(filter fleet.TeamFilter, query string, omit ...u
|
|||||||
)
|
)
|
||||||
|
|
||||||
hosts := []*fleet.Host{}
|
hosts := []*fleet.Host{}
|
||||||
if err := d.db.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
|
if err := d.reader.Select(&hosts, sql, hostQuery, ipQuery); err != nil {
|
||||||
return nil, errors.Wrap(err, "searching hosts")
|
return nil, errors.Wrap(err, "searching hosts")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -755,7 +757,7 @@ func (d *Datastore) HostIDsByName(filter fleet.TeamFilter, hostnames []string) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
var hostIDs []uint
|
var hostIDs []uint
|
||||||
if err := d.db.Select(&hostIDs, sql, args...); err != nil {
|
if err := d.reader.Select(&hostIDs, sql, args...); err != nil {
|
||||||
return nil, errors.Wrap(err, "get host IDs")
|
return nil, errors.Wrap(err, "get host IDs")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -770,7 +772,7 @@ func (d *Datastore) HostByIdentifier(identifier string) (*fleet.Host, error) {
|
|||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
host := &fleet.Host{}
|
host := &fleet.Host{}
|
||||||
err := d.db.Get(host, sql, identifier)
|
err := d.reader.Get(host, sql, identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "get host by identifier")
|
return nil, errors.Wrap(err, "get host by identifier")
|
||||||
}
|
}
|
||||||
@ -796,7 +798,7 @@ func (d *Datastore) AddHostsToTeam(teamID *uint, hostIDs []uint) error {
|
|||||||
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
|
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := d.db.Exec(sql, args...); err != nil {
|
if _, err := d.writer.Exec(sql, args...); err != nil {
|
||||||
return errors.Wrap(err, "exec AddHostsToTeam")
|
return errors.Wrap(err, "exec AddHostsToTeam")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -809,7 +811,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
|
|||||||
VALUES (?, ?)
|
VALUES (?, ?)
|
||||||
ON DUPLICATE KEY UPDATE additional = VALUES(additional)
|
ON DUPLICATE KEY UPDATE additional = VALUES(additional)
|
||||||
`
|
`
|
||||||
if _, err := d.db.Exec(sql, host.ID, host.Additional); err != nil {
|
if _, err := d.writer.Exec(sql, host.ID, host.Additional); err != nil {
|
||||||
return errors.Wrap(err, "insert additional")
|
return errors.Wrap(err, "insert additional")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -818,7 +820,7 @@ func (d *Datastore) SaveHostAdditional(host *fleet.Host) error {
|
|||||||
|
|
||||||
func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
||||||
if len(host.Users) == 0 {
|
if len(host.Users) == 0 {
|
||||||
if _, err := d.db.Exec(
|
if _, err := d.writer.Exec(
|
||||||
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`,
|
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE host_id = ?`,
|
||||||
host.ID,
|
host.ID,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
@ -829,7 +831,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
currentHost := &fleet.Host{ID: host.ID}
|
currentHost := &fleet.Host{ID: host.ID}
|
||||||
if err := d.loadHostUsers(currentHost); err != nil {
|
if err := d.loadHostUsers(d.writer, currentHost); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -837,7 +839,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
|||||||
for _, u := range host.Users {
|
for _, u := range host.Users {
|
||||||
incomingUsers[u.Uid] = true
|
incomingUsers[u.Uid] = true
|
||||||
|
|
||||||
if _, err := d.db.Exec(
|
if _, err := d.writer.Exec(
|
||||||
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
|
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
|
||||||
host.ID, u.Uid, u.Username, u.Type, u.GroupName,
|
host.ID, u.Uid, u.Username, u.Type, u.GroupName,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
@ -860,7 +862,7 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
|||||||
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`,
|
`UPDATE host_users SET removed_at = CURRENT_TIMESTAMP WHERE id IN (%s)`,
|
||||||
removedValues,
|
removedValues,
|
||||||
)
|
)
|
||||||
if _, err := d.db.Exec(removedSql, removedArgs...); err != nil {
|
if _, err := d.writer.Exec(removedSql, removedArgs...); err != nil {
|
||||||
return errors.Wrap(err, "mark users as removed")
|
return errors.Wrap(err, "mark users as removed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -869,15 +871,15 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
|
|||||||
|
|
||||||
func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) {
|
func (d *Datastore) TotalAndUnseenHostsSince(daysCount int) (int, int, error) {
|
||||||
var totalCount, unseenCount int
|
var totalCount, unseenCount int
|
||||||
err := d.db.QueryRow("SELECT count(*) FROM hosts").Scan(&totalCount)
|
err := d.reader.Get(&totalCount, "SELECT count(*) FROM hosts")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, errors.Wrap(err, "getting total host count")
|
return 0, 0, errors.Wrap(err, "getting total host count")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = d.db.QueryRow(
|
err = d.reader.Get(&unseenCount,
|
||||||
"SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?",
|
"SELECT count(*) FROM hosts WHERE DATEDIFF(CURRENT_DATE, seen_time) >= ?",
|
||||||
daysCount,
|
daysCount,
|
||||||
).Scan(&unseenCount)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, errors.Wrap(err, "getting unseen host count")
|
return 0, 0, errors.Wrap(err, "getting unseen host count")
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
|
|||||||
VALUES ( ?, ?, ?, ?, ?, ?, ?)
|
VALUES ( ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
result, err := d.db.Exec(sqlStmt, i.InvitedBy, i.Email,
|
result, err := d.writer.Exec(sqlStmt, i.InvitedBy, i.Email,
|
||||||
i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole)
|
i.Name, i.Position, i.Token, i.SSOEnabled, i.GlobalRole)
|
||||||
if err != nil && isDuplicate(err) {
|
if err != nil && isDuplicate(err) {
|
||||||
return nil, alreadyExists("Invite", i.Email)
|
return nil, alreadyExists("Invite", i.Email)
|
||||||
@ -45,10 +45,11 @@ func (d *Datastore) NewInvite(i *fleet.Invite) (*fleet.Invite, error) {
|
|||||||
for _, userTeam := range i.Teams {
|
for _, userTeam := range i.Teams {
|
||||||
args = append(args, i.ID, userTeam.Team.ID, userTeam.Role)
|
args = append(args, i.ID, userTeam.Team.ID, userTeam.Role)
|
||||||
}
|
}
|
||||||
|
// TODO: seems like this should be in a transaction?
|
||||||
sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " +
|
sql := "INSERT INTO invite_teams (invite_id, team_id, role) VALUES " +
|
||||||
strings.Repeat(valueStr, len(i.Teams))
|
strings.Repeat(valueStr, len(i.Teams))
|
||||||
sql = strings.TrimSuffix(sql, ",")
|
sql = strings.TrimSuffix(sql, ",")
|
||||||
if _, err := d.db.Exec(sql, args...); err != nil {
|
if _, err := d.writer.Exec(sql, args...); err != nil {
|
||||||
return nil, errors.Wrap(err, "insert teams")
|
return nil, errors.Wrap(err, "insert teams")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
|
|||||||
query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...)
|
query, params := searchLike(query, nil, opt.MatchQuery, inviteSearchColumns...)
|
||||||
query = appendListOptionsToSQL(query, opt)
|
query = appendListOptionsToSQL(query, opt)
|
||||||
|
|
||||||
err := d.db.Select(&invites, query, params...)
|
err := d.reader.Select(&invites, query, params...)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Invite")
|
return nil, notFound("Invite")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@ -80,7 +81,7 @@ func (d *Datastore) ListInvites(opt fleet.ListOptions) ([]*fleet.Invite, error)
|
|||||||
// Invite returns Invite identified by id.
|
// Invite returns Invite identified by id.
|
||||||
func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
||||||
var invite fleet.Invite
|
var invite fleet.Invite
|
||||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
|
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE id = ?", id)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Invite").WithID(id)
|
return nil, notFound("Invite").WithID(id)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@ -97,7 +98,7 @@ func (d *Datastore) Invite(id uint) (*fleet.Invite, error) {
|
|||||||
// InviteByEmail finds an Invite with a particular email, if one exists.
|
// InviteByEmail finds an Invite with a particular email, if one exists.
|
||||||
func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
||||||
var invite fleet.Invite
|
var invite fleet.Invite
|
||||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
|
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE email = ?", email)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Invite").
|
return nil, notFound("Invite").
|
||||||
WithMessage(fmt.Sprintf("with email %s", email))
|
WithMessage(fmt.Sprintf("with email %s", email))
|
||||||
@ -115,7 +116,7 @@ func (d *Datastore) InviteByEmail(email string) (*fleet.Invite, error) {
|
|||||||
// InviteByToken finds an Invite with a particular token, if one exists.
|
// InviteByToken finds an Invite with a particular token, if one exists.
|
||||||
func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) {
|
func (d *Datastore) InviteByToken(token string) (*fleet.Invite, error) {
|
||||||
var invite fleet.Invite
|
var invite fleet.Invite
|
||||||
err := d.db.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
|
err := d.reader.Get(&invite, "SELECT * FROM invites WHERE token = ?", token)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Invite").
|
return nil, notFound("Invite").
|
||||||
WithMessage(fmt.Sprintf("with token %s", token))
|
WithMessage(fmt.Sprintf("with token %s", token))
|
||||||
@ -163,7 +164,7 @@ func (d *Datastore) loadTeamsForInvites(invites []*fleet.Invite) error {
|
|||||||
fleet.UserTeam
|
fleet.UserTeam
|
||||||
InviteID uint `db:"invite_id"`
|
InviteID uint `db:"invite_id"`
|
||||||
}
|
}
|
||||||
if err := d.db.Select(&rows, sql, args...); err != nil {
|
if err := d.reader.Select(&rows, sql, args...); err != nil {
|
||||||
return errors.Wrap(err, "get loadTeamsForInvites")
|
return errors.Wrap(err, "get loadTeamsForInvites")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ func (d *Datastore) GetLabelSpecs() ([]*fleet.LabelSpec, error) {
|
|||||||
var specs []*fleet.LabelSpec
|
var specs []*fleet.LabelSpec
|
||||||
// Get basic specs
|
// Get basic specs
|
||||||
query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels"
|
query := "SELECT id, name, description, query, platform, label_type, label_membership_type FROM labels"
|
||||||
if err := d.db.Select(&specs, query); err != nil {
|
if err := d.reader.Select(&specs, query); err != nil {
|
||||||
return nil, errors.Wrap(err, "get labels")
|
return nil, errors.Wrap(err, "get labels")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ SELECT name, description, query, platform, label_type, label_membership_type
|
|||||||
FROM labels
|
FROM labels
|
||||||
WHERE name = ?
|
WHERE name = ?
|
||||||
`
|
`
|
||||||
if err := d.db.Select(&specs, query, name); err != nil {
|
if err := d.reader.Select(&specs, query, name); err != nil {
|
||||||
return nil, errors.Wrap(err, "get label")
|
return nil, errors.Wrap(err, "get label")
|
||||||
}
|
}
|
||||||
if len(specs) == 0 {
|
if len(specs) == 0 {
|
||||||
@ -171,7 +171,7 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
|
|||||||
WHERE label_id = (SELECT id FROM labels WHERE name = ?)
|
WHERE label_id = (SELECT id FROM labels WHERE name = ?)
|
||||||
)
|
)
|
||||||
`
|
`
|
||||||
err := d.db.Select(&label.Hosts, sql, label.Name)
|
err := d.reader.Select(&label.Hosts, sql, label.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "get hostnames for label")
|
return errors.Wrap(err, "get hostnames for label")
|
||||||
}
|
}
|
||||||
@ -190,7 +190,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
|
|||||||
label_membership_type
|
label_membership_type
|
||||||
) VALUES ( ?, ?, ?, ?, ?, ?)
|
) VALUES ( ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(
|
result, err := d.writer.Exec(
|
||||||
query,
|
query,
|
||||||
label.Name,
|
label.Name,
|
||||||
label.Description,
|
label.Description,
|
||||||
@ -216,7 +216,7 @@ func (d *Datastore) SaveLabel(label *fleet.Label) (*fleet.Label, error) {
|
|||||||
description = ?
|
description = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
_, err := d.db.Exec(query, label.Name, label.Description, label.ID)
|
_, err := d.writer.Exec(query, label.Name, label.Description, label.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "saving label")
|
return nil, errors.Wrap(err, "saving label")
|
||||||
}
|
}
|
||||||
@ -236,7 +236,7 @@ func (d *Datastore) Label(lid uint) (*fleet.Label, error) {
|
|||||||
`
|
`
|
||||||
label := &fleet.Label{}
|
label := &fleet.Label{}
|
||||||
|
|
||||||
if err := d.db.Get(label, sql, lid); err != nil {
|
if err := d.reader.Get(label, sql, lid); err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting label")
|
return nil, errors.Wrap(err, "selecting label")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ func (d *Datastore) ListLabels(filter fleet.TeamFilter, opt fleet.ListOptions) (
|
|||||||
query = appendListOptionsToSQL(query, opt)
|
query = appendListOptionsToSQL(query, opt)
|
||||||
labels := []*fleet.Label{}
|
labels := []*fleet.Label{}
|
||||||
|
|
||||||
if err := d.db.Select(&labels, query); err != nil {
|
if err := d.reader.Select(&labels, query); err != nil {
|
||||||
// it's ok if no labels exist
|
// it's ok if no labels exist
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return labels, nil
|
return labels, nil
|
||||||
@ -288,7 +288,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
|
|||||||
WHERE platform = ? OR platform = ''
|
WHERE platform = ? OR platform = ''
|
||||||
AND label_membership_type = ?
|
AND label_membership_type = ?
|
||||||
`
|
`
|
||||||
rows, err = d.db.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
|
rows, err = d.reader.Query(sql, platform, fleet.LabelMembershipTypeDynamic)
|
||||||
} else {
|
} else {
|
||||||
// Retrieve all labels (with matching platform) iff there is a label
|
// Retrieve all labels (with matching platform) iff there is a label
|
||||||
// that has been created since this host last reported label query
|
// that has been created since this host last reported label query
|
||||||
@ -300,7 +300,7 @@ func (d *Datastore) LabelQueriesForHost(host *fleet.Host, cutoff time.Time) (map
|
|||||||
AND (platform = ? OR platform = '')
|
AND (platform = ? OR platform = '')
|
||||||
AND label_membership_type = ?
|
AND label_membership_type = ?
|
||||||
`
|
`
|
||||||
rows, err = d.db.Query(
|
rows, err = d.reader.Query(
|
||||||
sql,
|
sql,
|
||||||
platform,
|
platform,
|
||||||
host.LabelUpdatedAt,
|
host.LabelUpdatedAt,
|
||||||
@ -369,7 +369,7 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
|
|||||||
updated_at = VALUES(updated_at)
|
updated_at = VALUES(updated_at)
|
||||||
`
|
`
|
||||||
|
|
||||||
_, err := d.db.Exec(sql, vals...)
|
_, err := d.writer.Exec(sql, vals...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "insert label query executions (%v)", vals)
|
return errors.Wrapf(err, "insert label query executions (%v)", vals)
|
||||||
}
|
}
|
||||||
@ -384,8 +384,8 @@ func (d *Datastore) RecordLabelQueryExecutions(host *fleet.Host, results map[uin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "IN for DELETE FROM label_membership")
|
return errors.Wrap(err, "IN for DELETE FROM label_membership")
|
||||||
}
|
}
|
||||||
query = d.db.Rebind(query)
|
query = d.writer.Rebind(query)
|
||||||
_, err = d.db.Exec(query, args...)
|
_, err = d.writer.Exec(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "delete label query executions")
|
return errors.Wrap(err, "delete label query executions")
|
||||||
}
|
}
|
||||||
@ -403,7 +403,7 @@ func (d *Datastore) ListLabelsForHost(hid uint) ([]*fleet.Label, error) {
|
|||||||
`
|
`
|
||||||
|
|
||||||
labels := []*fleet.Label{}
|
labels := []*fleet.Label{}
|
||||||
err := d.db.Select(&labels, sqlStatement, hid)
|
err := d.reader.Select(&labels, sqlStatement, hid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting host labels")
|
return nil, errors.Wrap(err, "selecting host labels")
|
||||||
}
|
}
|
||||||
@ -432,7 +432,7 @@ func (d *Datastore) ListHostsInLabel(filter fleet.TeamFilter, lid uint, opt flee
|
|||||||
|
|
||||||
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
sql = appendListOptionsToSQL(sql, opt.ListOptions)
|
||||||
hosts := []*fleet.Host{}
|
hosts := []*fleet.Host{}
|
||||||
err := d.db.Select(&hosts, sql, params...)
|
err := d.reader.Select(&hosts, sql, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting label query executions")
|
return nil, errors.Wrap(err, "selecting label query executions")
|
||||||
}
|
}
|
||||||
@ -458,9 +458,9 @@ func (d *Datastore) ListUniqueHostsInLabels(filter fleet.TeamFilter, labels []ui
|
|||||||
return nil, errors.Wrap(err, "building query listing unique hosts in labels")
|
return nil, errors.Wrap(err, "building query listing unique hosts in labels")
|
||||||
}
|
}
|
||||||
|
|
||||||
query = d.db.Rebind(query)
|
query = d.reader.Rebind(query)
|
||||||
hosts := []*fleet.Host{}
|
hosts := []*fleet.Host{}
|
||||||
err = d.db.Select(&hosts, query, args...)
|
err = d.reader.Select(&hosts, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "listing unique hosts in labels")
|
return nil, errors.Wrap(err, "listing unique hosts in labels")
|
||||||
}
|
}
|
||||||
@ -492,10 +492,10 @@ func (d *Datastore) searchLabelsWithOmits(filter fleet.TeamFilter, query string,
|
|||||||
return nil, errors.Wrap(err, "building query for labels with omits")
|
return nil, errors.Wrap(err, "building query for labels with omits")
|
||||||
}
|
}
|
||||||
|
|
||||||
sql = d.db.Rebind(sql)
|
sql = d.reader.Rebind(sql)
|
||||||
|
|
||||||
matches := []*fleet.Label{}
|
matches := []*fleet.Label{}
|
||||||
err = d.db.Select(&matches, sql, args...)
|
err = d.reader.Select(&matches, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting labels with omits")
|
return nil, errors.Wrap(err, "selecting labels with omits")
|
||||||
}
|
}
|
||||||
@ -527,7 +527,7 @@ func (d *Datastore) addAllHostsLabelToList(filter fleet.TeamFilter, labels []*fl
|
|||||||
)
|
)
|
||||||
|
|
||||||
var allHosts fleet.Label
|
var allHosts fleet.Label
|
||||||
if err := d.db.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
|
if err := d.reader.Get(&allHosts, sql, fleet.LabelTypeBuiltIn); err != nil {
|
||||||
return nil, errors.Wrap(err, "get all hosts label")
|
return nil, errors.Wrap(err, "get all hosts label")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -575,8 +575,8 @@ func (d *Datastore) searchLabelsDefault(filter fleet.TeamFilter, omit ...uint) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "searching default labels")
|
return nil, errors.Wrap(err, "searching default labels")
|
||||||
}
|
}
|
||||||
sql = d.db.Rebind(sql)
|
sql = d.reader.Rebind(sql)
|
||||||
if err := d.db.Select(&labels, sql, args...); err != nil {
|
if err := d.reader.Select(&labels, sql, args...); err != nil {
|
||||||
return nil, errors.Wrap(err, "searching default labels rebound")
|
return nil, errors.Wrap(err, "searching default labels rebound")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -617,7 +617,7 @@ func (d *Datastore) SearchLabels(filter fleet.TeamFilter, query string, omit ...
|
|||||||
)
|
)
|
||||||
|
|
||||||
matches := []*fleet.Label{}
|
matches := []*fleet.Label{}
|
||||||
if err := d.db.Select(&matches, sql, transformedQuery); err != nil {
|
if err := d.reader.Select(&matches, sql, transformedQuery); err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting labels for search")
|
return nil, errors.Wrap(err, "selecting labels for search")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -645,7 +645,7 @@ func (d *Datastore) LabelIDsByName(labels []string) ([]uint, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var labelIDs []uint
|
var labelIDs []uint
|
||||||
if err := d.db.Select(&labelIDs, sql, args...); err != nil {
|
if err := d.reader.Select(&labelIDs, sql, args...); err != nil {
|
||||||
return nil, errors.Wrap(err, "get label IDs")
|
return nil, errors.Wrap(err, "get label IDs")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,27 +29,27 @@ func (d *Datastore) Lock(name string, owner string, expiration time.Duration) (b
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
func (d *Datastore) createLock(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||||
return d.db.Exec(
|
return d.writer.Exec(
|
||||||
`INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`,
|
`INSERT IGNORE INTO locks (name, owner, expires_at) VALUES (?, ?, ?)`,
|
||||||
name, owner, time.Now().Add(expiration),
|
name, owner, time.Now().Add(expiration),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
func (d *Datastore) extendLockIfAlreadyAcquired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||||
return d.db.Exec(
|
return d.writer.Exec(
|
||||||
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`,
|
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE name = ? and owner = ?`,
|
||||||
name, owner, time.Now().Add(expiration), name, owner,
|
name, owner, time.Now().Add(expiration), name, owner,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
func (d *Datastore) overwriteLockIfExpired(name string, owner string, expiration time.Duration) (sql.Result, error) {
|
||||||
return d.db.Exec(
|
return d.writer.Exec(
|
||||||
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`,
|
`UPDATE locks SET name = ?, owner = ?, expires_at = ? WHERE expires_at < CURRENT_TIMESTAMP and name = ?`,
|
||||||
name, owner, time.Now().Add(expiration), name,
|
name, owner, time.Now().Add(expiration), name,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) Unlock(name string, owner string) error {
|
func (d *Datastore) Unlock(name string, owner string) error {
|
||||||
_, err := d.db.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
|
_, err := d.writer.Exec(`DELETE FROM locks WHERE name = ? and owner = ?`, name, owner)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -85,14 +85,14 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
|
|||||||
defer ds.Close()
|
defer ds.Close()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
version, err := tables.MigrationClient.GetDBVersion(ds.db.DB)
|
version, err := tables.MigrationClient.GetDBVersion(ds.writer.DB)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// break right before the the constraint migration
|
// break right before the the constraint migration
|
||||||
if version == 20210818182258 {
|
if version == 20210818182258 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
|
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now())
|
host1 := test.NewHost(t, ds, "host1", "", "host1key", "host1uuid", time.Now())
|
||||||
@ -122,7 +122,7 @@ func Test20210819131107_AddCascadeToHostSoftware(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, ds.DeleteHost(host1.ID))
|
require.NoError(t, ds.DeleteHost(host1.ID))
|
||||||
|
|
||||||
require.NoError(t, tables.MigrationClient.UpByOne(ds.db.DB, ""))
|
require.NoError(t, tables.MigrationClient.UpByOne(ds.writer.DB, ""))
|
||||||
|
|
||||||
// Make sure we don't delete more than we need
|
// Make sure we don't delete more than we need
|
||||||
hostCheck, err := ds.Host(host2.ID)
|
hostCheck, err := ds.Host(host2.ID)
|
||||||
|
@ -37,13 +37,28 @@ var (
|
|||||||
columnCharsRegexp = regexp.MustCompile(`[^\w-]`)
|
columnCharsRegexp = regexp.MustCompile(`[^\w-]`)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// dbReader is an interface that defines the methods required for reads.
|
||||||
|
type dbReader interface {
|
||||||
|
sqlx.Queryer
|
||||||
|
|
||||||
|
Close() error
|
||||||
|
Rebind(string) string
|
||||||
|
Select(interface{}, string, ...interface{}) error
|
||||||
|
Get(interface{}, string, ...interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
// Datastore is an implementation of fleet.Datastore interface backed by
|
// Datastore is an implementation of fleet.Datastore interface backed by
|
||||||
// MySQL
|
// MySQL
|
||||||
type Datastore struct {
|
type Datastore struct {
|
||||||
db *sqlx.DB
|
reader dbReader // so it cannot be used to perform writes
|
||||||
|
writer *sqlx.DB
|
||||||
|
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
clock clock.Clock
|
clock clock.Clock
|
||||||
config config.MysqlConfig
|
config config.MysqlConfig
|
||||||
|
|
||||||
|
// nil if no read replica
|
||||||
|
readReplicaConfig *config.MysqlConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type txFn func(*sqlx.Tx) error
|
type txFn func(*sqlx.Tx) error
|
||||||
@ -67,7 +82,7 @@ func retryableError(err error) bool {
|
|||||||
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
||||||
func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
tx, err := d.db.Beginx()
|
tx, err := d.writer.Beginx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "create transaction")
|
return errors.Wrap(err, "create transaction")
|
||||||
}
|
}
|
||||||
@ -116,7 +131,7 @@ func (d *Datastore) withRetryTxx(fn txFn) (err error) {
|
|||||||
|
|
||||||
// withTx provides a common way to commit/rollback a txFn
|
// withTx provides a common way to commit/rollback a txFn
|
||||||
func (d *Datastore) withTx(fn txFn) (err error) {
|
func (d *Datastore) withTx(fn txFn) (err error) {
|
||||||
tx, err := d.db.Beginx()
|
tx, err := d.writer.Beginx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "create transaction")
|
return errors.Wrap(err, "create transaction")
|
||||||
}
|
}
|
||||||
@ -153,51 +168,64 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, setOpt := range opts {
|
for _, setOpt := range opts {
|
||||||
|
if setOpt != nil {
|
||||||
setOpt(options)
|
setOpt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PasswordPath != "" && config.Password != "" {
|
|
||||||
return nil, errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check to see if the flag is populated
|
if err := checkConfig(&config); err != nil {
|
||||||
// Check if file exists on disk
|
return nil, err
|
||||||
// If file exists read contents
|
}
|
||||||
if config.PasswordPath != "" {
|
if options.replicaConfig != nil {
|
||||||
fileContents, err := ioutil.ReadFile(config.PasswordPath)
|
if err := checkConfig(options.replicaConfig); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "replica")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dbWriter, err := newDB(&config, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
config.Password = strings.TrimSpace(string(fileContents))
|
dbReader := dbWriter
|
||||||
}
|
if options.replicaConfig != nil {
|
||||||
|
dbReader, err = newDB(options.replicaConfig, options)
|
||||||
if config.TLSCA != "" {
|
|
||||||
config.TLSConfig = "custom"
|
|
||||||
err := registerTLS(config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "register TLS config for mysql")
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dsn := generateMysqlConnectionString(config)
|
ds := &Datastore{
|
||||||
|
writer: dbWriter,
|
||||||
|
reader: dbReader,
|
||||||
|
logger: options.logger,
|
||||||
|
clock: c,
|
||||||
|
config: config,
|
||||||
|
readReplicaConfig: options.replicaConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDB(conf *config.MysqlConfig, opts *dbOptions) (*sqlx.DB, error) {
|
||||||
|
dsn := generateMysqlConnectionString(*conf)
|
||||||
db, err := sqlx.Open("mysql", dsn)
|
db, err := sqlx.Open("mysql", dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
db.SetMaxIdleConns(conf.MaxIdleConns)
|
||||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
db.SetMaxOpenConns(conf.MaxOpenConns)
|
||||||
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetime))
|
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
|
||||||
|
|
||||||
var dbError error
|
var dbError error
|
||||||
for attempt := 0; attempt < options.maxAttempts; attempt++ {
|
for attempt := 0; attempt < opts.maxAttempts; attempt++ {
|
||||||
dbError = db.Ping()
|
dbError = db.Ping()
|
||||||
if dbError == nil {
|
if dbError == nil {
|
||||||
// we're connected!
|
// we're connected!
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
interval := time.Duration(attempt) * time.Second
|
interval := time.Duration(attempt) * time.Second
|
||||||
options.logger.Log("mysql", fmt.Sprintf(
|
opts.logger.Log("mysql", fmt.Sprintf(
|
||||||
"could not connect to db: %v, sleeping %v", dbError, interval))
|
"could not connect to db: %v, sleeping %v", dbError, interval))
|
||||||
time.Sleep(interval)
|
time.Sleep(interval)
|
||||||
}
|
}
|
||||||
@ -205,20 +233,37 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
|||||||
if dbError != nil {
|
if dbError != nil {
|
||||||
return nil, dbError
|
return nil, dbError
|
||||||
}
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
ds := &Datastore{
|
func checkConfig(conf *config.MysqlConfig) error {
|
||||||
db: db,
|
if conf.PasswordPath != "" && conf.Password != "" {
|
||||||
logger: options.logger,
|
return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
|
||||||
clock: c,
|
|
||||||
config: config,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ds, nil
|
// Check to see if the flag is populated
|
||||||
|
// Check if file exists on disk
|
||||||
|
// If file exists read contents
|
||||||
|
if conf.PasswordPath != "" {
|
||||||
|
fileContents, err := ioutil.ReadFile(conf.PasswordPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conf.Password = strings.TrimSpace(string(fileContents))
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.TLSCA != "" {
|
||||||
|
conf.TLSConfig = "custom"
|
||||||
|
err := registerTLS(*conf)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "register TLS config for mysql")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) Begin() (fleet.Transaction, error) {
|
func (d *Datastore) Begin() (fleet.Transaction, error) {
|
||||||
return d.db.Beginx()
|
return d.writer.Beginx()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) Name() string {
|
func (d *Datastore) Name() string {
|
||||||
@ -226,11 +271,11 @@ func (d *Datastore) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) MigrateTables() error {
|
func (d *Datastore) MigrateTables() error {
|
||||||
return tables.MigrationClient.Up(d.db.DB, "")
|
return tables.MigrationClient.Up(d.writer.DB, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) MigrateData() error {
|
func (d *Datastore) MigrateData() error {
|
||||||
return data.MigrationClient.Up(d.db.DB, "")
|
return data.MigrationClient.Up(d.writer.DB, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
||||||
@ -243,7 +288,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
|||||||
return 0, errors.Wrap(err, "missing tables migrations")
|
return 0, errors.Wrap(err, "missing tables migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.db.DB)
|
currentTablesVersion, err := tables.MigrationClient.GetDBVersion(d.writer.DB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "cannot get table migration status")
|
return 0, errors.Wrap(err, "cannot get table migration status")
|
||||||
}
|
}
|
||||||
@ -253,7 +298,7 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
|
|||||||
return 0, errors.Wrap(err, "missing data migrations")
|
return 0, errors.Wrap(err, "missing data migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.db.DB)
|
currentDataVersion, err := data.MigrationClient.GetDBVersion(d.writer.DB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "cannot get data migration status")
|
return 0, errors.Wrap(err, "cannot get data migration status")
|
||||||
}
|
}
|
||||||
@ -283,11 +328,11 @@ func (d *Datastore) Drop() error {
|
|||||||
WHERE TABLE_SCHEMA = ?;
|
WHERE TABLE_SCHEMA = ?;
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := d.db.Select(&tables, sql, d.config.Database); err != nil {
|
if err := d.writer.Select(&tables, sql, d.config.Database); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := d.db.Begin()
|
tx, err := d.writer.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -312,13 +357,28 @@ func (d *Datastore) Drop() error {
|
|||||||
|
|
||||||
// HealthCheck returns an error if the MySQL backend is not healthy.
|
// HealthCheck returns an error if the MySQL backend is not healthy.
|
||||||
func (d *Datastore) HealthCheck() error {
|
func (d *Datastore) HealthCheck() error {
|
||||||
_, err := d.db.Exec("select 1")
|
if _, err := d.writer.Exec("select 1"); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
if d.readReplicaConfig != nil {
|
||||||
|
var dst int
|
||||||
|
if err := d.reader.Get(&dst, "select 1"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close frees resources associated with underlying mysql connection
|
// Close frees resources associated with underlying mysql connection
|
||||||
func (d *Datastore) Close() error {
|
func (d *Datastore) Close() error {
|
||||||
return d.db.Close()
|
err := d.writer.Close()
|
||||||
|
if d.readReplicaConfig != nil {
|
||||||
|
errRead := d.reader.Close()
|
||||||
|
if err == nil {
|
||||||
|
err = errRead
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeColumn(col string) string {
|
func sanitizeColumn(col string) string {
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/VividCortex/mysqlerr"
|
"github.com/VividCortex/mysqlerr"
|
||||||
@ -15,6 +17,51 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestDatastoreReplica(t *testing.T) {
|
||||||
|
// a bit unfortunate to create temp databases just for this - could be mixed
|
||||||
|
// with other tests when/if we move to subtests to minimize the number of
|
||||||
|
// databases created for tests (see #1805).
|
||||||
|
|
||||||
|
t.Run("noreplica", func(t *testing.T) {
|
||||||
|
ds := CreateMySQLDSWithOptions(t, nil)
|
||||||
|
defer ds.Close()
|
||||||
|
require.Equal(t, ds.reader, ds.writer)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replica", func(t *testing.T) {
|
||||||
|
opts := &DatastoreTestOptions{Replica: true}
|
||||||
|
ds := CreateMySQLDSWithOptions(t, opts)
|
||||||
|
defer ds.Close()
|
||||||
|
require.NotEqual(t, ds.reader, ds.writer)
|
||||||
|
|
||||||
|
// create a new host
|
||||||
|
host, err := ds.NewHost(&fleet.Host{
|
||||||
|
DetailUpdatedAt: time.Now(),
|
||||||
|
LabelUpdatedAt: time.Now(),
|
||||||
|
SeenTime: time.Now(),
|
||||||
|
NodeKey: "1",
|
||||||
|
UUID: "1",
|
||||||
|
Hostname: "foo.local",
|
||||||
|
PrimaryIP: "192.168.1.1",
|
||||||
|
PrimaryMac: "30-65-EC-6F-C4-58",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, host)
|
||||||
|
|
||||||
|
// trying to read it fails, not replicated yet
|
||||||
|
_, err = ds.Host(host.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, sql.ErrNoRows))
|
||||||
|
|
||||||
|
opts.RunReplication()
|
||||||
|
|
||||||
|
// now it can read it
|
||||||
|
host2, err := ds.Host(host.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, host.ID, host2.ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeColumn(t *testing.T) {
|
func TestSanitizeColumn(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -121,8 +168,10 @@ func TestSearchLike(t *testing.T) {
|
|||||||
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
|
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
dbmock := sqlx.NewDb(db, "sqlmock")
|
||||||
ds := &Datastore{
|
ds := &Datastore{
|
||||||
db: sqlx.NewDb(db, "sqlmock"),
|
writer: dbmock,
|
||||||
|
reader: dbmock,
|
||||||
logger: log.NewNopLogger(),
|
logger: log.NewNopLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
|
|||||||
WHERE name = ?
|
WHERE name = ?
|
||||||
`
|
`
|
||||||
var pack fleet.Pack
|
var pack fleet.Pack
|
||||||
err := d.db.Get(&pack, sqlStatement, name)
|
err := d.reader.Get(&pack, sqlStatement, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
@ -213,7 +213,7 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
|
|||||||
return nil, false, errors.Wrap(err, "fetch pack by name")
|
return nil, false, errors.Wrap(err, "fetch pack by name")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.loadPackTargets(&pack); err != nil {
|
if err := d.loadPackTargets(d.reader, &pack); err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,10 +313,10 @@ func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) loadPackTargets(pack *fleet.Pack) error {
|
func (d *Datastore) loadPackTargets(db dbReader, pack *fleet.Pack) error {
|
||||||
var targets []fleet.PackTarget
|
var targets []fleet.PackTarget
|
||||||
sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
|
sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
|
||||||
if err := d.db.Select(&targets, sql, pack.ID); err != nil {
|
if err := db.Select(&targets, sql, pack.ID); err != nil {
|
||||||
return errors.Wrap(err, "select pack targets")
|
return errors.Wrap(err, "select pack targets")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -371,14 +371,14 @@ func (d *Datastore) DeletePack(name string) error {
|
|||||||
func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
||||||
query := `SELECT * FROM packs WHERE id = ?`
|
query := `SELECT * FROM packs WHERE id = ?`
|
||||||
pack := &fleet.Pack{}
|
pack := &fleet.Pack{}
|
||||||
err := d.db.Get(pack, query, pid)
|
err := d.reader.Get(pack, query, pid)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Pack").WithID(pid)
|
return nil, notFound("Pack").WithID(pid)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, errors.Wrap(err, "get pack")
|
return nil, errors.Wrap(err, "get pack")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.loadPackTargets(pack); err != nil {
|
if err := d.loadPackTargets(d.reader, pack); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -388,14 +388,15 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
|
|||||||
// EnsureGlobalPack gets or inserts a pack with type global
|
// EnsureGlobalPack gets or inserts a pack with type global
|
||||||
func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) {
|
func (d *Datastore) EnsureGlobalPack() (*fleet.Pack, error) {
|
||||||
pack := &fleet.Pack{}
|
pack := &fleet.Pack{}
|
||||||
err := d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
|
// read from primary as we will create the pack if it doesn't exist
|
||||||
|
err := d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = 'global'`)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return d.insertNewGlobalPack()
|
return d.insertNewGlobalPack()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, errors.Wrap(err, "get pack")
|
return nil, errors.Wrap(err, "get pack")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.loadPackTargets(pack); err != nil {
|
if err := d.loadPackTargets(d.writer, pack); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -439,14 +440,15 @@ func (d *Datastore) EnsureTeamPack(teamID uint) (*fleet.Pack, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
teamType := fmt.Sprintf("team-%d", teamID)
|
teamType := fmt.Sprintf("team-%d", teamID)
|
||||||
err = d.db.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
|
// read from primary as we will create the team pack if it doesn't exist
|
||||||
|
err = d.writer.Get(pack, `SELECT * FROM packs WHERE pack_type = ?`, teamType)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return d.insertNewTeamPack(t)
|
return d.insertNewTeamPack(t)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, errors.Wrap(err, "get pack")
|
return nil, errors.Wrap(err, "get pack")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.loadPackTargets(pack); err != nil {
|
if err := d.loadPackTargets(d.writer, pack); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -499,13 +501,13 @@ func (d *Datastore) ListPacks(opt fleet.PackListOptions) ([]*fleet.Pack, error)
|
|||||||
query = `SELECT * FROM packs`
|
query = `SELECT * FROM packs`
|
||||||
}
|
}
|
||||||
var packs []*fleet.Pack
|
var packs []*fleet.Pack
|
||||||
err := d.db.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
|
err := d.reader.Select(&packs, appendListOptionsToSQL(query, opt.ListOptions))
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, errors.Wrap(err, "listing packs")
|
return nil, errors.Wrap(err, "listing packs")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pack := range packs {
|
for _, pack := range packs {
|
||||||
if err := d.loadPackTargets(pack); err != nil {
|
if err := d.loadPackTargets(d.reader, pack); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -541,7 +543,7 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
|
|||||||
`
|
`
|
||||||
|
|
||||||
packs := []*fleet.Pack{}
|
packs := []*fleet.Pack{}
|
||||||
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
|
if err := d.reader.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
|
||||||
return nil, errors.Wrap(err, "listing hosts in pack")
|
return nil, errors.Wrap(err, "listing hosts in pack")
|
||||||
}
|
}
|
||||||
return packs, nil
|
return packs, nil
|
||||||
|
@ -524,7 +524,7 @@ func TestTeamScheduleNamesMigrateToNewFormat(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// insert team pack by hand with the old naming scheme
|
// insert team pack by hand with the old naming scheme
|
||||||
_, err = ds.db.Exec(
|
_, err = ds.writer.Exec(
|
||||||
"INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)",
|
"INSERT INTO packs(name, description, platform, disabled, pack_type) VALUES (?, ?, ?, ?, ?)",
|
||||||
teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1),
|
teamSchedulePackType(team1), "desc", "windows", false, teamSchedulePackType(team1),
|
||||||
)
|
)
|
||||||
|
@ -11,7 +11,7 @@ func (d *Datastore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
|
|||||||
( user_id, token, expires_at)
|
( user_id, token, expires_at)
|
||||||
VALUES (?,?, NOW())
|
VALUES (?,?, NOW())
|
||||||
`
|
`
|
||||||
response, err := d.db.Exec(sqlStatement, req.UserID, req.Token)
|
response, err := d.writer.Exec(sqlStatement, req.UserID, req.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "inserting password reset requests")
|
return nil, errors.Wrap(err, "inserting password reset requests")
|
||||||
}
|
}
|
||||||
@ -30,7 +30,7 @@ func (d *Datastore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) er
|
|||||||
token = ?
|
token = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
|
result, err := d.writer.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "updating password reset requests")
|
return errors.Wrap(err, "updating password reset requests")
|
||||||
}
|
}
|
||||||
@ -58,7 +58,7 @@ func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
|
|||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
DELETE FROM password_reset_requests WHERE user_id = ?
|
DELETE FROM password_reset_requests WHERE user_id = ?
|
||||||
`
|
`
|
||||||
_, err := d.db.Exec(sqlStatement, userID)
|
_, err := d.writer.Exec(sqlStatement, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "deleting password reset request by user")
|
return errors.Wrap(err, "deleting password reset request by user")
|
||||||
}
|
}
|
||||||
@ -72,7 +72,7 @@ func (d *Datastore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest
|
|||||||
WHERE id = ? LIMIT 1
|
WHERE id = ? LIMIT 1
|
||||||
`
|
`
|
||||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||||
err := d.db.Get(&passwordResetRequest, sqlStatement, id)
|
err := d.reader.Get(&passwordResetRequest, sqlStatement, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting password reset by id")
|
return nil, errors.Wrap(err, "selecting password reset by id")
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ func (d *Datastore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordReset
|
|||||||
`
|
`
|
||||||
|
|
||||||
passwordResetRequests := []*fleet.PasswordResetRequest{}
|
passwordResetRequests := []*fleet.PasswordResetRequest{}
|
||||||
err := d.db.Select(&passwordResetRequests, sqlStatement, id)
|
err := d.reader.Select(&passwordResetRequests, sqlStatement, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "finding password resets by user id")
|
return nil, errors.Wrap(err, "finding password resets by user id")
|
||||||
}
|
}
|
||||||
@ -102,7 +102,7 @@ func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordRese
|
|||||||
WHERE token = ? LIMIT 1
|
WHERE token = ? LIMIT 1
|
||||||
`
|
`
|
||||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||||
err := d.db.Get(passwordResetRequest, sqlStatement, token)
|
err := d.reader.Get(passwordResetRequest, sqlStatement, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting password reset requests")
|
return nil, errors.Wrap(err, "selecting password reset requests")
|
||||||
}
|
}
|
||||||
@ -118,7 +118,7 @@ func (d *Datastore) FindPassswordResetByTokenAndUserID(token string, id uint) (*
|
|||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
passwordResetRequest := &fleet.PasswordResetRequest{}
|
passwordResetRequest := &fleet.PasswordResetRequest{}
|
||||||
err := d.db.Get(passwordResetRequest, sqlStatement, id, token)
|
err := d.reader.Get(passwordResetRequest, sqlStatement, id, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting password reset by token and user id")
|
return nil, errors.Wrap(err, "selecting password reset by token and user id")
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
||||||
res, err := ds.db.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
|
res, err := ds.writer.Exec(`INSERT INTO policies (query_id) VALUES (?)`, queryID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "inserting new policy")
|
return nil, errors.Wrap(err, "inserting new policy")
|
||||||
}
|
}
|
||||||
@ -26,7 +26,7 @@ func (ds *Datastore) NewGlobalPolicy(queryID uint) (*fleet.Policy, error) {
|
|||||||
|
|
||||||
func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) {
|
func (ds *Datastore) Policy(id uint) (*fleet.Policy, error) {
|
||||||
var policy fleet.Policy
|
var policy fleet.Policy
|
||||||
err := ds.db.Get(
|
err := ds.reader.Get(
|
||||||
&policy,
|
&policy,
|
||||||
`SELECT
|
`SELECT
|
||||||
p.*,
|
p.*,
|
||||||
@ -66,7 +66,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
|
|||||||
strings.Join(bindvars, ","),
|
strings.Join(bindvars, ","),
|
||||||
)
|
)
|
||||||
|
|
||||||
_, err := ds.db.Exec(query, vals...)
|
_, err := ds.writer.Exec(query, vals...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "insert policy_membership (%v)", vals)
|
return errors.Wrapf(err, "insert policy_membership (%v)", vals)
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ func (ds *Datastore) RecordPolicyQueryExecutions(host *fleet.Host, results map[u
|
|||||||
|
|
||||||
func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) {
|
func (ds *Datastore) ListGlobalPolicies() ([]*fleet.Policy, error) {
|
||||||
var policies []*fleet.Policy
|
var policies []*fleet.Policy
|
||||||
err := ds.db.Select(
|
err := ds.reader.Select(
|
||||||
&policies,
|
&policies,
|
||||||
`SELECT
|
`SELECT
|
||||||
p.*,
|
p.*,
|
||||||
@ -97,8 +97,8 @@ func (ds *Datastore) DeleteGlobalPolicies(ids []uint) ([]uint, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "IN for DELETE FROM policies")
|
return nil, errors.Wrap(err, "IN for DELETE FROM policies")
|
||||||
}
|
}
|
||||||
stmt = ds.db.Rebind(stmt)
|
stmt = ds.writer.Rebind(stmt)
|
||||||
if _, err := ds.db.Exec(stmt, args...); err != nil {
|
if _, err := ds.writer.Exec(stmt, args...); err != nil {
|
||||||
return nil, errors.Wrap(err, "delete policies")
|
return nil, errors.Wrap(err, "delete policies")
|
||||||
}
|
}
|
||||||
return ids, nil
|
return ids, nil
|
||||||
@ -109,7 +109,7 @@ func (ds *Datastore) PolicyQueriesForHost(_ *fleet.Host) (map[string]string, err
|
|||||||
Id string `db:"id"`
|
Id string `db:"id"`
|
||||||
Query string `db:"query"`
|
Query string `db:"query"`
|
||||||
}
|
}
|
||||||
err := ds.db.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
|
err := ds.reader.Select(&rows, `SELECT p.id, q.query FROM policies p JOIN queries q ON (p.query_id=q.id)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting policies for host")
|
return nil, errors.Wrap(err, "selecting policies for host")
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) {
|
func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err error) {
|
||||||
tx, err := d.db.Begin()
|
tx, err := d.writer.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "begin ApplyQueries transaction")
|
return errors.Wrap(err, "begin ApplyQueries transaction")
|
||||||
}
|
}
|
||||||
@ -72,7 +72,7 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
|
|||||||
WHERE name = ?
|
WHERE name = ?
|
||||||
`
|
`
|
||||||
var query fleet.Query
|
var query fleet.Query
|
||||||
err := d.db.Get(&query, sqlStatement, name)
|
err := d.reader.Get(&query, sqlStatement, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, notFound("Query").WithName(name)
|
return nil, notFound("Query").WithName(name)
|
||||||
@ -99,7 +99,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
|
|||||||
observer_can_run
|
observer_can_run
|
||||||
) VALUES ( ?, ?, ?, ?, ?, ? )
|
) VALUES ( ?, ?, ?, ?, ?, ? )
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
|
result, err := d.writer.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
|
||||||
|
|
||||||
if err != nil && isDuplicate(err) {
|
if err != nil && isDuplicate(err) {
|
||||||
return nil, alreadyExists("Query", 0)
|
return nil, alreadyExists("Query", 0)
|
||||||
@ -120,7 +120,7 @@ func (d *Datastore) SaveQuery(q *fleet.Query) error {
|
|||||||
SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ?
|
SET name = ?, description = ?, query = ?, author_id = ?, saved = ?, observer_can_run = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
|
result, err := d.writer.Exec(sql, q.Name, q.Description, q.Query, q.AuthorID, q.Saved, q.ObserverCanRun, q.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "updating query")
|
return errors.Wrap(err, "updating query")
|
||||||
}
|
}
|
||||||
@ -156,7 +156,7 @@ func (d *Datastore) Query(id uint) (*fleet.Query, error) {
|
|||||||
WHERE q.id = ?
|
WHERE q.id = ?
|
||||||
`
|
`
|
||||||
query := &fleet.Query{}
|
query := &fleet.Query{}
|
||||||
if err := d.db.Get(query, sql, id); err != nil {
|
if err := d.reader.Get(query, sql, id); err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting query")
|
return nil, errors.Wrap(err, "selecting query")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ func (d *Datastore) ListQueries(opt fleet.ListOptions) ([]*fleet.Query, error) {
|
|||||||
sql = appendListOptionsToSQL(sql, opt)
|
sql = appendListOptionsToSQL(sql, opt)
|
||||||
results := []*fleet.Query{}
|
results := []*fleet.Query{}
|
||||||
|
|
||||||
if err := d.db.Select(&results, sql); err != nil {
|
if err := d.reader.Select(&results, sql); err != nil {
|
||||||
return nil, errors.Wrap(err, "listing queries")
|
return nil, errors.Wrap(err, "listing queries")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ func (d *Datastore) loadPacksForQueries(queries []*fleet.Query) error {
|
|||||||
fleet.Pack
|
fleet.Pack
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
err = d.db.Select(&rows, query, args...)
|
err = d.reader.Select(&rows, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "selecting load packs for queries")
|
return errors.Wrap(err, "selecting load packs for queries")
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
|
|||||||
query = appendListOptionsToSQL(query, opts)
|
query = appendListOptionsToSQL(query, opts)
|
||||||
results := []*fleet.ScheduledQuery{}
|
results := []*fleet.ScheduledQuery{}
|
||||||
|
|
||||||
if err := d.db.Select(&results, query, id); err != nil {
|
if err := d.reader.Select(&results, query, id); err != nil {
|
||||||
return nil, errors.Wrap(err, "listing scheduled queries")
|
return nil, errors.Wrap(err, "listing scheduled queries")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
func (d *Datastore) insertScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
||||||
selectFunc := d.db.Select
|
selectFunc := d.writer.Select
|
||||||
execFunc := d.db.Exec
|
execFunc := d.writer.Exec
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
selectFunc = tx.Select
|
selectFunc = tx.Select
|
||||||
execFunc = tx.Exec
|
execFunc = tx.Exec
|
||||||
@ -108,7 +108,7 @@ func (d *Datastore) SaveScheduledQuery(sq *fleet.ScheduledQuery) (*fleet.Schedul
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
func (d *Datastore) saveScheduledQuery(tx *sqlx.Tx, sq *fleet.ScheduledQuery) (*fleet.ScheduledQuery, error) {
|
||||||
updateFunc := d.db.Exec
|
updateFunc := d.writer.Exec
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
updateFunc = tx.Exec
|
updateFunc = tx.Exec
|
||||||
}
|
}
|
||||||
@ -160,7 +160,7 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
|
|||||||
WHERE sq.id = ?
|
WHERE sq.id = ?
|
||||||
`
|
`
|
||||||
sq := &fleet.ScheduledQuery{}
|
sq := &fleet.ScheduledQuery{}
|
||||||
if err := d.db.Get(sq, query, id); err != nil {
|
if err := d.reader.Get(sq, query, id); err != nil {
|
||||||
return nil, errors.Wrap(err, "select scheduled query")
|
return nil, errors.Wrap(err, "select scheduled query")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,11 +168,11 @@ func (d *Datastore) ScheduledQuery(id uint) (*fleet.ScheduledQuery, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) CleanupOrphanScheduledQueryStats() error {
|
func (d *Datastore) CleanupOrphanScheduledQueryStats() error {
|
||||||
_, err := d.db.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
|
_, err := d.writer.Exec(`DELETE FROM scheduled_query_stats where scheduled_query_id not in (select id from scheduled_queries where id=scheduled_query_id)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query")
|
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by scheduled_query")
|
||||||
}
|
}
|
||||||
_, err = d.db.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
|
_, err = d.writer.Exec(`DELETE FROM scheduled_query_stats where host_id not in (select id from hosts where id=host_id)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host")
|
return errors.Wrap(err, "cleaning orphan scheduled_query_stats by host")
|
||||||
}
|
}
|
||||||
|
@ -209,7 +209,7 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
|||||||
test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1")
|
test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "1")
|
||||||
sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2")
|
sq1 := test.NewScheduledQuery(t, ds, p1.ID, q1.ID, 60, false, false, "2")
|
||||||
|
|
||||||
_, err := ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
_, err := ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||||
host_id, scheduled_query_id, average_memory, denylisted,
|
host_id, scheduled_query_id, average_memory, denylisted,
|
||||||
executions, schedule_interval, output_size, system_time,
|
executions, schedule_interval, output_size, system_time,
|
||||||
user_time, wall_time
|
user_time, wall_time
|
||||||
@ -224,12 +224,12 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
|||||||
require.Len(t, h1.PackStats, 1)
|
require.Len(t, h1.PackStats, 1)
|
||||||
|
|
||||||
// now we insert a bogus stat
|
// now we insert a bogus stat
|
||||||
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||||
host_id, scheduled_query_id, average_memory, denylisted, executions
|
host_id, scheduled_query_id, average_memory, denylisted, executions
|
||||||
) VALUES (?, 999, 32, false, 2);`, h1.ID)
|
) VALUES (?, 999, 32, false, 2);`, h1.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// and also for an unknown host
|
// and also for an unknown host
|
||||||
_, err = ds.db.Exec(`INSERT INTO scheduled_query_stats (
|
_, err = ds.writer.Exec(`INSERT INTO scheduled_query_stats (
|
||||||
host_id, scheduled_query_id, average_memory, denylisted, executions
|
host_id, scheduled_query_id, average_memory, denylisted, executions
|
||||||
) VALUES (888, 999, 32, true, 4);`)
|
) VALUES (888, 999, 32, true, 4);`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -241,14 +241,14 @@ func TestCleanupOrphanScheduledQueryStats(t *testing.T) {
|
|||||||
|
|
||||||
// but there are definitely there
|
// but there are definitely there
|
||||||
var count int
|
var count int
|
||||||
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 3, count)
|
assert.Equal(t, 3, count)
|
||||||
|
|
||||||
// now we clean it up
|
// now we clean it up
|
||||||
require.NoError(t, ds.CleanupOrphanScheduledQueryStats())
|
require.NoError(t, ds.CleanupOrphanScheduledQueryStats())
|
||||||
|
|
||||||
err = ds.db.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
err = ds.writer.Get(&count, `SELECT count(*) FROM scheduled_query_stats`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, count)
|
assert.Equal(t, 1, count)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ func (d *Datastore) SessionByKey(key string) (*fleet.Session, error) {
|
|||||||
WHERE ` + "`key`" + ` = ? LIMIT 1
|
WHERE ` + "`key`" + ` = ? LIMIT 1
|
||||||
`
|
`
|
||||||
session := &fleet.Session{}
|
session := &fleet.Session{}
|
||||||
err := d.db.Get(session, sqlStatement, key)
|
err := d.reader.Get(session, sqlStatement, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting sessions")
|
return nil, errors.Wrap(err, "selecting sessions")
|
||||||
}
|
}
|
||||||
@ -26,7 +26,7 @@ func (d *Datastore) SessionByID(id uint) (*fleet.Session, error) {
|
|||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
session := &fleet.Session{}
|
session := &fleet.Session{}
|
||||||
err := d.db.Get(session, sqlStatement, id)
|
err := d.reader.Get(session, sqlStatement, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting session by id")
|
return nil, errors.Wrap(err, "selecting session by id")
|
||||||
}
|
}
|
||||||
@ -40,7 +40,7 @@ func (d *Datastore) ListSessionsForUser(id uint) ([]*fleet.Session, error) {
|
|||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
`
|
`
|
||||||
sessions := []*fleet.Session{}
|
sessions := []*fleet.Session{}
|
||||||
err := d.db.Select(&sessions, sqlStatement, id)
|
err := d.reader.Select(&sessions, sqlStatement, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "selecting sessions for user")
|
return nil, errors.Wrap(err, "selecting sessions for user")
|
||||||
}
|
}
|
||||||
@ -57,7 +57,7 @@ func (d *Datastore) NewSession(session *fleet.Session) (*fleet.Session, error) {
|
|||||||
)
|
)
|
||||||
VALUES(?,?)
|
VALUES(?,?)
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement, session.UserID, session.Key)
|
result, err := d.writer.Exec(sqlStatement, session.UserID, session.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "inserting session")
|
return nil, errors.Wrap(err, "inserting session")
|
||||||
}
|
}
|
||||||
@ -80,7 +80,7 @@ func (d *Datastore) DestroyAllSessionsForUser(id uint) error {
|
|||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
DELETE FROM sessions WHERE user_id = ?
|
DELETE FROM sessions WHERE user_id = ?
|
||||||
`
|
`
|
||||||
_, err := d.db.Exec(sqlStatement, id)
|
_, err := d.writer.Exec(sqlStatement, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "deleting sessions for user")
|
return errors.Wrap(err, "deleting sessions for user")
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ func (d *Datastore) MarkSessionAccessed(session *fleet.Session) error {
|
|||||||
accessed_at = ?
|
accessed_at = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
results, err := d.db.Exec(sqlStatement, d.clock.Now(), session.ID)
|
results, err := d.writer.Exec(sqlStatement, d.clock.Now(), session.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "updating mark session as accessed")
|
return errors.Wrap(err, "updating mark session as accessed")
|
||||||
}
|
}
|
||||||
|
@ -206,7 +206,7 @@ func (d *Datastore) insertNewInstalledHostSoftware(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) {
|
func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Software, error) {
|
||||||
selectFunc := d.db.Select
|
selectFunc := d.reader.Select
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
selectFunc = tx.Select
|
selectFunc = tx.Select
|
||||||
}
|
}
|
||||||
@ -231,7 +231,7 @@ func (d *Datastore) hostSoftwareFromHostID(tx *sqlx.Tx, id uint) ([]fleet.Softwa
|
|||||||
WHERE s.id IN
|
WHERE s.id IN
|
||||||
(SELECT software_id FROM host_software WHERE host_id = ?)
|
(SELECT software_id FROM host_software WHERE host_id = ?)
|
||||||
`
|
`
|
||||||
queryFunc := d.db.Queryx
|
queryFunc := d.reader.Queryx
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
queryFunc = tx.Queryx
|
queryFunc = tx.Queryx
|
||||||
}
|
}
|
||||||
@ -305,7 +305,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
|
|||||||
sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null`
|
sql := `SELECT s.* FROM software s LEFT JOIN software_cpe sc on (s.id=sc.software_id) WHERE sc.id is null`
|
||||||
// The rows.Close call is done by the caller once iteration using the
|
// The rows.Close call is done by the caller once iteration using the
|
||||||
// returned fleet.SoftwareIterator is done.
|
// returned fleet.SoftwareIterator is done.
|
||||||
rows, err := d.db.Queryx(sql) //nolint:sqlclosecheck
|
rows, err := d.reader.Queryx(sql) //nolint:sqlclosecheck
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "load host software")
|
return nil, errors.Wrap(err, "load host software")
|
||||||
}
|
}
|
||||||
@ -314,7 +314,7 @@ func (d *Datastore) AllSoftwareWithoutCPEIterator() (fleet.SoftwareIterator, err
|
|||||||
|
|
||||||
func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error {
|
func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error {
|
||||||
sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)`
|
sql := `INSERT INTO software_cpe (software_id, cpe) VALUES (?, ?)`
|
||||||
if _, err := d.db.Exec(sql, software.ID, cpe); err != nil {
|
if _, err := d.writer.Exec(sql, software.ID, cpe); err != nil {
|
||||||
return errors.Wrap(err, "insert software cpe")
|
return errors.Wrap(err, "insert software cpe")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -323,7 +323,7 @@ func (d *Datastore) AddCPEForSoftware(software fleet.Software, cpe string) error
|
|||||||
func (d *Datastore) AllCPEs() ([]string, error) {
|
func (d *Datastore) AllCPEs() ([]string, error) {
|
||||||
sql := `SELECT cpe FROM software_cpe`
|
sql := `SELECT cpe FROM software_cpe`
|
||||||
var cpes []string
|
var cpes []string
|
||||||
err := d.db.Select(&cpes, sql)
|
err := d.reader.Select(&cpes, sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "loads cpes")
|
return nil, errors.Wrap(err, "loads cpes")
|
||||||
}
|
}
|
||||||
@ -337,7 +337,7 @@ func (d *Datastore) InsertCVEForCPE(cve string, cpes []string) error {
|
|||||||
for _, cpe := range cpes {
|
for _, cpe := range cpes {
|
||||||
args = append(args, cpe, cve)
|
args = append(args, cpe, cve)
|
||||||
}
|
}
|
||||||
_, err := d.db.Exec(sql, args...)
|
_, err := d.writer.Exec(sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "insert software cve")
|
return errors.Wrap(err, "insert software cve")
|
||||||
}
|
}
|
||||||
|
@ -207,7 +207,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
|
|||||||
})
|
})
|
||||||
incoming[soft2Key] = true
|
incoming[soft2Key] = true
|
||||||
|
|
||||||
tx, err := ds.db.Beginx()
|
tx, err := ds.writer.Beginx()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
||||||
require.NoError(t, tx.Commit())
|
require.NoError(t, tx.Commit())
|
||||||
@ -220,7 +220,7 @@ func TestHostSoftwareDuplicates(t *testing.T) {
|
|||||||
})
|
})
|
||||||
incoming[soft3Key] = true
|
incoming[soft3Key] = true
|
||||||
|
|
||||||
tx, err = ds.db.Beginx()
|
tx, err = ds.writer.Beginx()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
require.NoError(t, ds.insertNewInstalledHostSoftware(tx, host1.ID, make(map[string]uint), incoming))
|
||||||
require.NoError(t, tx.Commit())
|
require.NoError(t, tx.Commit())
|
||||||
|
@ -21,14 +21,14 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
|
|||||||
}
|
}
|
||||||
|
|
||||||
dest := statistics{}
|
dest := statistics{}
|
||||||
err = d.db.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
|
err = d.writer.Get(&dest, `SELECT created_at, updated_at, anonymous_identifier FROM statistics LIMIT 1`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
anonIdentifier, err := server.GenerateRandomText(64)
|
anonIdentifier, err := server.GenerateRandomText(64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fleet.StatisticsPayload{}, false, err
|
return fleet.StatisticsPayload{}, false, err
|
||||||
}
|
}
|
||||||
_, err = d.db.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
|
_, err = d.writer.Exec(`INSERT INTO statistics(anonymous_identifier) VALUES (?)`, anonIdentifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fleet.StatisticsPayload{}, false, err
|
return fleet.StatisticsPayload{}, false, err
|
||||||
}
|
}
|
||||||
@ -55,6 +55,6 @@ func (d *Datastore) ShouldSendStatistics(frequency time.Duration) (fleet.Statist
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) RecordStatisticsSent() error {
|
func (d *Datastore) RecordStatisticsSent() error {
|
||||||
_, err := d.db.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
|
_, err := d.writer.Exec(`UPDATE statistics SET updated_at = CURRENT_TIMESTAMP LIMIT 1`)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func (d *Datastore) CountHostsInTargets(filter fleet.TeamFilter, targets fleet.H
|
|||||||
}
|
}
|
||||||
|
|
||||||
res := fleet.TargetMetrics{}
|
res := fleet.TargetMetrics{}
|
||||||
err = d.db.Get(&res, query, args...)
|
err = d.reader.Get(&res, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets")
|
return fleet.TargetMetrics{}, errors.Wrap(err, "sqlx.Get CountHostsInTargets")
|
||||||
}
|
}
|
||||||
@ -98,7 +98,7 @@ func (d *Datastore) HostIDsInTargets(filter fleet.TeamFilter, targets fleet.Host
|
|||||||
}
|
}
|
||||||
|
|
||||||
var res []uint
|
var res []uint
|
||||||
err = d.db.Select(&res, query, args...)
|
err = d.reader.Select(&res, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets")
|
return nil, errors.Wrap(err, "sqlx.Get HostIDsInTargets")
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ func (d *Datastore) NewTeam(team *fleet.Team) (*fleet.Team, error) {
|
|||||||
description
|
description
|
||||||
) VALUES ( ?, ?, ? )
|
) VALUES ( ?, ?, ? )
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(
|
result, err := d.writer.Exec(
|
||||||
query,
|
query,
|
||||||
team.Name,
|
team.Name,
|
||||||
team.AgentOptions,
|
team.AgentOptions,
|
||||||
@ -47,7 +47,7 @@ func (d *Datastore) Team(tid uint) (*fleet.Team, error) {
|
|||||||
`
|
`
|
||||||
team := &fleet.Team{}
|
team := &fleet.Team{}
|
||||||
|
|
||||||
if err := d.db.Get(team, sql, tid); err != nil {
|
if err := d.reader.Get(team, sql, tid); err != nil {
|
||||||
return nil, errors.Wrap(err, "select team")
|
return nil, errors.Wrap(err, "select team")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ func (d *Datastore) TeamByName(name string) (*fleet.Team, error) {
|
|||||||
`
|
`
|
||||||
team := &fleet.Team{}
|
team := &fleet.Team{}
|
||||||
|
|
||||||
if err := d.db.Get(team, sql, name); err != nil {
|
if err := d.reader.Get(team, sql, name); err != nil {
|
||||||
return nil, errors.Wrap(err, "select team")
|
return nil, errors.Wrap(err, "select team")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ func (d *Datastore) loadUsersForTeam(team *fleet.Team) error {
|
|||||||
WHERE ut.team_id = ?
|
WHERE ut.team_id = ?
|
||||||
`
|
`
|
||||||
rows := []fleet.TeamUser{}
|
rows := []fleet.TeamUser{}
|
||||||
if err := d.db.Select(&rows, sql, team.ID); err != nil {
|
if err := d.reader.Select(&rows, sql, team.ID); err != nil {
|
||||||
return errors.Wrap(err, "load users for team")
|
return errors.Wrap(err, "load users for team")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
|
|||||||
description = ?
|
description = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
_, err := d.db.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
|
_, err := d.writer.Exec(query, team.Name, team.AgentOptions, team.Description, team.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "saving team")
|
return nil, errors.Wrap(err, "saving team")
|
||||||
}
|
}
|
||||||
@ -174,7 +174,7 @@ func (d *Datastore) SaveTeam(team *fleet.Team) (*fleet.Team, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Datastore) updateTeamSchedule(team *fleet.Team) error {
|
func (d *Datastore) updateTeamSchedule(team *fleet.Team) error {
|
||||||
_, err := d.db.Exec(
|
_, err := d.writer.Exec(
|
||||||
`UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team),
|
`UPDATE packs SET name = ? WHERE pack_type = ?`, teamScheduleName(team), teamSchedulePackType(team),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@ -195,7 +195,7 @@ func (d *Datastore) ListTeams(filter fleet.TeamFilter, opt fleet.ListOptions) ([
|
|||||||
query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...)
|
query, params := searchLike(query, nil, opt.MatchQuery, teamSearchColumns...)
|
||||||
query = appendListOptionsToSQL(query, opt)
|
query = appendListOptionsToSQL(query, opt)
|
||||||
teams := []*fleet.Team{}
|
teams := []*fleet.Team{}
|
||||||
if err := d.db.Select(&teams, query, params...); err != nil {
|
if err := d.reader.Select(&teams, query, params...); err != nil {
|
||||||
return nil, errors.Wrap(err, "list teams")
|
return nil, errors.Wrap(err, "list teams")
|
||||||
}
|
}
|
||||||
if err := d.loadSecretsForTeams(teams); err != nil {
|
if err := d.loadSecretsForTeams(teams); err != nil {
|
||||||
@ -229,7 +229,7 @@ func (d *Datastore) SearchTeams(filter fleet.TeamFilter, matchQuery string, omit
|
|||||||
sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...)
|
sql, params := searchLike(sql, nil, matchQuery, teamSearchColumns...)
|
||||||
sql += "\nLIMIT 5"
|
sql += "\nLIMIT 5"
|
||||||
teams := []*fleet.Team{}
|
teams := []*fleet.Team{}
|
||||||
if err := d.db.Select(&teams, sql, params...); err != nil {
|
if err := d.reader.Select(&teams, sql, params...); err != nil {
|
||||||
return nil, errors.Wrap(err, "search teams")
|
return nil, errors.Wrap(err, "search teams")
|
||||||
}
|
}
|
||||||
if err := d.loadSecretsForTeams(teams); err != nil {
|
if err := d.loadSecretsForTeams(teams); err != nil {
|
||||||
@ -244,7 +244,7 @@ func (d *Datastore) TeamEnrollSecrets(teamID uint) ([]*fleet.EnrollSecret, error
|
|||||||
WHERE team_id = ?
|
WHERE team_id = ?
|
||||||
`
|
`
|
||||||
var secrets []*fleet.EnrollSecret
|
var secrets []*fleet.EnrollSecret
|
||||||
if err := d.db.Select(&secrets, sql, teamID); err != nil {
|
if err := d.reader.Select(&secrets, sql, teamID); err != nil {
|
||||||
return nil, errors.Wrap(err, "get secrets")
|
return nil, errors.Wrap(err, "get secrets")
|
||||||
}
|
}
|
||||||
return secrets, nil
|
return secrets, nil
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
@ -9,10 +10,12 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/WatchBeam/clock"
|
"github.com/WatchBeam/clock"
|
||||||
"github.com/fleetdm/fleet/v4/server/config"
|
"github.com/fleetdm/fleet/v4/server/config"
|
||||||
"github.com/go-kit/kit/log"
|
"github.com/go-kit/kit/log"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,9 +23,10 @@ const (
|
|||||||
testUsername = "root"
|
testUsername = "root"
|
||||||
testPassword = "toor"
|
testPassword = "toor"
|
||||||
testAddress = "localhost:3307"
|
testAddress = "localhost:3307"
|
||||||
|
testReplicaDatabaseSuffix = "_replica"
|
||||||
)
|
)
|
||||||
|
|
||||||
func connectMySQL(t *testing.T, testName string) *Datastore {
|
func connectMySQL(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
|
||||||
config := config.MysqlConfig{
|
config := config.MysqlConfig{
|
||||||
Username: testUsername,
|
Username: testUsername,
|
||||||
Password: testPassword,
|
Password: testPassword,
|
||||||
@ -31,15 +35,134 @@ func connectMySQL(t *testing.T, testName string) *Datastore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create datastore client
|
// Create datastore client
|
||||||
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
|
var replicaOpt DBOption
|
||||||
|
if opts.Replica {
|
||||||
|
replicaConf := config
|
||||||
|
replicaConf.Database += testReplicaDatabaseSuffix
|
||||||
|
replicaOpt = Replica(&replicaConf)
|
||||||
|
}
|
||||||
|
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1), replicaOpt)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
if opts.Replica {
|
||||||
|
setupReadReplica(t, testName, ds, opts)
|
||||||
|
}
|
||||||
|
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupReadReplica(t *testing.T, testName string, ds *Datastore, opts *DatastoreTestOptions) {
|
||||||
|
// create the context that will cancel the replication goroutine on test exit
|
||||||
|
var cancel func()
|
||||||
|
ctx := context.Background()
|
||||||
|
if dl, ok := t.Deadline(); ok {
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, dl)
|
||||||
|
} else {
|
||||||
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
|
}
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
// start the replication goroutine that runs when signalled through a
|
||||||
|
// channel, the replication runs in lock-step - the test is in control of
|
||||||
|
// when the replication happens, by calling opts.RunReplication(), and when
|
||||||
|
// that call returns, the replication is guaranteed to be done. This supports
|
||||||
|
// simulating all kinds of replica lag.
|
||||||
|
ch := make(chan chan struct{})
|
||||||
|
go func() {
|
||||||
|
// if it exits because of a panic/failed replication, cancel the context
|
||||||
|
// immediately so that RunReplication is unblocked too.
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
primary := ds.writer
|
||||||
|
replica := ds.reader.(*sqlx.DB)
|
||||||
|
replicaDB := testName + testReplicaDatabaseSuffix
|
||||||
|
last := time.Now().Add(-time.Minute)
|
||||||
|
|
||||||
|
// drop all foreign keys in the replica, as that causes issues even with
|
||||||
|
// FOREIGN_KEY_CHECKS=0
|
||||||
|
var fks []struct {
|
||||||
|
TableName string `db:"table_name"`
|
||||||
|
ConstraintName string `db:"constraint_name"`
|
||||||
|
}
|
||||||
|
err := primary.SelectContext(ctx, &fks, `
|
||||||
|
SELECT
|
||||||
|
table_name, constraint_name
|
||||||
|
FROM
|
||||||
|
information_schema.key_column_usage
|
||||||
|
WHERE
|
||||||
|
table_schema = ? AND
|
||||||
|
referenced_table_name IS NOT NULL`, testName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, fk := range fks {
|
||||||
|
stmt := fmt.Sprintf(`ALTER TABLE %s.%s DROP FOREIGN KEY %s`, replicaDB, fk.TableName, fk.ConstraintName)
|
||||||
|
_, err := replica.ExecContext(ctx, stmt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case out := <-ch:
|
||||||
|
// identify tables with changes since the last call
|
||||||
|
var tables []string
|
||||||
|
err := primary.SelectContext(ctx, &tables, `
|
||||||
|
SELECT
|
||||||
|
table_name
|
||||||
|
FROM
|
||||||
|
information_schema.tables
|
||||||
|
WHERE
|
||||||
|
table_schema = ? AND
|
||||||
|
table_type = 'BASE TABLE' AND
|
||||||
|
update_time >= ?`, testName, last)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = primary.GetContext(ctx, &last, `
|
||||||
|
SELECT
|
||||||
|
MAX(update_time)
|
||||||
|
FROM
|
||||||
|
information_schema.tables
|
||||||
|
WHERE
|
||||||
|
table_schema = ? AND
|
||||||
|
table_type = 'BASE TABLE'`, testName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// replicate by dropping the existing table and re-creating it from
|
||||||
|
// the primary.
|
||||||
|
for _, tbl := range tables {
|
||||||
|
stmt := fmt.Sprintf(`DROP TABLE IF EXISTS %s.%s`, replicaDB, tbl)
|
||||||
|
t.Log(stmt)
|
||||||
|
_, err = replica.ExecContext(ctx, stmt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
stmt = fmt.Sprintf(`CREATE TABLE %s.%s SELECT * FROM %s.%s`, replicaDB, tbl, testName, tbl)
|
||||||
|
t.Log(stmt)
|
||||||
|
_, err = replica.ExecContext(ctx, stmt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out <- struct{}{}
|
||||||
|
t.Logf("replication step executed, next will consider updates since %s", last)
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// set RunReplication to a function that triggers the replication and waits
|
||||||
|
// for it to complete.
|
||||||
|
opts.RunReplication = func() {
|
||||||
|
done := make(chan struct{})
|
||||||
|
ch <- done
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// initializeDatabase loads the dumped schema into a newly created database in
|
// initializeDatabase loads the dumped schema into a newly created database in
|
||||||
// MySQL. This is much faster than running the full set of migrations on each
|
// MySQL. This is much faster than running the full set of migrations on each
|
||||||
// test.
|
// test.
|
||||||
func initializeDatabase(t *testing.T, testName string) *Datastore {
|
func initializeDatabase(t *testing.T, testName string, opts *DatastoreTestOptions) *Datastore {
|
||||||
_, filename, _, _ := runtime.Caller(0)
|
_, filename, _, _ := runtime.Caller(0)
|
||||||
base := path.Dir(filename)
|
base := path.Dir(filename)
|
||||||
schema, err := ioutil.ReadFile(path.Join(base, "schema.sql"))
|
schema, err := ioutil.ReadFile(path.Join(base, "schema.sql"))
|
||||||
@ -47,6 +170,14 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute the schema for the test db, and once more for the replica db if
|
||||||
|
// that option is set.
|
||||||
|
dbs := []string{testName}
|
||||||
|
if opts.Replica {
|
||||||
|
dbs = append(dbs, testName+testReplicaDatabaseSuffix)
|
||||||
|
}
|
||||||
|
for _, dbName := range dbs {
|
||||||
// Load schema from dumpfile
|
// Load schema from dumpfile
|
||||||
if out, err := exec.Command(
|
if out, err := exec.Command(
|
||||||
"docker-compose", "exec", "-T", "mysql_test",
|
"docker-compose", "exec", "-T", "mysql_test",
|
||||||
@ -56,24 +187,43 @@ func initializeDatabase(t *testing.T, testName string) *Datastore {
|
|||||||
"-e",
|
"-e",
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
|
"DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;",
|
||||||
testName, testName, testName, schema,
|
dbName, dbName, dbName, schema,
|
||||||
),
|
),
|
||||||
).CombinedOutput(); err != nil {
|
).CombinedOutput(); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
t.Error(string(out))
|
t.Error(string(out))
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
return connectMySQL(t, testName)
|
}
|
||||||
|
return connectMySQL(t, testName, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateMySQLDS(t *testing.T) *Datastore {
|
// DatastoreTestOptions configures how the test datastore is created
|
||||||
|
// by CreateMySQLDSWithOptions.
|
||||||
|
type DatastoreTestOptions struct {
|
||||||
|
// Replica indicates that a read replica test database should be created.
|
||||||
|
Replica bool
|
||||||
|
|
||||||
|
// RunReplication is the function to call to execute the replication of all
|
||||||
|
// missing changes from the primary to the replica. The function is created
|
||||||
|
// and set automatically by CreateMySQLDSWithOptions. The test is in full
|
||||||
|
// control of when the replication is executed.
|
||||||
|
RunReplication func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
|
||||||
if _, ok := os.LookupEnv("MYSQL_TEST"); !ok {
|
if _, ok := os.LookupEnv("MYSQL_TEST"); !ok {
|
||||||
t.Skip("MySQL tests are disabled")
|
t.Skip("MySQL tests are disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
pc, _, _, ok := runtime.Caller(1)
|
if opts == nil {
|
||||||
|
// so it is never nil in internal helper functions
|
||||||
|
opts = new(DatastoreTestOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, _, _, ok := runtime.Caller(2)
|
||||||
details := runtime.FuncForPC(pc)
|
details := runtime.FuncForPC(pc)
|
||||||
if !ok || details == nil {
|
if !ok || details == nil {
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
@ -83,5 +233,13 @@ func CreateMySQLDS(t *testing.T) *Datastore {
|
|||||||
strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_",
|
strings.TrimPrefix(details.Name(), "github.com/fleetdm/fleet/v4/"), "/", "_",
|
||||||
)
|
)
|
||||||
cleanName = strings.ReplaceAll(cleanName, ".", "_")
|
cleanName = strings.ReplaceAll(cleanName, ".", "_")
|
||||||
return initializeDatabase(t, cleanName)
|
return initializeDatabase(t, cleanName, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datastore {
|
||||||
|
return createMySQLDSWithOptions(t, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateMySQLDS(t *testing.T) *Datastore {
|
||||||
|
return createMySQLDSWithOptions(t, nil)
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ func (d *Datastore) findUser(searchCol string, searchVal interface{}) (*fleet.Us
|
|||||||
|
|
||||||
user := &fleet.User{}
|
user := &fleet.User{}
|
||||||
|
|
||||||
err := d.db.Get(user, sqlStatement, searchVal)
|
err := d.reader.Get(user, sqlStatement, searchVal)
|
||||||
if err != nil && err == sql.ErrNoRows {
|
if err != nil && err == sql.ErrNoRows {
|
||||||
return nil, notFound("User").
|
return nil, notFound("User").
|
||||||
WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal))
|
WithMessage(fmt.Sprintf("with %s=%v", searchCol, searchVal))
|
||||||
@ -104,7 +104,7 @@ func (d *Datastore) ListUsers(opt fleet.UserListOptions) ([]*fleet.User, error)
|
|||||||
sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions)
|
sqlStatement = appendListOptionsToSQL(sqlStatement, opt.ListOptions)
|
||||||
users := []*fleet.User{}
|
users := []*fleet.User{}
|
||||||
|
|
||||||
if err := d.db.Select(&users, sqlStatement, params...); err != nil {
|
if err := d.reader.Select(&users, sqlStatement, params...); err != nil {
|
||||||
return nil, errors.Wrap(err, "list users")
|
return nil, errors.Wrap(err, "list users")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ func (d *Datastore) saveUser(tx *sqlx.Tx, user *fleet.User) error {
|
|||||||
global_role = ?
|
global_role = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
result, err := d.db.Exec(sqlStatement,
|
result, err := tx.Exec(sqlStatement,
|
||||||
user.Password,
|
user.Password,
|
||||||
user.Salt,
|
user.Salt,
|
||||||
user.Name,
|
user.Name,
|
||||||
@ -221,7 +221,7 @@ func (d *Datastore) loadTeamsForUsers(users []*fleet.User) error {
|
|||||||
fleet.UserTeam
|
fleet.UserTeam
|
||||||
UserID uint `db:"user_id"`
|
UserID uint `db:"user_id"`
|
||||||
}
|
}
|
||||||
if err := d.db.Select(&rows, sql, args...); err != nil {
|
if err := d.reader.Select(&rows, sql, args...); err != nil {
|
||||||
return errors.Wrap(err, "get loadTeamsForUsers")
|
return errors.Wrap(err, "get loadTeamsForUsers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user