Support per-task configuration for async host processing configuration (#5700)

This commit is contained in:
Martin Angers 2022-05-16 09:44:50 -04:00 committed by GitHub
parent 6027534a9b
commit e6b90ca8b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 606 additions and 248 deletions

View File

@ -0,0 +1 @@
* Support an extended syntax to configure asynchronous host processing task on a per-task basis.

View File

@ -25,7 +25,7 @@ import (
"github.com/fleetdm/fleet/v4/ee/server/licensing"
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/config"
configpkg "github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/cached_mysql"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
@ -70,7 +70,7 @@ type initializer interface {
Initialize() error
}
func createServeCmd(configManager config.Manager) *cobra.Command {
func createServeCmd(configManager configpkg.Manager) *cobra.Command {
// Whether to enable the debug endpoints
debug := false
// Whether to enable developer options
@ -302,20 +302,7 @@ the way that the Fleet server works.
failingPolicySet := redis_policy_set.NewFailing(redisPool)
task := &async.Task{
Datastore: ds,
Pool: redisPool,
Clock: clock.C,
AsyncEnabled: config.Osquery.EnableAsyncHostProcessing,
LockTimeout: config.Osquery.AsyncHostCollectLockTimeout,
LogStatsInterval: config.Osquery.AsyncHostCollectLogStatsInterval,
InsertBatch: config.Osquery.AsyncHostInsertBatch,
DeleteBatch: config.Osquery.AsyncHostDeleteBatch,
UpdateBatch: config.Osquery.AsyncHostUpdateBatch,
RedisPopCount: config.Osquery.AsyncHostRedisPopCount,
RedisScanKeysCount: config.Osquery.AsyncHostRedisScanKeysCount,
CollectorInterval: config.Osquery.AsyncHostCollectInterval,
}
task := async.NewTask(ds, redisPool, clock.C, config.Osquery)
if config.Sentry.Dsn != "" {
v := version.Version()
@ -361,7 +348,8 @@ the way that the Fleet server works.
cancelBackground := runCrons(ds, task, kitlog.With(logger, "component", "crons"), config, license, failingPolicySet)
// Flush seen hosts every second
if !task.AsyncEnabled {
hostsAsyncCfg := config.Osquery.AsyncConfigForTask(configpkg.AsyncTaskHostLastSeen)
if !hostsAsyncCfg.Enabled {
go func() {
for range time.Tick(time.Duration(rand.Intn(10)+1) * time.Second) {
if err := task.FlushHostsLastSeen(context.Background(), clock.C.Now()); err != nil {
@ -642,7 +630,7 @@ func runCrons(
ds fleet.Datastore,
task *async.Task,
logger kitlog.Logger,
config config.FleetConfig,
config configpkg.FleetConfig,
license *fleet.LicenseInfo,
failingPoliciesSet fleet.FailingPolicySet,
) context.CancelFunc {
@ -654,7 +642,7 @@ func runCrons(
}
// StartCollectors starts a goroutine per collector, using ctx to cancel.
task.StartCollectors(ctx, config.Osquery.AsyncHostCollectMaxJitterPercent, kitlog.With(logger, "cron", "async_task"))
task.StartCollectors(ctx, kitlog.With(logger, "cron", "async_task"))
go cronDB(ctx, ds, kitlog.With(logger, "cron", "cleanups"), ourIdentifier, license)
go cronVulnerabilities(
@ -675,7 +663,7 @@ func getTLSConfig(profile string) *tls.Config {
}
switch profile {
case config.TLSProfileModern:
case configpkg.TLSProfileModern:
cfg.MinVersion = tls.VersionTLS13
cfg.CurvePreferences = append(cfg.CurvePreferences,
tls.X25519,
@ -692,7 +680,7 @@ func getTLSConfig(profile string) *tls.Config {
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
)
case config.TLSProfileIntermediate:
case configpkg.TLSProfileIntermediate:
cfg.MinVersion = tls.VersionTLS12
cfg.CurvePreferences = append(cfg.CurvePreferences,
tls.X25519,

View File

@ -1071,6 +1071,12 @@ to the amount of time it takes for Fleet to give the host the label queries.
Note that currently, if both the failing policies webhook *and* this `osquery.enable_async_host_processing` option are set, some failing policies webhooks could be missing (some transitions from succeeding to failing or vice-versa could happen without triggering a webhook request).
It can be set to a single boolean value ("true" or "false"), which controls all async host processing tasks, or it can be set for specific async tasks using a syntax similar to an URL query string or parameters in a Data Source Name (DSN) string, e.g. "label_membership=true&policy_membership=true". When using the per-task syntax, omitted tasks get the default value. The supported async task names are:
* `label_membership` for updating the hosts' label query execution;
* `policy_membership` for updating the hosts' policy membership results;
* `host_last_seen` for updating the hosts' last seen timestamp.
- Default value: false
- Environment variable: `FLEET_OSQUERY_ENABLE_ASYNC_HOST_PROCESSING`
- Config file format:
@ -1084,6 +1090,8 @@ Note that currently, if both the failing policies webhook *and* this `osquery.en
Applies only when `osquery_enable_async_host_processing` is enabled. Sets the interval at which the host data will be collected into the database. Each Fleet instance will attempt to do the collection at this interval (with some optional jitter added, see `osquery_async_host_collect_max_jitter_percent`), with only one succeeding to get the exclusive lock.
It can be set to a single duration value (e.g. "30s"), which defines the interval for all async host processing tasks, or it can be set for specific async tasks using a syntax similar to an URL query string or parameters in a Data Source Name (DSN) string, e.g. "label_membership=10s&policy_membership=1m". When using the per-task syntax, omitted tasks get the default value. See [osquery_enable_async_host_processing](#osquery_enable_async_host_processing) for the supported async task names.
- Default value: 30s
- Environment variable: `FLEET_OSQUERY_ASYNC_HOST_COLLECT_INTERVAL`
- Config file format:
@ -1110,6 +1118,8 @@ Applies only when `osquery_enable_async_host_processing` is enabled. A number in
Applies only when `osquery_enable_async_host_processing` is enabled. Timeout of the lock acquired by a Fleet instance to collect host data into the database. If the collection runs for too long or the instance crashes unexpectedly, the lock will be automatically released after this duration and another Fleet instance can proceed with the next collection.
It can be set to a single duration value (e.g. "1m"), which defines the lock timeout for all async host processing tasks, or it can be set for specific async tasks using a syntax similar to an URL query string or parameters in a Data Source Name (DSN) string, e.g. "label_membership=2m&policy_membership=5m". When using the per-task syntax, omitted tasks get the default value. See [osquery_enable_async_host_processing](#osquery_enable_async_host_processing) for the supported async task names.
- Default value: 1m
- Environment variable: `FLEET_OSQUERY_ASYNC_HOST_COLLECT_LOCK_TIMEOUT`
- Config file format:

View File

@ -6,9 +6,11 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
@ -116,10 +118,10 @@ type OsqueryConfig struct {
ResultLogFile string `yaml:"result_log_file"`
EnableLogRotation bool `yaml:"enable_log_rotation"`
MaxJitterPercent int `yaml:"max_jitter_percent"`
EnableAsyncHostProcessing bool `yaml:"enable_async_host_processing"`
AsyncHostCollectInterval time.Duration `yaml:"async_host_collect_interval"`
EnableAsyncHostProcessing string `yaml:"enable_async_host_processing"` // true/false or per-task
AsyncHostCollectInterval string `yaml:"async_host_collect_interval"` // duration or per-task
AsyncHostCollectMaxJitterPercent int `yaml:"async_host_collect_max_jitter_percent"`
AsyncHostCollectLockTimeout time.Duration `yaml:"async_host_collect_lock_timeout"`
AsyncHostCollectLockTimeout string `yaml:"async_host_collect_lock_timeout"` // duration or per-task
AsyncHostCollectLogStatsInterval time.Duration `yaml:"async_host_collect_log_stats_interval"`
AsyncHostInsertBatch int `yaml:"async_host_insert_batch"`
AsyncHostDeleteBatch int `yaml:"async_host_delete_batch"`
@ -129,6 +131,55 @@ type OsqueryConfig struct {
MinSoftwareLastOpenedAtDiff time.Duration `yaml:"min_software_last_opened_at_diff"`
}
// AsyncTaskName is the type of names that identify tasks supporting
// asynchronous execution.
type AsyncTaskName string
// List of names for supported async tasks.
const (
AsyncTaskLabelMembership AsyncTaskName = "label_membership"
AsyncTaskPolicyMembership AsyncTaskName = "policy_membership"
AsyncTaskHostLastSeen AsyncTaskName = "host_last_seen"
)
var knownAsyncTasks = map[AsyncTaskName]struct{}{
AsyncTaskLabelMembership: {},
AsyncTaskPolicyMembership: {},
AsyncTaskHostLastSeen: {},
}
// AsyncConfigForTask returns the applicable configuration for the specified
// async task.
func (o OsqueryConfig) AsyncConfigForTask(name AsyncTaskName) AsyncProcessingConfig {
strName := string(name)
return AsyncProcessingConfig{
Enabled: configForKeyOrBool("osquery.enable_async_host_processing", strName, o.EnableAsyncHostProcessing, false),
CollectInterval: configForKeyOrDuration("osquery.async_host_collect_interval", strName, o.AsyncHostCollectInterval, 30*time.Second),
CollectMaxJitterPercent: o.AsyncHostCollectMaxJitterPercent,
CollectLockTimeout: configForKeyOrDuration("osquery.async_host_collect_lock_timeout", strName, o.AsyncHostCollectLockTimeout, 1*time.Minute),
CollectLogStatsInterval: o.AsyncHostCollectLogStatsInterval,
InsertBatch: o.AsyncHostInsertBatch,
DeleteBatch: o.AsyncHostDeleteBatch,
UpdateBatch: o.AsyncHostUpdateBatch,
RedisPopCount: o.AsyncHostRedisPopCount,
RedisScanKeysCount: o.AsyncHostRedisScanKeysCount,
}
}
// AsyncProcessingConfig is the configuration for a specific async task.
type AsyncProcessingConfig struct {
Enabled bool
CollectInterval time.Duration
CollectMaxJitterPercent int
CollectLockTimeout time.Duration
CollectLogStatsInterval time.Duration
InsertBatch int
DeleteBatch int
UpdateBatch int
RedisPopCount int
RedisScanKeysCount int
}
// LoggingConfig defines configs related to logging
type LoggingConfig struct {
Debug bool
@ -446,14 +497,14 @@ func (man Manager) addConfigs() {
"(DEPRECATED: Use filesystem.enable_log_rotation) Enable automatic rotation for osquery log files")
man.addConfigInt("osquery.max_jitter_percent", 10,
"Maximum percentage of the interval to add as jitter")
man.addConfigBool("osquery.enable_async_host_processing", false,
"Enable asynchronous processing of host-reported query results")
man.addConfigDuration("osquery.async_host_collect_interval", 30*time.Second,
"Interval to collect asynchronous host-reported query results (i.e. 30s)")
man.addConfigString("osquery.enable_async_host_processing", "false",
"Enable asynchronous processing of host-reported query results (either 'true'/'false' or set per task, e.g. 'label_membership=true&policy_membership=true')")
man.addConfigString("osquery.async_host_collect_interval", (30 * time.Second).String(),
"Interval to collect asynchronous host-reported query results (e.g. '30s' or set per task 'label_membership=10s&policy_membership=1m')")
man.addConfigInt("osquery.async_host_collect_max_jitter_percent", 10,
"Maximum percentage of the interval to collect asynchronous host results")
man.addConfigDuration("osquery.async_host_collect_lock_timeout", 1*time.Minute,
"Timeout of the exclusive lock held during async host collection")
man.addConfigString("osquery.async_host_collect_lock_timeout", (1 * time.Minute).String(),
"Timeout of the exclusive lock held during async host collection (e.g. '30s' or set per task 'label_membership=10s&policy_membership=1m'")
man.addConfigDuration("osquery.async_host_collect_log_stats_interval", 1*time.Minute,
"Interval at which async host collection statistics are logged (0 disables logging of stats)")
man.addConfigInt("osquery.async_host_insert_batch", 2000,
@ -613,7 +664,7 @@ func (man Manager) LoadConfig() FleetConfig {
}
}
return FleetConfig{
cfg := FleetConfig{
Mysql: loadMysqlConfig("mysql"),
MysqlReadReplica: loadMysqlConfig("mysql_read_replica"),
Redis: RedisConfig{
@ -675,10 +726,10 @@ func (man Manager) LoadConfig() FleetConfig {
DetailUpdateInterval: man.getConfigDuration("osquery.detail_update_interval"),
EnableLogRotation: man.getConfigBool("osquery.enable_log_rotation"),
MaxJitterPercent: man.getConfigInt("osquery.max_jitter_percent"),
EnableAsyncHostProcessing: man.getConfigBool("osquery.enable_async_host_processing"),
AsyncHostCollectInterval: man.getConfigDuration("osquery.async_host_collect_interval"),
EnableAsyncHostProcessing: man.getConfigString("osquery.enable_async_host_processing"),
AsyncHostCollectInterval: man.getConfigString("osquery.async_host_collect_interval"),
AsyncHostCollectMaxJitterPercent: man.getConfigInt("osquery.async_host_collect_max_jitter_percent"),
AsyncHostCollectLockTimeout: man.getConfigDuration("osquery.async_host_collect_lock_timeout"),
AsyncHostCollectLockTimeout: man.getConfigString("osquery.async_host_collect_lock_timeout"),
AsyncHostCollectLogStatsInterval: man.getConfigDuration("osquery.async_host_collect_log_stats_interval"),
AsyncHostInsertBatch: man.getConfigInt("osquery.async_host_insert_batch"),
AsyncHostDeleteBatch: man.getConfigInt("osquery.async_host_delete_batch"),
@ -779,6 +830,13 @@ func (man Manager) LoadConfig() FleetConfig {
},
},
}
// ensure immediately that the async config is valid for all known tasks
for task := range knownAsyncTasks {
cfg.Osquery.AsyncConfigForTask(task)
}
return cfg
}
// IsSet determines whether a given config key has been explicitly set by any
@ -951,6 +1009,74 @@ func (man Manager) getConfigDuration(key string) time.Duration {
return durationVal
}
// panics if the config is invalid, this is handled by Viper (this is how all
// getConfigT helpers indicate errors). The default value is only applied if
// there is no task-specific config (i.e. no "task=true" config format for that
// task). If the configuration key was not set at all, it automatically
// inherited the general default configured for that key (via
// man.addConfigBool).
func configForKeyOrBool(key, task, val string, def bool) bool {
parseVal := func(v string) bool {
if v == "" {
return false
}
b, err := strconv.ParseBool(v)
if err != nil {
panic("Unable to cast to bool for key " + key + ": " + err.Error())
}
return b
}
if !strings.Contains(val, "=") {
// simple case, val is a bool
return parseVal(val)
}
q, err := url.ParseQuery(val)
if err != nil {
panic("Invalid query format for key " + key + ": " + err.Error())
}
if v := q.Get(task); v != "" {
return parseVal(v)
}
return def
}
// panics if the config is invalid, this is handled by Viper (this is how all
// getConfigT helpers indicate errors). The default value is only applied if
// there is no task-specific config (i.e. no "task=10s" config format for that
// task). If the configuration key was not set at all, it automatically
// inherited the general default configured for that key (via
// man.addConfigDuration).
func configForKeyOrDuration(key, task, val string, def time.Duration) time.Duration {
parseVal := func(v string) time.Duration {
if v == "" {
return 0
}
d, err := time.ParseDuration(v)
if err != nil {
panic("Unable to cast to time.Duration for key " + key + ": " + err.Error())
}
return d
}
if !strings.Contains(val, "=") {
// simple case, val is a duration
return parseVal(val)
}
q, err := url.ParseQuery(val)
if err != nil {
panic("Invalid query format for key " + key + ": " + err.Error())
}
if v := q.Get(task); v != "" {
return parseVal(v)
}
return def
}
// loadConfigFile handles the loading of the config file.
func (man Manager) loadConfigFile() {
man.viper.SetConfigType("yaml")

View File

@ -48,6 +48,12 @@ func TestConfigRoundtrip(t *testing.T) {
// we have to explicitly set value for this key as it will only
// accept intermediate or modern
key_v.SetString(TLSProfileModern)
case "EnableAsyncHostProcessing":
// supports a bool or per-task config
key_v.SetString("true")
case "AsyncHostCollectInterval", "AsyncHostCollectLockTimeout":
// supports a duration or per-task config
key_v.SetString("30s")
default:
key_v.SetString(v.Elem().Type().Field(conf_index).Name + "_" + conf_v.Type().Field(key_index).Name)
}
@ -76,6 +82,259 @@ func TestConfigRoundtrip(t *testing.T) {
assert.Equal(t, *original, man.LoadConfig())
}
func TestConfigOsqueryAsync(t *testing.T) {
cases := []struct {
desc string
yaml string
envVars []string
panics bool
wantLabelCfg AsyncProcessingConfig
}{
{
desc: "default",
wantLabelCfg: AsyncProcessingConfig{
Enabled: false,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml set enabled true",
yaml: `
osquery:
enable_async_host_processing: true`,
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml set enabled yes",
yaml: `
osquery:
enable_async_host_processing: yes`,
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml set enabled on",
yaml: `
osquery:
enable_async_host_processing: on`,
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml set enabled invalid",
yaml: `
osquery:
enable_async_host_processing: nope`,
panics: true,
},
{
desc: "yaml set enabled per-task",
yaml: `
osquery:
enable_async_host_processing: label_membership=true&policy_membership=false`,
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml set invalid per-task",
yaml: `
osquery:
enable_async_host_processing: label_membership=nope&policy_membership=false`,
panics: true,
},
{
desc: "envvar set enabled",
envVars: []string{"FLEET_OSQUERY_ENABLE_ASYNC_HOST_PROCESSING=true"},
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "envvar set enabled on",
envVars: []string{"FLEET_OSQUERY_ENABLE_ASYNC_HOST_PROCESSING=on"}, // on/off, yes/no is only valid in yaml
panics: true,
},
{
desc: "envvar set enabled per task",
envVars: []string{"FLEET_OSQUERY_ENABLE_ASYNC_HOST_PROCESSING=policy_membership=false&label_membership=true"},
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 30 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml collect interval lock timeout",
yaml: `
osquery:
enable_async_host_processing: true
async_host_collect_interval: 10s
async_host_collect_lock_timeout: 20s`,
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 10 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 20 * time.Second,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml collect interval lock timeout per task",
yaml: `
osquery:
enable_async_host_processing: true
async_host_collect_interval: label_membership=10s
async_host_collect_lock_timeout: policy_membership=20s`,
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 10 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 1 * time.Minute,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 2000,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
{
desc: "yaml env var override",
yaml: `
osquery:
enable_async_host_processing: false
async_host_collect_interval: label_membership=10s
async_host_collect_lock_timeout: policy_membership=20s
async_host_insert_batch: 10`,
envVars: []string{
"FLEET_OSQUERY_ENABLE_ASYNC_HOST_PROCESSING=policy_membership=false&label_membership=true",
"FLEET_OSQUERY_ASYNC_HOST_COLLECT_INTERVAL=policy_membership=30s&label_membership=50s",
"FLEET_OSQUERY_ASYNC_HOST_COLLECT_LOCK_TIMEOUT=40s",
},
wantLabelCfg: AsyncProcessingConfig{
Enabled: true,
CollectInterval: 50 * time.Second,
CollectMaxJitterPercent: 10,
CollectLockTimeout: 40 * time.Second,
CollectLogStatsInterval: 1 * time.Minute,
InsertBatch: 10,
DeleteBatch: 2000,
UpdateBatch: 1000,
RedisPopCount: 1000,
RedisScanKeysCount: 1000,
},
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
var cmd cobra.Command
// Leaving this flag unset means that no attempt will be made to load
// the config file
cmd.PersistentFlags().StringP("config", "c", "", "Path to a configuration file")
man := NewManager(&cmd)
// load the yaml config
man.viper.SetConfigType("yaml")
require.NoError(t, man.viper.ReadConfig(strings.NewReader(c.yaml)))
// TODO: tried to test command-line flags too by using cmd.SetArgs to
// test-case values, but that didn't seem to work, not sure how it can
// be done in our particular setup.
// set the environment variables
os.Clearenv()
for _, env := range c.envVars {
kv := strings.SplitN(env, "=", 2)
os.Setenv(kv[0], kv[1])
}
var loadedCfg FleetConfig
if c.panics {
require.Panics(t, func() {
loadedCfg = man.LoadConfig()
})
} else {
require.NotPanics(t, func() {
loadedCfg = man.LoadConfig()
})
got := loadedCfg.Osquery.AsyncConfigForTask(AsyncTaskLabelMembership)
require.Equal(t, c.wantLabelCfg, got)
}
})
}
}
func TestToTLSConfig(t *testing.T) {
dir := t.TempDir()
caFile, certFile, keyFile, garbageFile := filepath.Join(dir, "ca"),

View File

@ -6,6 +6,7 @@ import (
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/getsentry/sentry-go"
@ -17,94 +18,75 @@ import (
const collectorLockKey = "locks:async_collector:{%s}"
type Task struct {
Datastore fleet.Datastore
Pool fleet.RedisPool
Clock clock.Clock
// AsyncEnabled indicates if async processing is enabled in the
// configuration. Note that Pool can be nil if this is false.
AsyncEnabled bool
LockTimeout time.Duration
LogStatsInterval time.Duration
InsertBatch int
DeleteBatch int
UpdateBatch int
RedisPopCount int
RedisScanKeysCount int
CollectorInterval time.Duration
datastore fleet.Datastore
pool fleet.RedisPool
clock clock.Clock
taskConfigs map[config.AsyncTaskName]config.AsyncProcessingConfig
seenHostSet seenHostSet
}
// NewTask configures and returns a Task.
func NewTask(ds fleet.Datastore, pool fleet.RedisPool, clck clock.Clock, conf config.OsqueryConfig) *Task {
taskCfgs := make(map[config.AsyncTaskName]config.AsyncProcessingConfig)
taskCfgs[config.AsyncTaskLabelMembership] = conf.AsyncConfigForTask(config.AsyncTaskLabelMembership)
taskCfgs[config.AsyncTaskPolicyMembership] = conf.AsyncConfigForTask(config.AsyncTaskPolicyMembership)
taskCfgs[config.AsyncTaskHostLastSeen] = conf.AsyncConfigForTask(config.AsyncTaskHostLastSeen)
return &Task{
datastore: ds,
pool: pool,
clock: clck,
taskConfigs: taskCfgs,
}
}
// Collect runs the various collectors as distinct background goroutines if
// async processing is enabled. Each collector will stop processing when ctx
// is done.
func (t *Task) StartCollectors(ctx context.Context, jitterPct int, logger kitlog.Logger) {
if !t.AsyncEnabled {
level.Debug(logger).Log("task", "async disabled, not starting collectors")
return
}
level.Debug(logger).Log("task", "async enabled, starting collectors", "interval", t.CollectorInterval, "jitter", jitterPct)
func (t *Task) StartCollectors(ctx context.Context, logger kitlog.Logger) {
collectorErrHandler := func(name string, err error) {
level.Error(logger).Log("err", fmt.Sprintf("%s collector", name), "details", err)
sentry.CaptureException(err)
}
labelColl := &collector{
name: "collect_labels",
pool: t.Pool,
ds: t.Datastore,
execInterval: t.CollectorInterval,
jitterPct: jitterPct,
lockTimeout: t.LockTimeout,
handler: t.collectLabelQueryExecutions,
errHandler: collectorErrHandler,
handlers := map[config.AsyncTaskName]collectorHandlerFunc{
config.AsyncTaskLabelMembership: t.collectLabelQueryExecutions,
config.AsyncTaskPolicyMembership: t.collectPolicyQueryExecutions,
config.AsyncTaskHostLastSeen: t.collectHostsLastSeen,
}
for task, cfg := range t.taskConfigs {
if !cfg.Enabled {
level.Debug(logger).Log("task", "async disabled, not starting collector", "name", task)
continue
}
policyColl := &collector{
name: "collect_policies",
pool: t.Pool,
ds: t.Datastore,
execInterval: t.CollectorInterval,
jitterPct: jitterPct,
lockTimeout: t.LockTimeout,
handler: t.collectPolicyQueryExecutions,
errHandler: collectorErrHandler,
}
lastSeenColl := &collector{
name: "collect_last_seen",
pool: t.Pool,
ds: t.Datastore,
execInterval: t.CollectorInterval,
jitterPct: jitterPct,
lockTimeout: t.LockTimeout,
handler: t.collectHostsLastSeen,
errHandler: collectorErrHandler,
}
colls := []*collector{labelColl, policyColl, lastSeenColl}
for _, coll := range colls {
cfg := cfg // shadow as local var to avoid capturing the iteration var
coll := &collector{
name: "collect_" + string(task),
pool: t.pool,
ds: t.datastore,
execInterval: cfg.CollectInterval,
jitterPct: cfg.CollectMaxJitterPercent,
lockTimeout: cfg.CollectLockTimeout,
handler: handlers[task],
errHandler: collectorErrHandler,
}
go coll.Start(ctx)
}
level.Debug(logger).Log("task", "async enabled, starting collectors", "name", task, "interval", cfg.CollectInterval, "jitter", cfg.CollectMaxJitterPercent)
// log stats at regular intervals
if t.LogStatsInterval > 0 {
go func() {
tick := time.Tick(t.LogStatsInterval)
for {
select {
case <-tick:
for _, coll := range colls {
if cfg.CollectLogStatsInterval > 0 {
go func() {
tick := time.Tick(cfg.CollectLogStatsInterval)
for {
select {
case <-tick:
stats := coll.ReadStats()
level.Debug(logger).Log("stats", fmt.Sprintf("%#v", stats), "name", coll.name)
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
}()
}
}
}

View File

@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -20,7 +21,8 @@ const (
// RecordHostLastSeen records that the specified host ID was seen.
func (t *Task) RecordHostLastSeen(ctx context.Context, hostID uint) error {
if !t.AsyncEnabled {
cfg := t.taskConfigs[config.AsyncTaskHostLastSeen]
if !cfg.Enabled {
t.seenHostSet.addHostID(hostID)
return nil
}
@ -31,7 +33,7 @@ func (t *Task) RecordHostLastSeen(ctx context.Context, hostID uint) error {
// collected yet - 1 week or 10 * the collector interval, whichever is
// biggest.
ttl := hostSeenKeysMinTTL
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
ttl = maxTTL
}
@ -44,9 +46,9 @@ func (t *Task) RecordHostLastSeen(ctx context.Context, hostID uint) error {
return redis.call('EXPIRE', KEYS[1], ARGV[2])
`)
conn := t.Pool.Get()
conn := t.pool.Get()
defer conn.Close()
if err := redis.BindConn(t.Pool, conn, hostSeenRecordedHostIDsKey); err != nil {
if err := redis.BindConn(t.pool, conn, hostSeenRecordedHostIDsKey); err != nil {
return ctxerr.Wrap(ctx, err, "bind redis connection")
}
@ -61,9 +63,10 @@ func (t *Task) RecordHostLastSeen(ctx context.Context, hostID uint) error {
// no-op if asychronous host processing is enabled, because then it is the
// task collector that will process the writes to mysql.
func (t *Task) FlushHostsLastSeen(ctx context.Context, now time.Time) error {
if !t.AsyncEnabled {
cfg := t.taskConfigs[config.AsyncTaskHostLastSeen]
if !cfg.Enabled {
hostIDs := t.seenHostSet.getAndClearHostIDs()
return t.Datastore.MarkHostsSeen(ctx, hostIDs, now)
return t.datastore.MarkHostsSeen(ctx, hostIDs, now)
}
// no-op, flushing the hosts' last seen is done via the cron that runs the
@ -72,6 +75,8 @@ func (t *Task) FlushHostsLastSeen(ctx context.Context, now time.Time) error {
}
func (t *Task) collectHostsLastSeen(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
cfg := t.taskConfigs[config.AsyncTaskHostLastSeen]
hostIDs, err := t.loadSeenHostsIDs(ctx, pool)
if err != nil {
return err
@ -85,8 +90,8 @@ func (t *Task) collectHostsLastSeen(ctx context.Context, ds fleet.Datastore, poo
// globally sort the host IDs so they are sent ordered as batches to MarkHostsSeen
sort.Slice(hostIDs, func(i, j int) bool { return hostIDs[i] < hostIDs[j] })
ts := t.Clock.Now()
batch := make([]uint, t.InsertBatch)
ts := t.clock.Now()
batch := make([]uint, cfg.InsertBatch)
for {
n := copy(batch, hostIDs)
if n == 0 {
@ -110,11 +115,13 @@ func (t *Task) collectHostsLastSeen(ctx context.Context, ds fleet.Datastore, poo
}
func (t *Task) loadSeenHostsIDs(ctx context.Context, pool fleet.RedisPool) ([]uint, error) {
cfg := t.taskConfigs[config.AsyncTaskHostLastSeen]
// compute the TTL for the processing key just as we do for the storage key,
// in case the collection fails before removing the working key, we don't
// want it to stick around forever.
ttl := hostSeenKeysMinTTL
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
ttl = maxTTL
}
@ -141,7 +148,7 @@ func (t *Task) loadSeenHostsIDs(ctx context.Context, pool fleet.RedisPool) ([]ui
var ids []uint
cursor := 0
for {
res, err := redigo.Values(conn.Do("SSCAN", hostSeenProcessingHostIDsKey, cursor, "COUNT", t.RedisScanKeysCount))
res, err := redigo.Values(conn.Do("SSCAN", hostSeenProcessingHostIDsKey, cursor, "COUNT", cfg.RedisScanKeysCount))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "scan seen host ids")
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -143,14 +144,13 @@ func testCollectHostsLastSeen(t *testing.T, ds *mysql.Datastore, pool fleet.Redi
// run the collection
var stats collectorExecStats
task := Task{
Clock: mockTime,
InsertBatch: batchSizes,
UpdateBatch: batchSizes,
DeleteBatch: batchSizes,
RedisPopCount: batchSizes,
RedisScanKeysCount: 10,
}
task := NewTask(nil, nil, mockTime, config.OsqueryConfig{
AsyncHostInsertBatch: batchSizes,
AsyncHostUpdateBatch: batchSizes,
AsyncHostDeleteBatch: batchSizes,
AsyncHostRedisPopCount: batchSizes,
AsyncHostRedisScanKeysCount: 10,
})
err := task.collectHostsLastSeen(ctx, ds, pool, &stats)
require.NoError(t, err)
require.Equal(t, wantStats, stats)
@ -177,12 +177,7 @@ func testRecordHostLastSeenSync(t *testing.T, ds *mock.Store, pool fleet.RedisPo
return nil
}
task := Task{
Datastore: ds,
Pool: pool,
Clock: clock.C,
AsyncEnabled: false,
}
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{})
err := task.RecordHostLastSeen(ctx, 1)
require.NoError(t, err)
err = task.RecordHostLastSeen(ctx, 2)
@ -219,15 +214,11 @@ func testRecordHostLastSeenAsync(t *testing.T, ds *mock.Store, pool fleet.RedisP
return nil
}
task := Task{
Datastore: ds,
Pool: pool,
AsyncEnabled: true,
Clock: clock.C,
InsertBatch: 2,
RedisScanKeysCount: 10,
}
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{
EnableAsyncHostProcessing: "true",
AsyncHostInsertBatch: 2,
AsyncHostRedisScanKeysCount: 10,
})
err := task.RecordHostLastSeen(ctx, 1)
require.NoError(t, err)

View File

@ -5,6 +5,7 @@ import (
"fmt"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -19,9 +20,10 @@ const (
)
func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
if !t.AsyncEnabled {
cfg := t.taskConfigs[config.AsyncTaskLabelMembership]
if !cfg.Enabled {
host.LabelUpdatedAt = ts
return t.Datastore.RecordLabelQueryExecutions(ctx, host, results, ts, deferred)
return t.datastore.RecordLabelQueryExecutions(ctx, host, results, ts, deferred)
}
keySet := fmt.Sprintf(labelMembershipHostKey, host.ID)
@ -37,7 +39,7 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host,
// the collector will have plenty of time to run (multiple times) to try to
// persist all the data in mysql.
ttl := labelMembershipKeysMinTTL
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
ttl = maxTTL
}
@ -65,9 +67,9 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host,
args = args.Add(score, k)
}
conn := t.Pool.Get()
conn := t.pool.Get()
defer conn.Close()
if err := redis.BindConn(t.Pool, conn, keySet, keyTs); err != nil {
if err := redis.BindConn(t.pool, conn, keySet, keyTs); err != nil {
return ctxerr.Wrap(ctx, err, "bind redis connection")
}
@ -79,14 +81,16 @@ func (t *Task) RecordLabelQueryExecutions(ctx context.Context, host *fleet.Host,
// outside of the redis script because in Redis Cluster mode the key may not
// live on the same node as the host's keys. At the same time, purge any
// entry in the set that is older than now - TTL.
if _, err := storePurgeActiveHostID(t.Pool, labelMembershipActiveHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
if _, err := storePurgeActiveHostID(t.pool, labelMembershipActiveHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
return ctxerr.Wrap(ctx, err, "store active host id")
}
return nil
}
func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
hosts, err := loadActiveHostIDs(pool, labelMembershipActiveHostIDsKey, t.RedisScanKeysCount)
cfg := t.taskConfigs[config.AsyncTaskLabelMembership]
hosts, err := loadActiveHostIDs(pool, labelMembershipActiveHostIDsKey, cfg.RedisScanKeysCount)
if err != nil {
return ctxerr.Wrap(ctx, err, "load active host ids")
}
@ -100,7 +104,7 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
for {
stats.RedisCmds++
vals, err := redigo.Ints(conn.Do("ZPOPMIN", keySet, t.RedisPopCount))
vals, err := redigo.Ints(conn.Do("ZPOPMIN", keySet, cfg.RedisPopCount))
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "redis ZPOPMIN")
}
@ -122,7 +126,7 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
deletes = append(deletes, [2]uint{uint(labelID), hostID})
}
}
if items < t.RedisPopCount {
if items < cfg.RedisPopCount {
return inserts, deletes, nil
}
}
@ -157,8 +161,8 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
return ds.AsyncBatchUpdateLabelTimestamp(ctx, ids, ts)
}
insertBatch := make([][2]uint, 0, t.InsertBatch)
deleteBatch := make([][2]uint, 0, t.DeleteBatch)
insertBatch := make([][2]uint, 0, cfg.InsertBatch)
deleteBatch := make([][2]uint, 0, cfg.DeleteBatch)
for _, host := range hosts {
hid := host.HostID
ins, del, err := getKeyTuples(hid)
@ -168,13 +172,13 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
insertBatch = append(insertBatch, ins...)
deleteBatch = append(deleteBatch, del...)
if len(insertBatch) >= t.InsertBatch {
if len(insertBatch) >= cfg.InsertBatch {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
insertBatch = insertBatch[:0]
}
if len(deleteBatch) >= t.DeleteBatch {
if len(deleteBatch) >= cfg.DeleteBatch {
if err := runDeleteBatch(deleteBatch); err != nil {
return err
}
@ -200,8 +204,8 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
hostIDs[i] = host.HostID
}
ts := t.Clock.Now()
updateBatch := make([]uint, t.UpdateBatch)
ts := t.clock.Now()
updateBatch := make([]uint, cfg.UpdateBatch)
for {
n := copy(updateBatch, hostIDs)
if n == 0 {
@ -226,8 +230,10 @@ func (t *Task) collectLabelQueryExecutions(ctx context.Context, ds fleet.Datasto
}
func (t *Task) GetHostLabelReportedAt(ctx context.Context, host *fleet.Host) time.Time {
if t.AsyncEnabled {
conn := redis.ConfigureDoer(t.Pool, t.Pool.Get())
cfg := t.taskConfigs[config.AsyncTaskLabelMembership]
if cfg.Enabled {
conn := redis.ConfigureDoer(t.pool, t.pool.Get())
defer conn.Close()
key := fmt.Sprintf(labelMembershipReportedKey, host.ID)

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -204,14 +205,13 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle
// run the collection
var stats collectorExecStats
task := Task{
Clock: clock.C,
InsertBatch: batchSizes,
UpdateBatch: batchSizes,
DeleteBatch: batchSizes,
RedisPopCount: batchSizes,
RedisScanKeysCount: 10,
}
task := NewTask(nil, nil, clock.C, config.OsqueryConfig{
AsyncHostInsertBatch: batchSizes,
AsyncHostUpdateBatch: batchSizes,
AsyncHostDeleteBatch: batchSizes,
AsyncHostRedisPopCount: batchSizes,
AsyncHostRedisScanKeysCount: 10,
})
err := task.collectLabelQueryExecutions(ctx, ds, pool, &stats)
require.NoError(t, err)
// inserts, updates and deletes are a bit tricky to track automatically,
@ -252,14 +252,13 @@ func testCollectLabelQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fle
// update host 1, label 1, already existing
setupTest(t, map[int]map[int]bool{1: {1: true}})
var stats collectorExecStats
task := Task{
Clock: clock.C,
InsertBatch: batchSizes,
UpdateBatch: batchSizes,
DeleteBatch: batchSizes,
RedisPopCount: batchSizes,
RedisScanKeysCount: 10,
}
task := NewTask(nil, nil, clock.C, config.OsqueryConfig{
AsyncHostInsertBatch: batchSizes,
AsyncHostUpdateBatch: batchSizes,
AsyncHostDeleteBatch: batchSizes,
AsyncHostRedisPopCount: batchSizes,
AsyncHostRedisScanKeysCount: 10,
})
err := task.collectLabelQueryExecutions(ctx, ds, pool, &stats)
require.NoError(t, err)
@ -288,12 +287,7 @@ func testRecordLabelQueryExecutionsSync(t *testing.T, ds *mock.Store, pool fleet
results := map[uint]*bool{1: &yes, 2: &yes, 3: &no, 4: nil}
keySet, keyTs := fmt.Sprintf(labelMembershipHostKey, host.ID), fmt.Sprintf(labelMembershipReportedKey, host.ID)
task := Task{
Datastore: ds,
Pool: pool,
Clock: clock.C,
AsyncEnabled: false,
}
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{})
labelReportedAt := task.GetHostLabelReportedAt(ctx, host)
require.True(t, labelReportedAt.Equal(lastYear))
@ -336,18 +330,14 @@ func testRecordLabelQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool flee
results := map[uint]*bool{1: &yes, 2: &yes, 3: &no, 4: nil}
keySet, keyTs := fmt.Sprintf(labelMembershipHostKey, host.ID), fmt.Sprintf(labelMembershipReportedKey, host.ID)
task := Task{
Datastore: ds,
Pool: pool,
Clock: clock.C,
AsyncEnabled: true,
InsertBatch: 3,
UpdateBatch: 3,
DeleteBatch: 3,
RedisPopCount: 3,
RedisScanKeysCount: 10,
}
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{
EnableAsyncHostProcessing: "true",
AsyncHostInsertBatch: 3,
AsyncHostUpdateBatch: 3,
AsyncHostDeleteBatch: 3,
AsyncHostRedisPopCount: 3,
AsyncHostRedisScanKeysCount: 10,
})
labelReportedAt := task.GetHostLabelReportedAt(ctx, host)
require.True(t, labelReportedAt.Equal(lastYear))

View File

@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -27,9 +28,10 @@ var (
)
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
if !t.AsyncEnabled {
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if !cfg.Enabled {
host.PolicyUpdatedAt = ts
return t.Datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
return t.datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
}
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
@ -45,7 +47,7 @@ func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host
// the collector will have plenty of time to run (multiple times) to try to
// persist all the data in mysql.
ttl := policyPassKeysMinTTL
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
if maxTTL := 10 * cfg.CollectInterval; maxTTL > ttl {
ttl = maxTTL
}
@ -79,9 +81,9 @@ func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
}
conn := t.Pool.Get()
conn := t.pool.Get()
defer conn.Close()
if err := redis.BindConn(t.Pool, conn, keyList, keyTs); err != nil {
if err := redis.BindConn(t.pool, conn, keyList, keyTs); err != nil {
return ctxerr.Wrap(ctx, err, "bind redis connection")
}
@ -93,14 +95,16 @@ func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host
// outside of the redis script because in Redis Cluster mode the key may not
// live on the same node as the host's keys. At the same time, purge any
// entry in the set that is older than now - TTL.
if _, err := storePurgeActiveHostID(t.Pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
if _, err := storePurgeActiveHostID(t.pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
return ctxerr.Wrap(ctx, err, "store active host id")
}
return nil
}
func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, t.RedisScanKeysCount)
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, cfg.RedisScanKeysCount)
if err != nil {
return ctxerr.Wrap(ctx, err, "load active host ids")
}
@ -165,7 +169,7 @@ func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datast
return ds.AsyncBatchUpdatePolicyTimestamp(ctx, ids, ts)
}
insertBatch := make([]fleet.PolicyMembershipResult, 0, t.InsertBatch)
insertBatch := make([]fleet.PolicyMembershipResult, 0, cfg.InsertBatch)
for _, host := range hosts {
hid := host.HostID
ins, err := getKeyTuples(hid)
@ -174,7 +178,7 @@ func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datast
}
insertBatch = append(insertBatch, ins...)
if len(insertBatch) >= t.InsertBatch {
if len(insertBatch) >= cfg.InsertBatch {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
@ -195,8 +199,8 @@ func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datast
hostIDs[i] = host.HostID
}
ts := t.Clock.Now()
updateBatch := make([]uint, t.UpdateBatch)
ts := t.clock.Now()
updateBatch := make([]uint, cfg.UpdateBatch)
for {
n := copy(updateBatch, hostIDs)
if n == 0 {
@ -221,8 +225,10 @@ func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datast
}
func (t *Task) GetHostPolicyReportedAt(ctx context.Context, host *fleet.Host) time.Time {
if t.AsyncEnabled {
conn := redis.ConfigureDoer(t.Pool, t.Pool.Get())
cfg := t.taskConfigs[config.AsyncTaskPolicyMembership]
if cfg.Enabled {
conn := redis.ConfigureDoer(t.pool, t.pool.Get())
defer conn.Close()
key := fmt.Sprintf(policyPassReportedKey, host.ID)

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -227,14 +228,13 @@ func testCollectPolicyQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fl
// run the collection
var stats collectorExecStats
task := Task{
Clock: clock.C,
InsertBatch: batchSizes,
UpdateBatch: batchSizes,
DeleteBatch: batchSizes,
RedisPopCount: batchSizes,
RedisScanKeysCount: 10,
}
task := NewTask(nil, nil, clock.C, config.OsqueryConfig{
AsyncHostInsertBatch: batchSizes,
AsyncHostUpdateBatch: batchSizes,
AsyncHostDeleteBatch: batchSizes,
AsyncHostRedisPopCount: batchSizes,
AsyncHostRedisScanKeysCount: 10,
})
err := task.collectPolicyQueryExecutions(ctx, ds, pool, &stats)
require.NoError(t, err)
// inserts, updates and deletes are a bit tricky to track automatically,
@ -276,14 +276,13 @@ func testCollectPolicyQueryExecutions(t *testing.T, ds *mysql.Datastore, pool fl
// update host 1, policy 1, already existing
setupTest(t, map[int]map[int]*bool{1: {1: nil}})
var stats collectorExecStats
task := Task{
Clock: clock.C,
InsertBatch: batchSizes,
UpdateBatch: batchSizes,
DeleteBatch: batchSizes,
RedisPopCount: batchSizes,
RedisScanKeysCount: 10,
}
task := NewTask(nil, nil, clock.C, config.OsqueryConfig{
AsyncHostInsertBatch: batchSizes,
AsyncHostUpdateBatch: batchSizes,
AsyncHostDeleteBatch: batchSizes,
AsyncHostRedisPopCount: batchSizes,
AsyncHostRedisScanKeysCount: 10,
})
err := task.collectPolicyQueryExecutions(ctx, ds, pool, &stats)
require.NoError(t, err)
@ -312,12 +311,7 @@ func testRecordPolicyQueryExecutionsSync(t *testing.T, ds *mock.Store, pool flee
results := map[uint]*bool{1: &yes, 2: &yes, 3: &no, 4: nil}
keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID)
task := Task{
Datastore: ds,
Pool: pool,
Clock: clock.C,
AsyncEnabled: false,
}
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{})
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))
@ -360,18 +354,14 @@ func testRecordPolicyQueryExecutionsAsync(t *testing.T, ds *mock.Store, pool fle
results := map[uint]*bool{1: &yes, 2: &yes, 3: &no, 4: nil}
keyList, keyTs := fmt.Sprintf(policyPassHostKey, host.ID), fmt.Sprintf(policyPassReportedKey, host.ID)
task := Task{
Datastore: ds,
Pool: pool,
Clock: clock.C,
AsyncEnabled: true,
InsertBatch: 3,
UpdateBatch: 3,
DeleteBatch: 3,
RedisPopCount: 3,
RedisScanKeysCount: 10,
}
task := NewTask(ds, pool, clock.C, config.OsqueryConfig{
EnableAsyncHostProcessing: "true",
AsyncHostInsertBatch: 3,
AsyncHostUpdateBatch: 3,
AsyncHostDeleteBatch: 3,
AsyncHostRedisPopCount: 3,
AsyncHostRedisScanKeysCount: 10,
})
policyReportedAt := task.GetHostPolicyReportedAt(ctx, host)
require.True(t, policyReportedAt.Equal(lastYear))

View File

@ -14,6 +14,8 @@ import (
redigo "github.com/gomodule/redigo/redis"
)
type collectorHandlerFunc func(context.Context, fleet.Datastore, fleet.RedisPool, *collectorExecStats) error
type collector struct {
// immutable after creation
name string
@ -22,7 +24,7 @@ type collector struct {
execInterval time.Duration
jitterPct int
lockTimeout time.Duration
handler func(context.Context, fleet.Datastore, fleet.RedisPool, *collectorExecStats) error
handler collectorHandlerFunc
errHandler func(string, error)
// mutable, must be protected by mutex

View File

@ -279,7 +279,7 @@ func TestEnrollAgentDetails(t *testing.T) {
func TestAuthenticateHost(t *testing.T) {
ds := new(mock.Store)
task := &async.Task{Datastore: ds, AsyncEnabled: false}
task := async.NewTask(ds, nil, clock.C, config.OsqueryConfig{})
svc := newTestService(t, ds, nil, nil, &TestServerOpts{Task: task})
var gotKey string

View File

@ -49,10 +49,19 @@ func newTestServiceWithConfig(t *testing.T, ds fleet.Datastore, fleetConfig conf
var failingPolicySet fleet.FailingPolicySet = NewMemFailingPolicySet()
var c clock.Clock = clock.C
if len(opts) > 0 {
if opts[0].Clock != nil {
c = opts[0].Clock
}
}
task := &async.Task{
Datastore: ds,
AsyncEnabled: false,
task := async.NewTask(ds, nil, c, config.OsqueryConfig{})
if len(opts) > 0 {
if opts[0].Task != nil {
task = opts[0].Task
} else {
opts[0].Task = task
}
}
if len(opts) > 0 {
@ -68,17 +77,8 @@ func newTestServiceWithConfig(t *testing.T, ds fleet.Datastore, fleetConfig conf
if opts[0].FailingPolicySet != nil {
failingPolicySet = opts[0].FailingPolicySet
}
if opts[0].Clock != nil {
c = opts[0].Clock
}
if opts[0].Task != nil {
task = opts[0].Task
} else {
opts[0].Task = task
}
}
task.Clock = c
svc, err := NewService(context.Background(), ds, task, rs, logger, osqlogger, fleetConfig, mailer, c, ssoStore, lq, ds, *license, failingPolicySet, &fleet.NoOpGeoIP{})
if err != nil {
panic(err)