From bb30ffb21f080a3e828d2c278c3dc3c773b2c822 Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Mon, 18 Dec 2023 09:31:07 -0600 Subject: [PATCH] Added live query stats to activity. (#15682) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📺 Loom: https://www.loom.com/share/474ac220dd054c389ebed6ba13f943b1?sid=70c2d94e-9fd9-40f5-af73-13e6d2e4c90a Final piece for #467 -- Added live query stats to activity feed. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Added/updated tests - [x] Manual QA for all new/changed functionality --- changes/467-live-query-stats | 3 +- cmd/fleetctl/query_test.go | 10 ++--- server/fleet/activities.go | 7 ++-- server/service/campaigns.go | 16 +------- server/service/campaigns_test.go | 24 ------------ .../service/integration_live_queries_test.go | 29 +++++++++++++- server/service/live_queries.go | 18 +++++---- server/service/osquery_test.go | 4 -- server/service/service_campaigns.go | 39 +++++++++++++++++-- 9 files changed, 84 insertions(+), 66 deletions(-) diff --git a/changes/467-live-query-stats b/changes/467-live-query-stats index 075fdd519..765f72d75 100644 --- a/changes/467-live-query-stats +++ b/changes/467-live-query-stats @@ -1,4 +1,5 @@ Live Queries now collect stats. Stats are collected for saved queries present in the Queries tab. - After running a live query, user will see updated stats in Queries tab. - Query stats on Host Details page will no longer be cleared after host/agent reboots. -- Query stats are now deleted when query is deleted. \ No newline at end of file +- Query stats are now deleted when query is deleted. +- Activity is now added after the query runs, and includes performance stats for saved queries. diff --git a/cmd/fleetctl/query_test.go b/cmd/fleetctl/query_test.go index ce2188ee7..cf0160364 100644 --- a/cmd/fleetctl/query_test.go +++ b/cmd/fleetctl/query_test.go @@ -76,9 +76,6 @@ func TestSavedLiveQuery(t *testing.T) { ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { return fleet.TargetMetrics{TotalHosts: 1, OnlineHosts: 1}, nil } - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { - return nil - } lq.On("QueriesForHost", uint(1)).Return( map[string]string{ @@ -149,7 +146,9 @@ func TestSavedLiveQuery(t *testing.T) { expected := `{"host":"somehostname","rows":[{"bing":"fds","host_display_name":"somehostname","host_hostname":"somehostname"}]} ` - assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query-name", "saved-query"})) + // Note: runAppForTest never closes the WebSocket connection and does not exit, + // so we are unable to see the activity data that is written after WebSocket disconnects. + assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query-name", queryName})) // We need to use waitGroups to detect whether Database functions were called because this is an asynchronous test which will flag data races otherwise. c := make(chan struct{}) @@ -225,9 +224,6 @@ func TestAdHocLiveQuery(t *testing.T) { ) (fleet.TargetMetrics, error) { return fleet.TargetMetrics{TotalHosts: 1, OnlineHosts: 1}, nil } - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { - return nil - } lq.On("QueriesForHost", uint(1)).Return( map[string]string{ diff --git a/server/fleet/activities.go b/server/fleet/activities.go index 5c5fa79e4..2bdcfd91e 100644 --- a/server/fleet/activities.go +++ b/server/fleet/activities.go @@ -458,9 +458,10 @@ func (a ActivityTypeEditedAgentOptions) Documentation() (activity string, detail } type ActivityTypeLiveQuery struct { - TargetsCount uint `json:"targets_count"` - QuerySQL string `json:"query_sql"` - QueryName *string `json:"query_name,omitempty"` + TargetsCount uint `json:"targets_count"` + QuerySQL string `json:"query_sql"` + QueryName *string `json:"query_name,omitempty"` + Stats *AggregatedStats `json:"stats,omitempty"` } func (a ActivityTypeLiveQuery) ActivityName() string { diff --git a/server/service/campaigns.go b/server/service/campaigns.go index d98d35300..bb272ea7b 100644 --- a/server/service/campaigns.go +++ b/server/service/campaigns.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/contexts/viewer" @@ -162,25 +161,12 @@ func (svc *Service) NewDistributedQueryCampaign(ctx context.Context, queryString return nil, ctxerr.Wrap(ctx, err, "run query") } + // Metrics are used for total hosts targeted for the activity feed. campaign.Metrics, err = svc.ds.CountHostsInTargets(ctx, filter, targets, time.Now()) if err != nil { return nil, ctxerr.Wrap(ctx, err, "counting hosts") } - activityData := fleet.ActivityTypeLiveQuery{ - TargetsCount: campaign.Metrics.TotalHosts, - QuerySQL: query.Query, - } - if queryID != nil { - activityData.QueryName = &query.Name - } - if err := svc.ds.NewActivity( - ctx, - authz.UserFromContext(ctx), - activityData, - ); err != nil { - return nil, ctxerr.Wrap(ctx, err, "create activity for campaign creation") - } return campaign, nil } diff --git a/server/service/campaigns_test.go b/server/service/campaigns_test.go index 1942cbeff..e0234645d 100644 --- a/server/service/campaigns_test.go +++ b/server/service/campaigns_test.go @@ -5,8 +5,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" @@ -85,18 +83,6 @@ func TestLiveQueryAuth(t *testing.T) { ds.CountHostsInTargetsFunc = func(ctx context.Context, filters fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { return fleet.TargetMetrics{}, nil } - var queryName, querySQL string - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { - act := activity.(fleet.ActivityTypeLiveQuery) - name := act.QueryName - if name == nil { - queryName = "" - } else { - queryName = *name - } - querySQL = act.QuerySQL - return nil - } ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { if id == 1 { return query1ObsCanRun, nil @@ -218,25 +204,15 @@ func TestLiveQueryAuth(t *testing.T) { } _, err := svc.NewDistributedQueryCampaign(ctx, query1ObsCanRun.Query, nil, fleet.HostTargets{TeamIDs: tms}) checkAuthErr(t, tt.shouldFailRunNew, err) - checkActivity := func(t testing.TB, err error, expectName, expectSQL string) { - if err != nil { - return - } - require.Equal(t, expectName, queryName) - require.Equal(t, expectSQL, querySQL) - } - checkActivity(t, err, "", query1ObsCanRun.Query) if tt.teamID != nil { tms = []uint{*tt.teamID} } _, err = svc.NewDistributedQueryCampaign(ctx, query1ObsCanRun.Query, ptr.Uint(query1ObsCanRun.ID), fleet.HostTargets{TeamIDs: tms}) checkAuthErr(t, tt.shouldFailRunObsCan, err) - checkActivity(t, err, query1ObsCanRun.Name, query1ObsCanRun.Query) _, err = svc.NewDistributedQueryCampaign(ctx, query2ObsCannotRun.Query, ptr.Uint(query2ObsCannotRun.ID), fleet.HostTargets{TeamIDs: tms}) checkAuthErr(t, tt.shouldFailRunObsCannot, err) - checkActivity(t, err, query2ObsCannotRun.Name, query2ObsCannotRun.Query) // tests with a team target cannot run the "ByNames" calls, as there's no way // to pass a team target with this call. diff --git a/server/service/integration_live_queries_test.go b/server/service/integration_live_queries_test.go index 9a50bb310..8b8320d32 100644 --- a/server/service/integration_live_queries_test.go +++ b/server/service/integration_live_queries_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/jmoiron/sqlx" "math/rand" "net/http" "sort" @@ -159,7 +160,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"]) assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"]) - // Allow time for aggregated stats to update + // Allow time for aggregated stats and activity feed to update time.Sleep(500 * time.Millisecond) aggStats, err := mysql.GetAggregatedStats(context.Background(), s.ds, fleet.AggregatedStatsTypeScheduledQuery, q1.ID) if savedQuery && hasStats { @@ -172,7 +173,31 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { } else { require.ErrorAs(t, err, &sql.ErrNoRows) } - + // Check activity + details := json.RawMessage{} + mysql.ExecAdhocSQL( + t, s.ds, func(q sqlx.ExtContext) error { + return sqlx.GetContext( + context.Background(), q, &details, + `SELECT details FROM activities WHERE activity_type = 'live_query' ORDER BY id DESC LIMIT 1`, + ) + }, + ) + activity := fleet.ActivityTypeLiveQuery{} + err = json.Unmarshal(details, &activity) + require.NoError(t, err) + assert.Equal(t, activity.TargetsCount, uint(1)) + assert.Equal(t, activity.QuerySQL, q1.Query) + if savedQuery { + assert.Equal(t, q1.Name, *activity.QueryName) + if hasStats { + assert.Equal(t, 1, int(*activity.Stats.TotalExecutions)) + assert.Equal(t, float64(2), *activity.Stats.SystemTimeP50) + assert.Equal(t, float64(2), *activity.Stats.SystemTimeP95) + assert.Equal(t, float64(1), *activity.Stats.UserTimeP50) + assert.Equal(t, float64(1), *activity.Stats.UserTimeP95) + } + } } s.Run("not saved query", func() { test(false, true) }) s.Run("saved query without stats", func() { test(true, false) }) diff --git a/server/service/live_queries.go b/server/service/live_queries.go index 683c4eb77..a62b0e282 100644 --- a/server/service/live_queries.go +++ b/server/service/live_queries.go @@ -137,6 +137,16 @@ func (svc *Service) RunLiveQueryDeadline( level.Error(svc.logger).Log("msg", "error checking saved query", "query.id", campaign.QueryID, "err", err) perfStatsTracker.saveStats = false } + // to make sure stats and activity DB operations don't get killed after we return results. + ctxWithoutCancel := context.WithoutCancel(ctx) + totalHosts := campaign.Metrics.TotalHosts + // We update aggregated stats and activity at the end asynchronously. + defer func() { + go func() { + svc.updateStats(ctxWithoutCancel, queryID, svc.logger, &perfStatsTracker, true) + svc.addLiveQueryActivity(ctxWithoutCancel, totalHosts, queryID, svc.logger) + }() + }() loop: for { select { @@ -163,13 +173,7 @@ func (svc *Service) RunLiveQueryDeadline( return } case <-timeout: - // This is the normal path for returning results. We only update aggregated stats here, and without blocking. - if perfStatsTracker.saveStats { - ctxWithoutCancel := context.WithoutCancel(ctx) // to make sure stats DB operations don't get killed after we return results. - go func() { - svc.updateStats(ctxWithoutCancel, campaign.QueryID, svc.logger, &perfStatsTracker, true) - }() - } + // This is the normal path for returning results. break loop case <-ctx.Done(): break loop diff --git a/server/service/osquery_test.go b/server/service/osquery_test.go index 3e029bf89..0fed22ca8 100644 --- a/server/service/osquery_test.go +++ b/server/service/osquery_test.go @@ -2010,13 +2010,9 @@ func TestNewDistributedQueryCampaign(t *testing.T) { }, }) q := "select year, month, day, hour, minutes, seconds from time" - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { - return nil - } campaign, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) require.NoError(t, err) assert.Equal(t, gotQuery.ID, gotCampaign.QueryID) - assert.True(t, ds.NewActivityFuncInvoked) assert.Equal(t, []*fleet.DistributedQueryCampaignTarget{ { Type: fleet.TargetHost, diff --git a/server/service/service_campaigns.go b/server/service/service_campaigns.go index 6f8949802..f8f813d42 100644 --- a/server/service/service_campaigns.go +++ b/server/service/service_campaigns.go @@ -183,6 +183,13 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co level.Error(logger).Log("msg", "error checking saved query", "query.id", campaign.QueryID, "err", err) perfStatsTracker.saveStats = false } + // We aggregate stats and add activity at the end. Using context without cancel for precaution. + queryID := campaign.QueryID + ctxWithoutCancel := context.WithoutCancel(ctx) + defer func() { + svc.updateStats(ctxWithoutCancel, queryID, logger, &perfStatsTracker, true) + svc.addLiveQueryActivity(ctxWithoutCancel, lastTotals.Total, queryID, logger) + }() // Loop, pushing updates to results and expected totals for { @@ -207,7 +214,6 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co } } if ctxerr.Cause(err) == sockjs.ErrSessionNotOpen { - svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true) // return and stop sending the query if the session was closed // by the client return @@ -225,7 +231,6 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co case <-ticker.C: if conn.GetSessionState() == sockjs.SessionClosed { - svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true) // return and stop sending the query if the session was closed // by the client return @@ -233,16 +238,44 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co // Update status if err := updateStatus(); err != nil { level.Error(logger).Log("msg", "error updating status", "err", err) - svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true) return } if status.ActualResults == status.ExpectedResults { + // We update stats when all expected results come in. + // The WebSockets connection can remain open indefinitely, so we make sure we update the stats at this critical point. svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true) } } } } +// addLiveQueryActivity adds live query activity to the activity feed, including the updated aggregated stats +func (svc Service) addLiveQueryActivity( + ctx context.Context, targetsCount uint, queryID uint, logger log.Logger, +) { + activityData := fleet.ActivityTypeLiveQuery{ + TargetsCount: targetsCount, + } + // Query returns SQL, name, and aggregated stats + q, err := svc.ds.Query(ctx, queryID) + if err != nil { + level.Error(logger).Log("msg", "error getting query", "id", queryID, "err", err) + } else { + activityData.QuerySQL = q.Query + if q.Saved { + activityData.QueryName = &q.Name + activityData.Stats = &q.AggregatedStats + } + } + if err := svc.ds.NewActivity( + ctx, + authz.UserFromContext(ctx), + activityData, + ); err != nil { + level.Error(logger).Log("msg", "error creating activity for live query", "err", err) + } +} + func calculateOutputSize(perfStatsTracker *statsTracker, res *fleet.DistributedQueryResult) uint64 { outputSize := uint64(0) if perfStatsTracker.saveStats {