From 6c0e56ea73198b0f3155149599f19ce8cfa96583 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Tue, 27 Feb 2024 20:35:27 -0500 Subject: [PATCH] Address multiple redis-related issues observed with live queries (#16855) #16331 Doc updates in a separate PR: https://github.com/fleetdm/fleet/pull/17214 # Checklist for submitter - [x] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] Added/updated tests - [x] Manual QA for all new/changed functionality (smoke-tested locally with osquery-perf simulating 100 hosts, ran a live query, a saved live query, stopped naturally and stopped before the end, and again via fleetctl) --------- Co-authored-by: Victor Lyuboslavsky Co-authored-by: Victor Lyuboslavsky --- changes/16331-address-redis-issues-live-query | 3 + cmd/fleet/cron.go | 65 +++++++++++-- cmd/fleet/serve.go | 20 +++- server/config/config.go | 2 +- server/datastore/mysql/campaigns.go | 65 +++++++++++-- server/datastore/mysql/campaigns_test.go | 96 ++++++++++++++++--- server/datastore/redis/redis.go | 25 ++++- server/datastore/redis/redis_external_test.go | 71 ++++++++++++++ server/datastore/redis/redistest/redistest.go | 31 +++--- server/fleet/cron_schedules.go | 1 + server/fleet/datastore.go | 14 ++- server/fleet/live_query_store.go | 8 ++ .../live_query_mock/live_query_mock.go | 7 ++ server/live_query/live_query_test.go | 78 +++++++++++++++ server/live_query/redis_live_query.go | 65 ++++++++++++- server/mock/datastore_mock.go | 12 +++ server/pubsub/{doc.go => errors.go} | 0 server/pubsub/query_results_test.go | 4 +- server/pubsub/redis_query_results.go | 10 +- server/service/campaigns_test.go | 8 ++ server/utils.go | 4 +- 21 files changed, 531 insertions(+), 58 deletions(-) create mode 100644 changes/16331-address-redis-issues-live-query rename server/pubsub/{doc.go => errors.go} (100%) diff --git a/changes/16331-address-redis-issues-live-query b/changes/16331-address-redis-issues-live-query new file mode 100644 index 000000000..d5796f896 --- /dev/null +++ b/changes/16331-address-redis-issues-live-query @@ -0,0 +1,3 @@ +* Added logging when a Redis connection is blocked for a long time waiting for live query results. +* Added support for the `redis.conn_wait_timeout` configuration setting for Redis standalone (it was previously only supported on Redis cluster). +* Added Redis cleanup of inactive queries in a cron job, so temporary Redis failures to stop a live query doesn't leave such queries around for a long time. diff --git a/cmd/fleet/cron.go b/cmd/fleet/cron.go index 932bb444b..2400a2959 100644 --- a/cmd/fleet/cron.go +++ b/cmd/fleet/cron.go @@ -705,6 +705,7 @@ func newCleanupsAndAggregationSchedule( ctx context.Context, instanceID string, ds fleet.Datastore, + lq fleet.LiveQueryStore, logger kitlog.Logger, enrollHostLimiter fleet.EnrollHostLimiter, config *config.FleetConfig, @@ -720,13 +721,6 @@ func newCleanupsAndAggregationSchedule( schedule.WithAltLockID("leader"), schedule.WithLogger(kitlog.With(logger, "cron", name)), // Run cleanup jobs first. - schedule.WithJob( - "distributed_query_campaigns", - func(ctx context.Context) error { - _, err := ds.CleanupDistributedQueryCampaigns(ctx, time.Now()) - return err - }, - ), schedule.WithJob( "incoming_hosts", func(ctx context.Context) error { @@ -838,6 +832,50 @@ func newCleanupsAndAggregationSchedule( return s, nil } +func newFrequentCleanupsSchedule( + ctx context.Context, + instanceID string, + ds fleet.Datastore, + lq fleet.LiveQueryStore, + logger kitlog.Logger, +) (*schedule.Schedule, error) { + const ( + name = string(fleet.CronFrequentCleanups) + defaultInterval = 15 * time.Minute + ) + s := schedule.New( + ctx, name, instanceID, defaultInterval, ds, ds, + // Using leader for the lock to be backwards compatilibity with old deployments. + schedule.WithAltLockID("leader"), + schedule.WithLogger(kitlog.With(logger, "cron", name)), + // Run cleanup jobs first. + schedule.WithJob( + "distributed_query_campaigns", + func(ctx context.Context) error { + _, err := ds.CleanupDistributedQueryCampaigns(ctx, time.Now().UTC()) + if err != nil { + return err + } + names, err := lq.LoadActiveQueryNames() + if err != nil { + return err + } + ids := stringSliceToUintSlice(names, logger) + completed, err := ds.GetCompletedCampaigns(ctx, ids) + if err != nil { + return err + } + if err := lq.CleanupInactiveQueries(ctx, completed); err != nil { + return err + } + return nil + }, + ), + ) + + return s, nil +} + func verifyDiskEncryptionKeys( ctx context.Context, logger kitlog.Logger, @@ -1094,3 +1132,16 @@ func cronActivitiesStreaming( page += 1 } } + +func stringSliceToUintSlice(s []string, logger kitlog.Logger) []uint { + result := make([]uint, 0, len(s)) + for _, v := range s { + i, err := strconv.ParseUint(v, 10, 64) + if err != nil { + level.Warn(logger).Log("msg", "failed to parse string to uint", "string", v, "err", err) + continue + } + result = append(result, uint(i)) + } + return result +} diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index 9755b97bd..01ef84b29 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -681,12 +681,22 @@ the way that the Fleet server works. }() if err := cronSchedules.StartCronSchedule(func() (fleet.CronSchedule, error) { - var commander *apple_mdm.MDMAppleCommander - if appCfg.MDM.EnabledAndConfigured { - commander = apple_mdm.NewMDMAppleCommander(mdmStorage, mdmPushService) - } - return newCleanupsAndAggregationSchedule(ctx, instanceID, ds, logger, redisWrapperDS, &config, commander) + return newFrequentCleanupsSchedule(ctx, instanceID, ds, liveQueryStore, logger) }); err != nil { + initFatal(err, "failed to register frequent_cleanups schedule") + } + + if err := cronSchedules.StartCronSchedule( + func() (fleet.CronSchedule, error) { + var commander *apple_mdm.MDMAppleCommander + if appCfg.MDM.EnabledAndConfigured { + commander = apple_mdm.NewMDMAppleCommander(mdmStorage, mdmPushService) + } + return newCleanupsAndAggregationSchedule( + ctx, instanceID, ds, liveQueryStore, logger, redisWrapperDS, &config, commander, + ) + }, + ); err != nil { initFatal(err, "failed to register cleanups_then_aggregations schedule") } diff --git a/server/config/config.go b/server/config/config.go index 70f927e6b..e26831ec3 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -818,7 +818,7 @@ func (man Manager) addConfigs() { man.addConfigInt("redis.max_open_conns", 0, "Redis maximum open connections, 0 means no limit") man.addConfigDuration("redis.conn_max_lifetime", 0, "Redis maximum amount of time a connection may be reused, 0 means no limit") man.addConfigDuration("redis.idle_timeout", 240*time.Second, "Redis maximum amount of time a connection may stay idle, 0 means no limit") - man.addConfigDuration("redis.conn_wait_timeout", 0, "Redis maximum amount of time to wait for a connection if the maximum is reached (0 for no wait, ignored in non-cluster Redis)") + man.addConfigDuration("redis.conn_wait_timeout", 0, "Redis maximum amount of time to wait for a connection if the maximum is reached (0 for no wait)") man.addConfigDuration("redis.write_timeout", 10*time.Second, "Redis maximum amount of time to wait for a write (send) on a connection") man.addConfigDuration("redis.read_timeout", 10*time.Second, "Redis maximum amount of time to wait for a read (receive) on a connection") diff --git a/server/datastore/mysql/campaigns.go b/server/datastore/mysql/campaigns.go index 881289cb6..fa0af95f2 100644 --- a/server/datastore/mysql/campaigns.go +++ b/server/datastore/mysql/campaigns.go @@ -2,24 +2,37 @@ package mysql import ( "context" + "fmt" "time" + "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/jmoiron/sqlx" ) func (ds *Datastore) NewDistributedQueryCampaign(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { + args := []any{camp.QueryID, camp.Status, camp.UserID} - sqlStatement := ` + // for tests, we sometimes provide specific timestamps for CreatedAt, honor + // those if provided. + var createdAtField, createdAtPlaceholder string + if !camp.CreatedAt.IsZero() { + createdAtField = ", created_at" + createdAtPlaceholder = ", ?" + args = append(args, camp.CreatedAt) + } + + sqlStatement := fmt.Sprintf(` INSERT INTO distributed_query_campaigns ( query_id, status, user_id + %s ) - VALUES(?,?,?) - ` - result, err := ds.writer(ctx).ExecContext(ctx, sqlStatement, camp.QueryID, camp.Status, camp.UserID) + VALUES(?,?,?%s) + `, createdAtField, createdAtPlaceholder) + result, err := ds.writer(ctx).ExecContext(ctx, sqlStatement, args...) if err != nil { return nil, ctxerr.Wrap(ctx, err, "inserting distributed query campaign") } @@ -121,9 +134,50 @@ func (ds *Datastore) NewDistributedQueryCampaignTarget(ctx context.Context, targ return target, nil } +func (ds *Datastore) GetCompletedCampaigns(ctx context.Context, filter []uint) ([]uint, error) { + // There is a limit of 65,535 (2^16-1) placeholders in MySQL 5.7 + const batchSize = 65535 - 1 + if len(filter) == 0 { + return nil, nil + } + + // We must remove duplicates from the input filter because we process the filter in batches, + // and that could result in duplicated result IDs + filter = server.RemoveDuplicatesFromSlice(filter) + + completed := make([]uint, 0, len(filter)) + for i := 0; i < len(filter); i += batchSize { + end := i + batchSize + if end > len(filter) { + end = len(filter) + } + batch := filter[i:end] + + query, args, err := sqlx.In( + `SELECT id + FROM distributed_query_campaigns + WHERE status = ? + AND id IN (?) + `, fleet.QueryComplete, batch, + ) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building query for completed campaigns") + } + + var rows []uint + // using the writer, so we catch the ones we just marked as completed + err = sqlx.SelectContext(ctx, ds.writer(ctx), &rows, query, args...) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "selecting completed campaigns") + } + completed = append(completed, rows...) + } + return completed, nil +} + func (ds *Datastore) CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error) { // Expire old waiting/running campaigns - sqlStatement := ` + const sqlStatement = ` UPDATE distributed_query_campaigns SET status = ? WHERE (status = ? AND created_at < ?) @@ -140,6 +194,5 @@ func (ds *Datastore) CleanupDistributedQueryCampaigns(ctx context.Context, now t if err != nil { return 0, ctxerr.Wrap(ctx, err, "rows affected updating distributed query campaign") } - return uint(exp), nil } diff --git a/server/datastore/mysql/campaigns_test.go b/server/datastore/mysql/campaigns_test.go index 863fd8531..b7154d92f 100644 --- a/server/datastore/mysql/campaigns_test.go +++ b/server/datastore/mysql/campaigns_test.go @@ -2,6 +2,8 @@ package mysql import ( "context" + "math/rand" + "sort" "testing" "time" @@ -22,6 +24,7 @@ func TestCampaigns(t *testing.T) { {"DistributedQuery", testCampaignsDistributedQuery}, {"CleanupDistributedQuery", testCampaignsCleanupDistributedQuery}, {"SaveDistributedQuery", testCampaignsSaveDistributedQuery}, + {"CompletedCampaigns", testCompletedCampaigns}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -80,10 +83,9 @@ func testCampaignsDistributedQuery(t *testing.T, ds *Datastore) { } func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) { + ctx := context.Background() user := test.NewUser(t, ds, "Zach", "zwass@fleet.co", true) - mockClock := clock.NewMockClock() - query := test.NewQuery(t, ds, nil, "test", "select * from time", user.ID, false) c1 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now()) @@ -91,18 +93,18 @@ func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) { // Cleanup and verify that nothing changed (because time has not // advanced) - expired, err := ds.CleanupDistributedQueryCampaigns(context.Background(), mockClock.Now()) + expired, err := ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now()) require.Nil(t, err) assert.Equal(t, uint(0), expired) { - retrieved, err := ds.DistributedQueryCampaign(context.Background(), c1.ID) + retrieved, err := ds.DistributedQueryCampaign(ctx, c1.ID) require.Nil(t, err) assert.Equal(t, c1.QueryID, retrieved.QueryID) assert.Equal(t, c1.Status, retrieved.Status) } { - retrieved, err := ds.DistributedQueryCampaign(context.Background(), c2.ID) + retrieved, err := ds.DistributedQueryCampaign(ctx, c2.ID) require.Nil(t, err) assert.Equal(t, c2.QueryID, retrieved.QueryID) assert.Equal(t, c2.Status, retrieved.Status) @@ -114,18 +116,18 @@ func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) { // Cleanup and verify that the campaign was expired and executions // deleted appropriately - expired, err = ds.CleanupDistributedQueryCampaigns(context.Background(), mockClock.Now()) + expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now()) require.Nil(t, err) assert.Equal(t, uint(1), expired) { // c1 should now be complete - retrieved, err := ds.DistributedQueryCampaign(context.Background(), c1.ID) + retrieved, err := ds.DistributedQueryCampaign(ctx, c1.ID) require.Nil(t, err) assert.Equal(t, c1.QueryID, retrieved.QueryID) assert.Equal(t, fleet.QueryComplete, retrieved.Status) } { - retrieved, err := ds.DistributedQueryCampaign(context.Background(), c2.ID) + retrieved, err := ds.DistributedQueryCampaign(ctx, c2.ID) require.Nil(t, err) assert.Equal(t, c2.QueryID, retrieved.QueryID) assert.Equal(t, c2.Status, retrieved.Status) @@ -135,22 +137,51 @@ func testCampaignsCleanupDistributedQuery(t *testing.T, ds *Datastore) { // Cleanup and verify that the campaign was expired and executions // deleted appropriately - expired, err = ds.CleanupDistributedQueryCampaigns(context.Background(), mockClock.Now()) + expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now()) require.Nil(t, err) assert.Equal(t, uint(1), expired) { - retrieved, err := ds.DistributedQueryCampaign(context.Background(), c1.ID) + retrieved, err := ds.DistributedQueryCampaign(ctx, c1.ID) require.Nil(t, err) assert.Equal(t, c1.QueryID, retrieved.QueryID) assert.Equal(t, fleet.QueryComplete, retrieved.Status) } { // c2 should now be complete - retrieved, err := ds.DistributedQueryCampaign(context.Background(), c2.ID) + retrieved, err := ds.DistributedQueryCampaign(ctx, c2.ID) require.Nil(t, err) assert.Equal(t, c2.QueryID, retrieved.QueryID) assert.Equal(t, fleet.QueryComplete, retrieved.Status) } + + // simulate another old campaign created > 7 days ago + c3 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now().AddDate(0, 0, -8)) + { + retrieved, err := ds.DistributedQueryCampaign(ctx, c3.ID) + require.Nil(t, err) + assert.Equal(t, c3.QueryID, retrieved.QueryID) + assert.Equal(t, fleet.QueryWaiting, retrieved.Status) + } + + // cleanup will mark c3 as completed because it was waiting for > 1 minute, + // but it won't return it as recently inactive because it's too old a query. + expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now()) + require.Nil(t, err) + assert.Equal(t, uint(1), expired) + + // cleanup again does not expire any new campaign and still returns the same + // recently inactive campaigns + expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now()) + require.Nil(t, err) + assert.Equal(t, uint(0), expired) + + // move time forward 7 days and cleanup again, this time it returns no recent + // inactive campaigns + mockClock.AddTime(7*24*time.Hour + 1*time.Second) + + expired, err = ds.CleanupDistributedQueryCampaigns(ctx, mockClock.Now()) + require.Nil(t, err) + assert.Equal(t, uint(0), expired) } func testCampaignsSaveDistributedQuery(t *testing.T, ds *Datastore) { @@ -180,3 +211,46 @@ func checkTargets(t *testing.T, ds fleet.Datastore, campaignID uint, expectedTar assert.ElementsMatch(t, expectedTargets.LabelIDs, targets.LabelIDs) assert.ElementsMatch(t, expectedTargets.TeamIDs, targets.TeamIDs) } + +func testCompletedCampaigns(t *testing.T, ds *Datastore) { + // Test nil result + result, err := ds.GetCompletedCampaigns(context.Background(), nil) + assert.NoError(t, err) + assert.Len(t, result, 0) + + result, err = ds.GetCompletedCampaigns(context.Background(), []uint{234, 1, 1, 34455455453}) + assert.NoError(t, err) + assert.Len(t, result, 0) + + // Now test reasonable results + user := test.NewUser(t, ds, t.Name(), t.Name()+"zwass@fleet.co", true) + mockClock := clock.NewMockClock() + query := test.NewQuery(t, ds, nil, t.Name()+"test", "select * from time", user.ID, false) + + numCampaigns := 5 + totalFilterSize := 100000 + filter := make([]uint, 0, totalFilterSize) + complete := make([]uint, 0, numCampaigns) + for i := 0; i < numCampaigns; i++ { + c1 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now()) + gotC, err := ds.DistributedQueryCampaign(context.Background(), c1.ID) + require.NoError(t, err) + require.Equal(t, fleet.QueryWaiting, gotC.Status) + if rand.Intn(10) < 7 { //nolint:gosec + c1.Status = fleet.QueryComplete + require.NoError(t, ds.SaveDistributedQueryCampaign(context.Background(), c1)) + complete = append(complete, c1.ID) + } + filter = append(filter, c1.ID) + } + for j := filter[len(filter)-1] / 2; j < uint(totalFilterSize); j++ { // some IDs are duplicated + filter = append(filter, j) + } + rand.Shuffle(len(filter), func(i, j int) { filter[i], filter[j] = filter[j], filter[i] }) + + result, err = ds.GetCompletedCampaigns(context.Background(), filter) + assert.NoError(t, err) + sort.Slice(result, func(i, j int) bool { return result[i] < result[j] }) + assert.Equal(t, complete, result) + +} diff --git a/server/datastore/redis/redis.go b/server/datastore/redis/redis.go index 5384d3984..34b5fc940 100644 --- a/server/datastore/redis/redis.go +++ b/server/datastore/redis/redis.go @@ -1,6 +1,7 @@ package redis import ( + "context" "errors" "fmt" "net" @@ -18,7 +19,24 @@ import ( // redisc.Cluster, so both can satisfy the same interface. type standalonePool struct { *redis.Pool - addr string + addr string + connWaitTimeout time.Duration +} + +func (p *standalonePool) Get() redis.Conn { + if p.connWaitTimeout <= 0 { + return p.Pool.Get() + } + + ctx, cancel := context.WithTimeout(context.Background(), p.connWaitTimeout) + defer cancel() + + // GetContext always returns an "errorConn" as valid connection when there is + // an error, so there's no need to care about the second return value (as for + // the no-wait case, the errorConn will fail on first use with the actual + // error). + conn, _ := p.Pool.GetContext(ctx) + return conn } func (p *standalonePool) Stats() map[string]redis.PoolStats { @@ -83,10 +101,7 @@ func NewPool(config PoolConfig) (fleet.RedisPool, error) { // not a Redis Cluster setup, use a standalone Redis pool pool, _ := cluster.CreatePool(config.Server, cluster.DialOptions...) cluster.Close() - // never wait for a connection in a non-cluster pool as it can block - // indefinitely. - pool.Wait = false - return &standalonePool{pool, config.Server}, nil + return &standalonePool{pool, config.Server, config.ConnWaitTimeout}, nil } return nil, fmt.Errorf("refresh cluster: %w", err) } diff --git a/server/datastore/redis/redis_external_test.go b/server/datastore/redis/redis_external_test.go index fb2de11a8..6a5fc3d31 100644 --- a/server/datastore/redis/redis_external_test.go +++ b/server/datastore/redis/redis_external_test.go @@ -364,3 +364,74 @@ func TestReadTimeout(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "i/o timeout") } + +func TestRedisConnWaitTime(t *testing.T) { + t.Run("standalone", func(t *testing.T) { + pool := redistest.SetupRedisWithConfig(t, "zz", false, false, false, redis.PoolConfig{ConnWaitTimeout: 2 * time.Second, MaxOpenConns: 2}) + + conn1 := pool.Get() + defer conn1.Close() + conn2 := pool.Get() + defer conn2.Close() + + // no more connections available, requesting another will wait and fail after 2s + start := time.Now() + conn3 := pool.Get() + _, err := conn3.Do("PING") + conn3.Close() + require.Error(t, err) + require.GreaterOrEqual(t, time.Since(start), 2*time.Second) + + // request another one but this time close an open connection after a second + go func() { + time.Sleep(time.Second) + conn1.Close() + }() + + start = time.Now() + conn4 := pool.Get() + _, err = conn4.Do("PING") + conn4.Close() + require.NoError(t, err) + require.Less(t, time.Since(start), 2*time.Second) + }) + + t.Run("cluster", func(t *testing.T) { + pool := redistest.SetupRedisWithConfig(t, "wait-timeout-", true, false, false, redis.PoolConfig{ConnWaitTimeout: 2 * time.Second, MaxOpenConns: 2}) + + conn1 := pool.Get() + defer conn1.Close() + conn2 := pool.Get() + defer conn2.Close() + + // bind the connections to the same node by requesting the same key, both connections + // will now be used from the same pool. + _, err := conn1.Do("GET", "wait-timeout-a") + require.NoError(t, err) + _, err = conn2.Do("GET", "wait-timeout-a") + require.NoError(t, err) + + // no more connections available, requesting another will wait and fail after 2s + start := time.Now() + conn3 := pool.Get() + defer conn3.Close() + _, err = conn3.Do("GET", "wait-timeout-a") + conn3.Close() + require.Error(t, err) + require.GreaterOrEqual(t, time.Since(start), 2*time.Second) + + // request another one but this time close an open connection after a second + go func() { + time.Sleep(time.Second) + conn1.Close() + }() + + start = time.Now() + conn4 := pool.Get() + defer conn4.Close() + _, err = conn4.Do("GET", "wait-timeout-a") + conn4.Close() + require.NoError(t, err) + require.Less(t, time.Since(start), 2*time.Second) + }) +} diff --git a/server/datastore/redis/redistest/redistest.go b/server/datastore/redis/redistest/redistest.go index 9c7d15965..689adf20b 100644 --- a/server/datastore/redis/redistest/redistest.go +++ b/server/datastore/redis/redistest/redistest.go @@ -36,6 +36,10 @@ func NopRedis() fleet.RedisPool { } func SetupRedis(tb testing.TB, cleanupKeyPrefix string, cluster, redir, readReplica bool) fleet.RedisPool { + return SetupRedisWithConfig(tb, cleanupKeyPrefix, cluster, redir, readReplica, redis.PoolConfig{}) +} + +func SetupRedisWithConfig(tb testing.TB, cleanupKeyPrefix string, cluster, redir, readReplica bool, config redis.PoolConfig) fleet.RedisPool { if _, ok := os.LookupEnv("REDIS_TEST"); !ok { tb.Skip("set REDIS_TEST environment variable to run redis-based tests") } @@ -59,17 +63,22 @@ func SetupRedis(tb testing.TB, cleanupKeyPrefix string, cluster, redir, readRepl } addr += port - pool, err := redis.NewPool(redis.PoolConfig{ - Server: addr, - Username: username, - Password: password, - Database: database, - UseTLS: useTLS, - ConnTimeout: 5 * time.Second, - KeepAlive: 10 * time.Second, - ClusterFollowRedirections: redir, - ClusterReadFromReplica: readReplica, - }) + // set the mandatory, non-configurable configs for tests + config.Server = addr + config.Username = username + config.Password = password + config.Database = database + config.UseTLS = useTLS + config.ClusterFollowRedirections = redir + config.ClusterReadFromReplica = readReplica + if config.ConnTimeout == 0 { + config.ConnTimeout = 5 * time.Second + } + if config.KeepAlive == 0 { + config.KeepAlive = 10 * time.Second + } + + pool, err := redis.NewPool(config) require.NoError(tb, err) conn := pool.Get() diff --git a/server/fleet/cron_schedules.go b/server/fleet/cron_schedules.go index aeaebf1c4..607f15f85 100644 --- a/server/fleet/cron_schedules.go +++ b/server/fleet/cron_schedules.go @@ -14,6 +14,7 @@ type CronScheduleName string const ( CronAppleMDMDEPProfileAssigner CronScheduleName = "apple_mdm_dep_profile_assigner" CronCleanupsThenAggregation CronScheduleName = "cleanups_then_aggregation" + CronFrequentCleanups CronScheduleName = "frequent_cleanups" CronUsageStatistics CronScheduleName = "usage_statistics" CronVulnerabilities CronScheduleName = "vulnerabilities" CronAutomations CronScheduleName = "automations" diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 7d7108c8f..001349d46 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -120,12 +120,18 @@ type Datastore interface { // NewDistributedQueryCampaignTarget adds a new target to an existing distributed query campaign NewDistributedQueryCampaignTarget(ctx context.Context, target *DistributedQueryCampaignTarget) (*DistributedQueryCampaignTarget, error) - // CleanupDistributedQueryCampaigns will clean and trim metadata for old distributed query campaigns. Any campaign - // in the QueryWaiting state will be moved to QueryComplete after one minute. Any campaign in the QueryRunning state - // will be moved to QueryComplete after one day. Times are from creation time. The now parameter makes this method - // easier to test. The return values indicate how many campaigns were expired and any error. + // CleanupDistributedQueryCampaigns will clean and trim metadata for old + // distributed query campaigns. Any campaign in the QueryWaiting state will + // be moved to QueryComplete after one minute. Any campaign in the + // QueryRunning state will be moved to QueryComplete after one day. Times are + // from creation time. The now parameter makes this method easier to test. + // The return values indicate how many campaigns were expired and any error. CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error) + // GetCompletedCampaigns returns the IDs of the campaigns that are in the fleet.QueryComplete state and that are in the + // provided list of IDs. The return value is a slice of the IDs of the completed campaigns and any error. + GetCompletedCampaigns(ctx context.Context, filter []uint) ([]uint, error) + DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*DistributedQueryCampaign, error) /////////////////////////////////////////////////////////////////////////////// diff --git a/server/fleet/live_query_store.go b/server/fleet/live_query_store.go index 5a28c0060..ef46b8ec6 100644 --- a/server/fleet/live_query_store.go +++ b/server/fleet/live_query_store.go @@ -1,5 +1,7 @@ package fleet +import "context" + // LiveQueryStore defines an interface for storing and retrieving the status of // live queries in the Fleet system. type LiveQueryStore interface { @@ -16,4 +18,10 @@ type LiveQueryStore interface { // given host. After calling QueryCompleted, that query will no longer be // sent to the host. QueryCompletedByHost(name string, hostID uint) error + // CleanupInactiveQueries removes any inactive queries. This is used via a + // cron job to regularly cleanup any queries that may have failed to be + // stopped properly in Redis. + CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error + // LoadActiveQueryNames returns the names of all active queries. + LoadActiveQueryNames() ([]string, error) } diff --git a/server/live_query/live_query_mock/live_query_mock.go b/server/live_query/live_query_mock/live_query_mock.go index 59516be50..bd487f90b 100644 --- a/server/live_query/live_query_mock/live_query_mock.go +++ b/server/live_query/live_query_mock/live_query_mock.go @@ -1,6 +1,7 @@ package live_query_mock import ( + "context" "testing" "github.com/fleetdm/fleet/v4/server/fleet" @@ -45,3 +46,9 @@ func (m *MockLiveQuery) QueryCompletedByHost(name string, hostID uint) error { args := m.Called(name, hostID) return args.Error(0) } + +// CleanupInactiveQueries mocks the live query store CleanupInactiveQueries method. +func (m *MockLiveQuery) CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error { + args := m.Called(ctx, inactiveCampaignIDs) + return args.Error(0) +} diff --git a/server/live_query/live_query_test.go b/server/live_query/live_query_test.go index 9cfd961bc..f7d684963 100644 --- a/server/live_query/live_query_test.go +++ b/server/live_query/live_query_test.go @@ -1,6 +1,7 @@ package live_query import ( + "context" "testing" "github.com/fleetdm/fleet/v4/server/datastore/redis" @@ -16,6 +17,7 @@ var testFunctions = [...]func(*testing.T, fleet.LiveQueryStore){ testLiveQueryStopQuery, testLiveQueryExpiredQuery, testLiveQueryOnlyExpired, + testLiveQueryCleanupInactive, } func testLiveQuery(t *testing.T, store fleet.LiveQueryStore) { @@ -135,3 +137,79 @@ func testLiveQueryOnlyExpired(t *testing.T, store fleet.LiveQueryStore) { require.NoError(t, err) require.Len(t, activeNames, 0) } + +func testLiveQueryCleanupInactive(t *testing.T, store fleet.LiveQueryStore) { + ctx := context.Background() + + // get a raw Redis connection to make direct checks + pool := store.(*redisLiveQuery).pool + conn := redis.ConfigureDoer(pool, pool.Get()) + defer conn.Close() + + // run a few live queries, making them active in Redis + err := store.RunQuery("1", "SELECT 1", []uint{1, 2, 3}) + require.NoError(t, err) + err = store.RunQuery("2", "SELECT 2", []uint{4}) + require.NoError(t, err) + err = store.RunQuery("3", "SELECT 3", []uint{5, 6}) + require.NoError(t, err) + err = store.RunQuery("4", "SELECT 4", []uint{1, 2, 5}) + require.NoError(t, err) + err = store.RunQuery("5", "SELECT 5", []uint{2, 3, 7}) + require.NoError(t, err) + + activeNames, err := store.LoadActiveQueryNames() + require.NoError(t, err) + require.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, activeNames) + + // sanity-check that the queries are properly stored + m, err := store.QueriesForHost(1) + require.NoError(t, err) + require.Equal(t, map[string]string{"1": "SELECT 1", "4": "SELECT 4"}, m) + + // simulate that only campaigns 2 and 4 are still active, cleanup the rest + err = store.CleanupInactiveQueries(ctx, []uint{1, 3, 5}) + require.NoError(t, err) + + activeNames, err = store.LoadActiveQueryNames() + require.NoError(t, err) + require.ElementsMatch(t, []string{"2", "4"}, activeNames) + + m, err = store.QueriesForHost(1) + require.NoError(t, err) + require.Equal(t, map[string]string{"4": "SELECT 4"}, m) + + // explicitly mark campaign 4 as stopped + err = store.StopQuery("4") + require.NoError(t, err) + + // no more queries for host 1 + m, err = store.QueriesForHost(1) + require.NoError(t, err) + require.Empty(t, m) + + // only campaign 2 remains, for host 4 + m, err = store.QueriesForHost(4) + require.NoError(t, err) + require.Equal(t, map[string]string{"2": "SELECT 2"}, m) + + // simulate that there are no inactive campaigns to cleanup + err = store.CleanupInactiveQueries(ctx, nil) + require.NoError(t, err) + + activeNames, err = store.LoadActiveQueryNames() + require.NoError(t, err) + require.ElementsMatch(t, []string{"2"}, activeNames) + + // simulate that all campaigns are inactive, cleanup all + err = store.CleanupInactiveQueries(ctx, []uint{1, 2, 3, 4, 5}) + require.NoError(t, err) + + activeNames, err = store.LoadActiveQueryNames() + require.NoError(t, err) + require.Empty(t, activeNames) + + m, err = store.QueriesForHost(4) + require.NoError(t, err) + require.Empty(t, m) +} diff --git a/server/live_query/redis_live_query.go b/server/live_query/redis_live_query.go index 6156a1944..8b424ba63 100644 --- a/server/live_query/redis_live_query.go +++ b/server/live_query/redis_live_query.go @@ -45,11 +45,14 @@ package live_query import ( + "context" "errors" "fmt" + "strconv" "strings" "time" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/datastore/redis" "github.com/fleetdm/fleet/v4/server/fleet" redigo "github.com/gomodule/redigo/redis" @@ -100,7 +103,7 @@ func extractTargetKeyName(key string) string { // RunQuery stores the live query information in ephemeral storage for the // duration of the query or its TTL. Note that hostIDs *must* be sorted -// in ascending order. +// in ascending order. The name is the campaign ID as a string. func (r *redisLiveQuery) RunQuery(name, sql string, hostIDs []uint) error { if len(hostIDs) == 0 { return errors.New("no hosts targeted") @@ -138,7 +141,7 @@ var cleanupExpiredQueriesModulo int64 = 10 func (r *redisLiveQuery) QueriesForHost(hostID uint) (map[string]string, error) { // Get keys for active queries - names, err := r.loadActiveQueryNames() + names, err := r.LoadActiveQueryNames() if err != nil { return nil, fmt.Errorf("load active queries: %w", err) } @@ -312,7 +315,7 @@ func (r *redisLiveQuery) removeQueryNames(names ...string) error { return err } -func (r *redisLiveQuery) loadActiveQueryNames() ([]string, error) { +func (r *redisLiveQuery) LoadActiveQueryNames() ([]string, error) { conn := redis.ConfigureDoer(r.pool, r.pool.Get()) defer conn.Close() @@ -323,6 +326,62 @@ func (r *redisLiveQuery) loadActiveQueryNames() ([]string, error) { return names, nil } +func (r *redisLiveQuery) CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error { + // the following logic is used to cleanup inactive queries: + // * the inactive campaign IDs are removed from the livequery:active set + // + // At this point, all inactive queries are already effectively deleted - the + // rest is just best effort cleanup to save Redis memory space, but those + // keys would otherwise be ignored and without effect. + // + // * remove the livequery: and sql:livequery: for every inactive + // campaign ID. + + if len(inactiveCampaignIDs) == 0 { + return nil + } + + if err := r.removeInactiveQueries(ctx, inactiveCampaignIDs); err != nil { + return err + } + + keysToDel := make([]string, 0, len(inactiveCampaignIDs)*2) + for _, id := range inactiveCampaignIDs { + targetKey, sqlKey := generateKeys(strconv.FormatUint(uint64(id), 10)) + keysToDel = append(keysToDel, targetKey, sqlKey) + } + + keysBySlot := redis.SplitKeysBySlot(r.pool, keysToDel...) + for _, keys := range keysBySlot { + if err := r.removeBatchInactiveKeys(ctx, keys); err != nil { + return err + } + } + return nil +} + +func (r *redisLiveQuery) removeBatchInactiveKeys(ctx context.Context, keys []string) error { + conn := r.pool.Get() + defer conn.Close() + + args := redigo.Args{}.AddFlat(keys) + if _, err := conn.Do("DEL", args...); err != nil { + return ctxerr.Wrap(ctx, err, "remove batch of inactive keys") + } + return nil +} + +func (r *redisLiveQuery) removeInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error { + conn := r.pool.Get() + defer conn.Close() + + args := redigo.Args{}.Add(activeQueriesKey).AddFlat(inactiveCampaignIDs) + if _, err := conn.Do("SREM", args...); err != nil { + return ctxerr.Wrap(ctx, err, "remove inactive campaign IDs") + } + return nil +} + // mapBitfield takes the given host IDs and maps them into a bitfield compatible // with Redis. It is expected that the input IDs are in ascending order. func mapBitfield(hostIDs []uint) []byte { diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 3abc4f199..123d87880 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -98,6 +98,8 @@ type NewDistributedQueryCampaignTargetFunc func(ctx context.Context, target *fle type CleanupDistributedQueryCampaignsFunc func(ctx context.Context, now time.Time) (expired uint, err error) +type GetCompletedCampaignsFunc func(ctx context.Context, filter []uint) ([]uint, error) + type DistributedQueryCampaignsForQueryFunc func(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) type ApplyPackSpecsFunc func(ctx context.Context, specs []*fleet.PackSpec) error @@ -959,6 +961,9 @@ type DataStore struct { CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFuncInvoked bool + GetCompletedCampaignsFunc GetCompletedCampaignsFunc + GetCompletedCampaignsFuncInvoked bool + DistributedQueryCampaignsForQueryFunc DistributedQueryCampaignsForQueryFunc DistributedQueryCampaignsForQueryFuncInvoked bool @@ -2352,6 +2357,13 @@ func (s *DataStore) CleanupDistributedQueryCampaigns(ctx context.Context, now ti return s.CleanupDistributedQueryCampaignsFunc(ctx, now) } +func (s *DataStore) GetCompletedCampaigns(ctx context.Context, filter []uint) ([]uint, error) { + s.mu.Lock() + s.GetCompletedCampaignsFuncInvoked = true + s.mu.Unlock() + return s.GetCompletedCampaignsFunc(ctx, filter) +} + func (s *DataStore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) { s.mu.Lock() s.DistributedQueryCampaignsForQueryFuncInvoked = true diff --git a/server/pubsub/doc.go b/server/pubsub/errors.go similarity index 100% rename from server/pubsub/doc.go rename to server/pubsub/errors.go diff --git a/server/pubsub/query_results_test.go b/server/pubsub/query_results_test.go index b7d266e54..47ab85ff4 100644 --- a/server/pubsub/query_results_test.go +++ b/server/pubsub/query_results_test.go @@ -96,7 +96,7 @@ func TestQueryResultsStore(t *testing.T) { ctx1, cancel1 := context.WithCancel(context.Background()) channel1, err := store.ReadChannel(ctx1, campaign1) - assert.Nil(t, err) + require.NoError(t, err) expected1 := []fleet.DistributedQueryResult{ { @@ -126,7 +126,7 @@ func TestQueryResultsStore(t *testing.T) { ctx2, cancel2 := context.WithCancel(context.Background()) channel2, err := store.ReadChannel(ctx2, campaign2) - assert.Nil(t, err) + require.NoError(t, err) expected2 := []fleet.DistributedQueryResult{ { diff --git a/server/pubsub/redis_query_results.go b/server/pubsub/redis_query_results.go index b3e8f800b..27b01258a 100644 --- a/server/pubsub/redis_query_results.go +++ b/server/pubsub/redis_query_results.go @@ -90,9 +90,16 @@ func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan< defer close(outChan) for { - // Add a timeout to try to cleanup in the case the server has somehow gone completely unresponsive. + beforeReceive := time.Now() + + // Add a timeout to try to cleanup in the case the server has somehow gone + // completely unresponsive. msg := conn.ReceiveWithTimeout(1 * time.Hour) + if recvTime := time.Since(beforeReceive); recvTime > time.Minute { + level.Info(logger).Log("msg", "conn.ReceiveWithTimeout connection was blocked for significant time", "duration", recvTime, "connection", fmt.Sprintf("%p", conn)) + } + // Pass the message back to ReadChannel. if writeOrDone(ctx, outChan, msg) { return @@ -188,6 +195,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib wg.Wait() psc.Unsubscribe(pubSubName) //nolint:errcheck conn.Close() + level.Debug(logger).Log("msg", "proper close of Redis connection in ReadChannel", "connection", fmt.Sprintf("%p", conn)) }() return outChannel, nil diff --git a/server/service/campaigns_test.go b/server/service/campaigns_test.go index ff3c3f67f..0ba50729c 100644 --- a/server/service/campaigns_test.go +++ b/server/service/campaigns_test.go @@ -30,6 +30,14 @@ func (nopLiveQuery) QueryCompletedByHost(name string, hostID uint) error { return nil } +func (nopLiveQuery) CleanupInactiveQueries(ctx context.Context, inactiveCampaignIDs []uint) error { + return nil +} + +func (q nopLiveQuery) LoadActiveQueryNames() ([]string, error) { + return nil, nil +} + func TestLiveQueryAuth(t *testing.T) { ds := new(mock.Store) qr := pubsub.NewInmemQueryResults() diff --git a/server/utils.go b/server/utils.go index 3e4fe2c7c..76643ee4f 100644 --- a/server/utils.go +++ b/server/utils.go @@ -149,8 +149,8 @@ func Base64DecodePaddingAgnostic(s string) ([]byte, error) { // RemoveDuplicatesFromSlice returns a slice with all the duplicates removed from the input slice. func RemoveDuplicatesFromSlice[T comparable](slice []T) []T { // We are using the allKeys map as a set here - allKeys := make(map[T]struct{}) - var list []T + allKeys := make(map[T]struct{}, len(slice)) + list := make([]T, 0, len(slice)) for _, i := range slice { if _, exists := allKeys[i]; !exists {