Address multiple redis-related issues observed with live queries (#16855)

#16331 

Doc updates in a separate PR:
https://github.com/fleetdm/fleet/pull/17214

# Checklist for submitter

- [x] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality (smoke-tested locally
with osquery-perf simulating 100 hosts, ran a live query, a saved live
query, stopped naturally and stopped before the end, and again via
fleetctl)

---------

Co-authored-by: Victor Lyuboslavsky <victor@fleetdm.com>
Co-authored-by: Victor Lyuboslavsky <victor.lyuboslavsky@gmail.com>
This commit is contained in:
Martin Angers 2024-02-27 20:35:27 -05:00 committed by GitHub
parent d50ee081f2
commit 6c0e56ea73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 531 additions and 58 deletions

View File

@ -0,0 +1,3 @@
* Added logging when a Redis connection is blocked for a long time waiting for live query results.
* Added support for the `redis.conn_wait_timeout` configuration setting for Redis standalone (it was previously only supported on Redis cluster).
* Added Redis cleanup of inactive queries in a cron job, so temporary Redis failures to stop a live query doesn't leave such queries around for a long time.

View File

@ -705,6 +705,7 @@ func newCleanupsAndAggregationSchedule(
ctx context.Context,
instanceID string,
ds fleet.Datastore,
lq fleet.LiveQueryStore,
logger kitlog.Logger,
enrollHostLimiter fleet.EnrollHostLimiter,
config *config.FleetConfig,
@ -720,13 +721,6 @@ func newCleanupsAndAggregationSchedule(
schedule.WithAltLockID("leader"),
schedule.WithLogger(kitlog.With(logger, "cron", name)),
// Run cleanup jobs first.
schedule.WithJob(
"distributed_query_campaigns",
func(ctx context.Context) error {
_, err := ds.CleanupDistributedQueryCampaigns(ctx, time.Now())
return err
},
),
schedule.WithJob(
"incoming_hosts",
func(ctx context.Context) error {
@ -838,6 +832,50 @@ func newCleanupsAndAggregationSchedule(
return s, nil
}
func newFrequentCleanupsSchedule(
ctx context.Context,
instanceID string,
ds fleet.Datastore,
lq fleet.LiveQueryStore,
logger kitlog.Logger,
) (*schedule.Schedule, error) {
const (
name = string(fleet.CronFrequentCleanups)
defaultInterval = 15 * time.Minute
)
s := schedule.New(
ctx, name, instanceID, defaultInterval, ds, ds,
// Using leader for the lock to be backwards compatilibity with old deployments.
schedule.WithAltLockID("leader"),
schedule.WithLogger(kitlog.With(logger, "cron", name)),
// Run cleanup jobs first.
schedule.WithJob(
"distributed_query_campaigns",
func(ctx context.Context) error {
_, err := ds.CleanupDistributedQueryCampaigns(ctx, time.Now().UTC())
if err != nil {
return err
}
names, err := lq.LoadActiveQueryNames()
if err != nil {
return err
}
ids := stringSliceToUintSlice(names, logger)
completed, err := ds.GetCompletedCampaigns(ctx, ids)
if err != nil {
return err
}
if err := lq.CleanupInactiveQueries(ctx, completed); err != nil {
return err
}
return nil
},
),
)
return s, nil
}
func verifyDiskEncryptionKeys(
ctx context.Context,
logger kitlog.Logger,
@ -1094,3 +1132,16 @@ func cronActivitiesStreaming(
page += 1
}
}
func stringSliceToUintSlice(s []string, logger kitlog.Logger) []uint {
result := make([]uint, 0, len(s))
for _, v := range s {
i, err := strconv.ParseUint(v, 10, 64)
if err != nil {
level.Warn(logger).Log("msg", "failed to parse string to uint", "string", v, "err", err)
continue
}
result = append(result, uint(i))
}
return result
}

View File

@ -681,12 +681,22 @@ the way that the Fleet server works.
}()
if err := cronSchedules.StartCronSchedule(func() (fleet.CronSchedule, error) {
var commander *apple_mdm.MDMAppleCommander
if appCfg.MDM.EnabledAndConfigured {
commander = apple_mdm.NewMDMAppleCommander(mdmStorage, mdmPushService)
}
return newCleanupsAndAggregationSchedule(ctx, instanceID, ds, logger, redisWrapperDS, &config, commander)
return newFrequentCleanupsSchedule(ctx, instanceID, ds, liveQueryStore, logger)
}); err != nil {
initFatal(err, "failed to register frequent_cleanups schedule")
}
if err := cronSchedules.StartCronSchedule(
func() (fleet.CronSchedule, error) {
var commander *apple_mdm.MDMAppleCommander
if appCfg.MDM.EnabledAndConfigured {
commander = apple_mdm.NewMDMAppleCommander(mdmStorage, mdmPushService)
}
return newCleanupsAndAggregationSchedule(
ctx, instanceID, ds, liveQueryStore, logger, redisWrapperDS, &config, commander,
)
},
); err != nil {
initFatal(err, "failed to register cleanups_then_aggregations schedule")
}

View File

@ -818,7 +818,7 @@ func (man Manager) addConfigs() {
man.addConfigInt("redis.max_open_conns", 0, "Redis maximum open connections, 0 means no limit")
man.addConfigDuration("redis.conn_max_lifetime", 0, "Redis maximum amount of time a connection may be reused, 0 means no limit")
man.addConfigDuration("redis.idle_timeout", 240*time.Second, "Redis maximum amount of time a connection may stay idle, 0 means no limit")
man.addConfigDuration("redis.conn_wait_timeout", 0, "Redis maximum amount of time to wait for a connection if the maximum is reached (0 for no wait, ignored in non-cluster Redis)")
man.addConfigDuration("redis.conn_wait_timeout", 0, "Redis maximum amount of time to wait for a connection if the maximum is reached (0 for no wait)")
man.addConfigDuration("redis.write_timeout", 10*time.Second, "Redis maximum amount of time to wait for a write (send) on a connection")
man.addConfigDuration("redis.read_timeout", 10*time.Second, "Redis maximum amount of time to wait for a read (receive) on a connection")

View File

@ -2,24 +2,37 @@ package mysql
import (
"context"
"fmt"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/jmoiron/sqlx"
)
func (ds *Datastore) NewDistributedQueryCampaign(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) {
args := []any{camp.QueryID, camp.Status, camp.UserID}
sqlStatement := `
// for tests, we sometimes provide specific timestamps for CreatedAt, honor
// those if provided.
var createdAtField, createdAtPlaceholder string
if !camp.CreatedAt.IsZero() {
createdAtField = ", created_at"
createdAtPlaceholder = ", ?"
args = append(args, camp.CreatedAt)
}
sqlStatement := fmt.Sprintf(`
INSERT INTO distributed_query_campaigns (
query_id,
status,
user_id
%s
)
VALUES(?,?,?)
`
result, err := ds.writer(ctx).ExecContext(ctx, sqlStatement, camp.QueryID, camp.Status, camp.UserID)
VALUES(?,?,?%s)
`, createdAtField, createdAtPlaceholder)
result, err := ds.writer(ctx).ExecContext(ctx, sqlStatement, args...)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "inserting distributed query campaign")
}
@ -121,9 +134,50 @@ func (ds *Datastore) NewDistributedQueryCampaignTarget(ctx context.Context, targ
return target, nil
}
func (ds *Datastore) GetCompletedCampaigns(ctx context.Context, filter []uint) ([]uint, error) {
// There is a limit of 65,535 (2^16-1) placeholders in MySQL 5.7
const batchSize = 65535 - 1
if len(filter) == 0 {
return nil, nil
}
// We must remove duplicates from the input filter because we process the filter in batches,
// and that could result in duplicated result IDs
filter = server.RemoveDuplicatesFromSlice(filter)
completed := make([]uint, 0, len(filter))
for i := 0; i < len(filter); i += batchSize {
end := i + batchSize
if end > len(filter) {
end = len(filter)
}
batch := filter[i:end]
query, args, err := sqlx.In(
`SELECT id
FROM distributed_query_campaigns
WHERE status = ?
AND id IN (?)
`, fleet.QueryComplete, batch,
)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building query for completed campaigns")
}
var rows []uint
// using the writer, so we catch the ones we just marked as completed
err = sqlx.SelectContext(ctx, ds.writer(ctx), &rows, query, args...)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "selecting completed campaigns")
}
completed = append(completed, rows...)
}
return completed, nil
}
func (ds *Datastore) CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error) {
// Expire old waiting/running campaigns
sqlStatement := `
const sqlStatement = `
UPDATE distributed_query_campaigns
SET status = ?
WHERE (status = ? AND created_at < ?)
@ -140,6 +194,5 @@ func (ds *Datastore) CleanupDistributedQueryCampaigns(ctx context.Context, now t
if err != nil {
return 0, ctxerr.Wrap(ctx, err, "rows affected updating distributed query campaign")
}
return uint(exp), nil
}

View File

@ -2,6 +2,8 @@ package mysql
import (
"context"
"math/rand"
"sort"
"testing"
"time"
@ -22,6 +24,7 @@ func TestCampaigns(t *testing.T) {
{"DistributedQuery", testCampaignsDistributedQuery},
{"CleanupDistributedQuery", testCampaignsCleanupDistributedQuery},
{"SaveDistributedQuery", testCampaignsSaveDistributedQuery},
{"CompletedCampaigns", testCompletedCampaigns},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -80,10 +83,9 @@ func testCampaignsDistributedQuery(t *testing.T, ds *Datastore) {
}
func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Zach", "zwass@fleet.co", true)
mockClock := clock.NewMockClock()
query := test.NewQuery(t, ds, nil, "test", "select * from time", user.ID, false)
c1 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now())
@ -91,18 +93,18 @@ func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) {
// Cleanup and verify that nothing changed (because time has not
// advanced)
expired, err := ds.CleanupDistributedQueryCampaigns(context.Background(), mockClock.Now())
expired, err := ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now())
require.Nil(t, err)
assert.Equal(t, uint(0), expired)
{
retrieved, err := ds.DistributedQueryCampaign(context.Background(), c1.ID)
retrieved, err := ds.DistributedQueryCampaign(ctx, c1.ID)
require.Nil(t, err)
assert.Equal(t, c1.QueryID, retrieved.QueryID)
assert.Equal(t, c1.Status, retrieved.Status)
}
{
retrieved, err := ds.DistributedQueryCampaign(context.Background(), c2.ID)
retrieved, err := ds.DistributedQueryCampaign(ctx, c2.ID)
require.Nil(t, err)
assert.Equal(t, c2.QueryID, retrieved.QueryID)
assert.Equal(t, c2.Status, retrieved.Status)
@ -114,18 +116,18 @@ func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) {
// Cleanup and verify that the campaign was expired and executions
// deleted appropriately
expired, err = ds.CleanupDistributedQueryCampaigns(context.Background(), mockClock.Now())
expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now())
require.Nil(t, err)
assert.Equal(t, uint(1), expired)
{
// c1 should now be complete
retrieved, err := ds.DistributedQueryCampaign(context.Background(), c1.ID)
retrieved, err := ds.DistributedQueryCampaign(ctx, c1.ID)
require.Nil(t, err)
assert.Equal(t, c1.QueryID, retrieved.QueryID)
assert.Equal(t, fleet.QueryComplete, retrieved.Status)
}
{
retrieved, err := ds.DistributedQueryCampaign(context.Background(), c2.ID)
retrieved, err := ds.DistributedQueryCampaign(ctx, c2.ID)
require.Nil(t, err)
assert.Equal(t, c2.QueryID, retrieved.QueryID)
assert.Equal(t, c2.Status, retrieved.Status)
@ -135,22 +137,51 @@ func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) {
// Cleanup and verify that the campaign was expired and executions
// deleted appropriately
expired, err = ds.CleanupDistributedQueryCampaigns(context.Background(), mockClock.Now())
expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now())
require.Nil(t, err)
assert.Equal(t, uint(1), expired)
{
retrieved, err := ds.DistributedQueryCampaign(context.Background(), c1.ID)
retrieved, err := ds.DistributedQueryCampaign(ctx, c1.ID)
require.Nil(t, err)
assert.Equal(t, c1.QueryID, retrieved.QueryID)
assert.Equal(t, fleet.QueryComplete, retrieved.Status)
}
{
// c2 should now be complete
retrieved, err := ds.DistributedQueryCampaign(context.Background(), c2.ID)
retrieved, err := ds.DistributedQueryCampaign(ctx, c2.ID)
require.Nil(t, err)
assert.Equal(t, c2.QueryID, retrieved.QueryID)
assert.Equal(t, fleet.QueryComplete, retrieved.Status)
}
// simulate another old campaign created > 7 days ago
c3 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now().AddDate(0, 0, -8))
{
retrieved, err := ds.DistributedQueryCampaign(ctx, c3.ID)
require.Nil(t, err)
assert.Equal(t, c3.QueryID, retrieved.QueryID)
assert.Equal(t, fleet.QueryWaiting, retrieved.Status)
}
// cleanup will mark c3 as completed because it was waiting for > 1 minute,
// but it won't return it as recently inactive because it's too old a query.
expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now())
require.Nil(t, err)
assert.Equal(t, uint(1), expired)
// cleanup again does not expire any new campaign and still returns the same
// recently inactive campaigns
expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now())
require.Nil(t, err)
assert.Equal(t, uint(0), expired)
// move time forward 7 days and cleanup again, this time it returns no recent
// inactive campaigns
mockClock.AddTime(7*24*time.Hour + 1*time.Second)
expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now())
require.Nil(t, err)
assert.Equal(t, uint(0), expired)
}
func testCampaignsSaveDistributedQuery(t *testing.T, ds *Datastore) {
@ -180,3 +211,46 @@ func checkTargets(t *testing.T, ds fleet.Datastore, campaignID uint, expectedTar
assert.ElementsMatch(t, expectedTargets.LabelIDs, targets.LabelIDs)
assert.ElementsMatch(t, expectedTargets.TeamIDs, targets.TeamIDs)
}
func testCompletedCampaigns(t *testing.T, ds *Datastore) {
// Test nil result
result, err := ds.GetCompletedCampaigns(context.Background(), nil)
assert.NoError(t, err)
assert.Len(t, result, 0)
result, err = ds.GetCompletedCampaigns(context.Background(), []uint{234, 1, 1, 34455455453})
assert.NoError(t, err)
assert.Len(t, result, 0)
// Now test reasonable results
user := test.NewUser(t, ds, t.Name(), t.Name()+"zwass@fleet.co", true)
mockClock := clock.NewMockClock()
query := test.NewQuery(t, ds, nil, t.Name()+"test", "select * from time", user.ID, false)
numCampaigns := 5
totalFilterSize := 100000
filter := make([]uint, 0, totalFilterSize)
complete := make([]uint, 0, numCampaigns)
for i := 0; i < numCampaigns; i++ {
c1 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now())
gotC, err := ds.DistributedQueryCampaign(context.Background(), c1.ID)
require.NoError(t, err)
require.Equal(t, fleet.QueryWaiting, gotC.Status)
if rand.Intn(10) < 7 { //nolint:gosec
c1.Status = fleet.QueryComplete
require.NoError(t, ds.SaveDistributedQueryCampaign(context.Background(), c1))
complete = append(complete, c1.ID)
}
filter = append(filter, c1.ID)
}
for j := filter[len(filter)-1] / 2; j < uint(totalFilterSize); j++ { // some IDs are duplicated
filter = append(filter, j)
}
rand.Shuffle(len(filter), func(i, j int) { filter[i], filter[j] = filter[j], filter[i] })
result, err = ds.GetCompletedCampaigns(context.Background(), filter)
assert.NoError(t, err)
sort.Slice(result, func(i, j int) bool { return result[i] < result[j] })
assert.Equal(t, complete, result)
}

View File

@ -1,6 +1,7 @@
package redis
import (
"context"
"errors"
"fmt"
"net"
@ -18,7 +19,24 @@ import (
// redisc.Cluster, so both can satisfy the same interface.
type standalonePool struct {
*redis.Pool
addr string
addr string
connWaitTimeout time.Duration
}
func (p *standalonePool) Get() redis.Conn {
if p.connWaitTimeout <= 0 {
return p.Pool.Get()
}
ctx, cancel := context.WithTimeout(context.Background(), p.connWaitTimeout)
defer cancel()
// GetContext always returns an "errorConn" as valid connection when there is
// an error, so there's no need to care about the second return value (as for
// the no-wait case, the errorConn will fail on first use with the actual
// error).
conn, _ := p.Pool.GetContext(ctx)
return conn
}
func (p *standalonePool) Stats() map[string]redis.PoolStats {
@ -83,10 +101,7 @@ func NewPool(config PoolConfig) (fleet.RedisPool, error) {
// not a Redis Cluster setup, use a standalone Redis pool
pool, _ := cluster.CreatePool(config.Server, cluster.DialOptions...)
cluster.Close()
// never wait for a connection in a non-cluster pool as it can block
// indefinitely.
pool.Wait = false
return &standalonePool{pool, config.Server}, nil
return &standalonePool{pool, config.Server, config.ConnWaitTimeout}, nil
}
return nil, fmt.Errorf("refresh cluster: %w", err)
}

View File

@ -364,3 +364,74 @@ func TestReadTimeout(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
}
func TestRedisConnWaitTime(t *testing.T) {
t.Run("standalone", func(t *testing.T) {
pool := redistest.SetupRedisWithConfig(t, "zz", false, false, false, redis.PoolConfig{ConnWaitTimeout: 2 * time.Second, MaxOpenConns: 2})
conn1 := pool.Get()
defer conn1.Close()
conn2 := pool.Get()
defer conn2.Close()
// no more connections available, requesting another will wait and fail after 2s
start := time.Now()
conn3 := pool.Get()
_, err := conn3.Do("PING")
conn3.Close()
require.Error(t, err)
require.GreaterOrEqual(t, time.Since(start), 2*time.Second)
// request another one but this time close an open connection after a second
go func() {
time.Sleep(time.Second)
conn1.Close()
}()
start = time.Now()
conn4 := pool.Get()
_, err = conn4.Do("PING")
conn4.Close()
require.NoError(t, err)
require.Less(t, time.Since(start), 2*time.Second)
})
t.Run("cluster", func(t *testing.T) {
pool := redistest.SetupRedisWithConfig(t, "wait-timeout-", true, false, false, redis.PoolConfig{ConnWaitTimeout: 2 * time.Second, MaxOpenConns: 2})
conn1 := pool.Get()
defer conn1.Close()
conn2 := pool.Get()
defer conn2.Close()
// bind the connections to the same node by requesting the same key, both connections
// will now be used from the same pool.
_, err := conn1.Do("GET", "wait-timeout-a")
require.NoError(t, err)
_, err = conn2.Do("GET", "wait-timeout-a")
require.NoError(t, err)
// no more connections available, requesting another will wait and fail after 2s
start := time.Now()
conn3 := pool.Get()
defer conn3.Close()
_, err = conn3.Do("GET", "wait-timeout-a")
conn3.Close()
require.Error(t, err)
require.GreaterOrEqual(t, time.Since(start), 2*time.Second)
// request another one but this time close an open connection after a second
go func() {
time.Sleep(time.Second)
conn1.Close()
}()
start = time.Now()
conn4 := pool.Get()
defer conn4.Close()
_, err = conn4.Do("GET", "wait-timeout-a")
conn4.Close()
require.NoError(t, err)
require.Less(t, time.Since(start), 2*time.Second)
})
}

View File

@ -36,6 +36,10 @@ func NopRedis() fleet.RedisPool {
}
func SetupRedis(tb testing.TB, cleanupKeyPrefix string, cluster, redir, readReplica bool) fleet.RedisPool {
return SetupRedisWithConfig(tb, cleanupKeyPrefix, cluster, redir, readReplica, redis.PoolConfig{})
}
func SetupRedisWithConfig(tb testing.TB, cleanupKeyPrefix string, cluster, redir, readReplica bool, config redis.PoolConfig) fleet.RedisPool {
if _, ok := os.LookupEnv("REDIS_TEST"); !ok {
tb.Skip("set REDIS_TEST environment variable to run redis-based tests")
}
@ -59,17 +63,22 @@ func SetupRedis(tb testing.TB, cleanupKeyPrefix string, cluster, redir, readRepl
}
addr += port
pool, err := redis.NewPool(redis.PoolConfig{
Server: addr,
Username: username,
Password: password,
Database: database,
UseTLS: useTLS,
ConnTimeout: 5 * time.Second,
KeepAlive: 10 * time.Second,
ClusterFollowRedirections: redir,
ClusterReadFromReplica: readReplica,
})
// set the mandatory, non-configurable configs for tests
config.Server = addr
config.Username = username
config.Password = password
config.Database = database
config.UseTLS = useTLS
config.ClusterFollowRedirections = redir
config.ClusterReadFromReplica = readReplica
if config.ConnTimeout == 0 {
config.ConnTimeout = 5 * time.Second
}
if config.KeepAlive == 0 {
config.KeepAlive = 10 * time.Second
}
pool, err := redis.NewPool(config)
require.NoError(tb, err)
conn := pool.Get()

View File

@ -14,6 +14,7 @@ type CronScheduleName string
const (
CronAppleMDMDEPProfileAssigner CronScheduleName = "apple_mdm_dep_profile_assigner"
CronCleanupsThenAggregation CronScheduleName = "cleanups_then_aggregation"
CronFrequentCleanups CronScheduleName = "frequent_cleanups"
CronUsageStatistics CronScheduleName = "usage_statistics"
CronVulnerabilities CronScheduleName = "vulnerabilities"
CronAutomations CronScheduleName = "automations"

View File

@ -120,12 +120,18 @@ type Datastore interface {
// NewDistributedQueryCampaignTarget adds a new target to an existing distributed query campaign
NewDistributedQueryCampaignTarget(ctx context.Context, target *DistributedQueryCampaignTarget) (*DistributedQueryCampaignTarget, error)
// CleanupDistributedQueryCampaigns will clean and trim metadata for old distributed query campaigns. Any campaign
// in the QueryWaiting state will be moved to QueryComplete after one minute. Any campaign in the QueryRunning state
// will be moved to QueryComplete after one day. Times are from creation time. The now parameter makes this method
// easier to test. The return values indicate how many campaigns were expired and any error.
// CleanupDistributedQueryCampaigns will clean and trim metadata for old
// distributed query campaigns. Any campaign in the QueryWaiting state will
// be moved to QueryComplete after one minute. Any campaign in the
// QueryRunning state will be moved to QueryComplete after one day. Times are
// from creation time. The now parameter makes this method easier to test.
// The return values indicate how many campaigns were expired and any error.
CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error)
// GetCompletedCampaigns returns the IDs of the campaigns that are in the fleet.QueryComplete state and that are in the
// provided list of IDs. The return value is a slice of the IDs of the completed campaigns and any error.
GetCompletedCampaigns(ctx context.Context, filter []uint) ([]uint, error)
DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*DistributedQueryCampaign, error)
///////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,7 @@
package fleet
import "context"
// LiveQueryStore defines an interface for storing and retrieving the status of
// live queries in the Fleet system.
type LiveQueryStore interface {
@ -16,4 +18,10 @@ type LiveQueryStore interface {
// given host. After calling QueryCompleted, that query will no longer be
// sent to the host.
QueryCompletedByHost(name string, hostID uint) error
// CleanupInactiveQueries removes any inactive queries. This is used via a
// cron job to regularly cleanup any queries that may have failed to be
// stopped properly in Redis.
CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error
// LoadActiveQueryNames returns the names of all active queries.
LoadActiveQueryNames() ([]string, error)
}

View File

@ -1,6 +1,7 @@
package live_query_mock
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -45,3 +46,9 @@ func (m *MockLiveQuery) QueryCompletedByHost(name string, hostID uint) error {
args := m.Called(name, hostID)
return args.Error(0)
}
// CleanupInactiveQueries mocks the live query store CleanupInactiveQueries method.
func (m *MockLiveQuery) CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error {
args := m.Called(ctx, inactiveCampaignIDs)
return args.Error(0)
}

View File

@ -1,6 +1,7 @@
package live_query
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
@ -16,6 +17,7 @@ var testFunctions = [...]func(*testing.T, fleet.LiveQueryStore){
testLiveQueryStopQuery,
testLiveQueryExpiredQuery,
testLiveQueryOnlyExpired,
testLiveQueryCleanupInactive,
}
func testLiveQuery(t *testing.T, store fleet.LiveQueryStore) {
@ -135,3 +137,79 @@ func testLiveQueryOnlyExpired(t *testing.T, store fleet.LiveQueryStore) {
require.NoError(t, err)
require.Len(t, activeNames, 0)
}
func testLiveQueryCleanupInactive(t *testing.T, store fleet.LiveQueryStore) {
ctx := context.Background()
// get a raw Redis connection to make direct checks
pool := store.(*redisLiveQuery).pool
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
// run a few live queries, making them active in Redis
err := store.RunQuery("1", "SELECT 1", []uint{1, 2, 3})
require.NoError(t, err)
err = store.RunQuery("2", "SELECT 2", []uint{4})
require.NoError(t, err)
err = store.RunQuery("3", "SELECT 3", []uint{5, 6})
require.NoError(t, err)
err = store.RunQuery("4", "SELECT 4", []uint{1, 2, 5})
require.NoError(t, err)
err = store.RunQuery("5", "SELECT 5", []uint{2, 3, 7})
require.NoError(t, err)
activeNames, err := store.LoadActiveQueryNames()
require.NoError(t, err)
require.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, activeNames)
// sanity-check that the queries are properly stored
m, err := store.QueriesForHost(1)
require.NoError(t, err)
require.Equal(t, map[string]string{"1": "SELECT 1", "4": "SELECT 4"}, m)
// simulate that only campaigns 2 and 4 are still active, cleanup the rest
err = store.CleanupInactiveQueries(ctx, []uint{1, 3, 5})
require.NoError(t, err)
activeNames, err = store.LoadActiveQueryNames()
require.NoError(t, err)
require.ElementsMatch(t, []string{"2", "4"}, activeNames)
m, err = store.QueriesForHost(1)
require.NoError(t, err)
require.Equal(t, map[string]string{"4": "SELECT 4"}, m)
// explicitly mark campaign 4 as stopped
err = store.StopQuery("4")
require.NoError(t, err)
// no more queries for host 1
m, err = store.QueriesForHost(1)
require.NoError(t, err)
require.Empty(t, m)
// only campaign 2 remains, for host 4
m, err = store.QueriesForHost(4)
require.NoError(t, err)
require.Equal(t, map[string]string{"2": "SELECT 2"}, m)
// simulate that there are no inactive campaigns to cleanup
err = store.CleanupInactiveQueries(ctx, nil)
require.NoError(t, err)
activeNames, err = store.LoadActiveQueryNames()
require.NoError(t, err)
require.ElementsMatch(t, []string{"2"}, activeNames)
// simulate that all campaigns are inactive, cleanup all
err = store.CleanupInactiveQueries(ctx, []uint{1, 2, 3, 4, 5})
require.NoError(t, err)
activeNames, err = store.LoadActiveQueryNames()
require.NoError(t, err)
require.Empty(t, activeNames)
m, err = store.QueriesForHost(4)
require.NoError(t, err)
require.Empty(t, m)
}

View File

@ -45,11 +45,14 @@
package live_query
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
redigo "github.com/gomodule/redigo/redis"
@ -100,7 +103,7 @@ func extractTargetKeyName(key string) string {
// RunQuery stores the live query information in ephemeral storage for the
// duration of the query or its TTL. Note that hostIDs *must* be sorted
// in ascending order.
// in ascending order. The name is the campaign ID as a string.
func (r *redisLiveQuery) RunQuery(name, sql string, hostIDs []uint) error {
if len(hostIDs) == 0 {
return errors.New("no hosts targeted")
@ -138,7 +141,7 @@ var cleanupExpiredQueriesModulo int64 = 10
func (r *redisLiveQuery) QueriesForHost(hostID uint) (map[string]string, error) {
// Get keys for active queries
names, err := r.loadActiveQueryNames()
names, err := r.LoadActiveQueryNames()
if err != nil {
return nil, fmt.Errorf("load active queries: %w", err)
}
@ -312,7 +315,7 @@ func (r *redisLiveQuery) removeQueryNames(names ...string) error {
return err
}
func (r *redisLiveQuery) loadActiveQueryNames() ([]string, error) {
func (r *redisLiveQuery) LoadActiveQueryNames() ([]string, error) {
conn := redis.ConfigureDoer(r.pool, r.pool.Get())
defer conn.Close()
@ -323,6 +326,62 @@ func (r *redisLiveQuery) loadActiveQueryNames() ([]string, error) {
return names, nil
}
func (r *redisLiveQuery) CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error {
// the following logic is used to cleanup inactive queries:
// * the inactive campaign IDs are removed from the livequery:active set
//
// At this point, all inactive queries are already effectively deleted - the
// rest is just best effort cleanup to save Redis memory space, but those
// keys would otherwise be ignored and without effect.
//
// * remove the livequery:<ID> and sql:livequery:<ID> for every inactive
// campaign ID.
if len(inactiveCampaignIDs) == 0 {
return nil
}
if err := r.removeInactiveQueries(ctx, inactiveCampaignIDs); err != nil {
return err
}
keysToDel := make([]string, 0, len(inactiveCampaignIDs)*2)
for _, id := range inactiveCampaignIDs {
targetKey, sqlKey := generateKeys(strconv.FormatUint(uint64(id), 10))
keysToDel = append(keysToDel, targetKey, sqlKey)
}
keysBySlot := redis.SplitKeysBySlot(r.pool, keysToDel...)
for _, keys := range keysBySlot {
if err := r.removeBatchInactiveKeys(ctx, keys); err != nil {
return err
}
}
return nil
}
func (r *redisLiveQuery) removeBatchInactiveKeys(ctx context.Context, keys []string) error {
conn := r.pool.Get()
defer conn.Close()
args := redigo.Args{}.AddFlat(keys)
if _, err := conn.Do("DEL", args...); err != nil {
return ctxerr.Wrap(ctx, err, "remove batch of inactive keys")
}
return nil
}
func (r *redisLiveQuery) removeInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error {
conn := r.pool.Get()
defer conn.Close()
args := redigo.Args{}.Add(activeQueriesKey).AddFlat(inactiveCampaignIDs)
if _, err := conn.Do("SREM", args...); err != nil {
return ctxerr.Wrap(ctx, err, "remove inactive campaign IDs")
}
return nil
}
// mapBitfield takes the given host IDs and maps them into a bitfield compatible
// with Redis. It is expected that the input IDs are in ascending order.
func mapBitfield(hostIDs []uint) []byte {

View File

@ -98,6 +98,8 @@ type NewDistributedQueryCampaignTargetFunc func(ctx context.Context, target *fle
type CleanupDistributedQueryCampaignsFunc func(ctx context.Context, now time.Time) (expired uint, err error)
type GetCompletedCampaignsFunc func(ctx context.Context, filter []uint) ([]uint, error)
type DistributedQueryCampaignsForQueryFunc func(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error)
type ApplyPackSpecsFunc func(ctx context.Context, specs []*fleet.PackSpec) error
@ -959,6 +961,9 @@ type DataStore struct {
CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFunc
CleanupDistributedQueryCampaignsFuncInvoked bool
GetCompletedCampaignsFunc GetCompletedCampaignsFunc
GetCompletedCampaignsFuncInvoked bool
DistributedQueryCampaignsForQueryFunc DistributedQueryCampaignsForQueryFunc
DistributedQueryCampaignsForQueryFuncInvoked bool
@ -2352,6 +2357,13 @@ func (s *DataStore) CleanupDistributedQueryCampaigns(ctx context.Context, now ti
return s.CleanupDistributedQueryCampaignsFunc(ctx, now)
}
func (s *DataStore) GetCompletedCampaigns(ctx context.Context, filter []uint) ([]uint, error) {
s.mu.Lock()
s.GetCompletedCampaignsFuncInvoked = true
s.mu.Unlock()
return s.GetCompletedCampaignsFunc(ctx, filter)
}
func (s *DataStore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) {
s.mu.Lock()
s.DistributedQueryCampaignsForQueryFuncInvoked = true

View File

@ -96,7 +96,7 @@ func TestQueryResultsStore(t *testing.T) {
ctx1, cancel1 := context.WithCancel(context.Background())
channel1, err := store.ReadChannel(ctx1, campaign1)
assert.Nil(t, err)
require.NoError(t, err)
expected1 := []fleet.DistributedQueryResult{
{
@ -126,7 +126,7 @@ func TestQueryResultsStore(t *testing.T) {
ctx2, cancel2 := context.WithCancel(context.Background())
channel2, err := store.ReadChannel(ctx2, campaign2)
assert.Nil(t, err)
require.NoError(t, err)
expected2 := []fleet.DistributedQueryResult{
{

View File

@ -90,9 +90,16 @@ func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<
defer close(outChan)
for {
// Add a timeout to try to cleanup in the case the server has somehow gone completely unresponsive.
beforeReceive := time.Now()
// Add a timeout to try to cleanup in the case the server has somehow gone
// completely unresponsive.
msg := conn.ReceiveWithTimeout(1 * time.Hour)
if recvTime := time.Since(beforeReceive); recvTime > time.Minute {
level.Info(logger).Log("msg", "conn.ReceiveWithTimeout connection was blocked for significant time", "duration", recvTime, "connection", fmt.Sprintf("%p", conn))
}
// Pass the message back to ReadChannel.
if writeOrDone(ctx, outChan, msg) {
return
@ -188,6 +195,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
wg.Wait()
psc.Unsubscribe(pubSubName) //nolint:errcheck
conn.Close()
level.Debug(logger).Log("msg", "proper close of Redis connection in ReadChannel", "connection", fmt.Sprintf("%p", conn))
}()
return outChannel, nil

View File

@ -30,6 +30,14 @@ func (nopLiveQuery) QueryCompletedByHost(name string, hostID uint) error {
return nil
}
func (nopLiveQuery) CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error {
return nil
}
func (q nopLiveQuery) LoadActiveQueryNames() ([]string, error) {
return nil, nil
}
func TestLiveQueryAuth(t *testing.T) {
ds := new(mock.Store)
qr := pubsub.NewInmemQueryResults()

View File

@ -149,8 +149,8 @@ func Base64DecodePaddingAgnostic(s string) ([]byte, error) {
// RemoveDuplicatesFromSlice returns a slice with all the duplicates removed from the input slice.
func RemoveDuplicatesFromSlice[T comparable](slice []T) []T {
// We are using the allKeys map as a set here
allKeys := make(map[T]struct{})
var list []T
allKeys := make(map[T]struct{}, len(slice))
list := make([]T, 0, len(slice))
for _, i := range slice {
if _, exists := allKeys[i]; !exists {