diff --git a/changes/14800-new-live-query-endpoint b/changes/14800-new-live-query-endpoint new file mode 100644 index 000000000..706d3c291 --- /dev/null +++ b/changes/14800-new-live-query-endpoint @@ -0,0 +1 @@ +Added a new synchronous live query endpoint: POST /api/v1/fleet/queries/:id/run diff --git a/server/datastore/mysql/testing_utils.go b/server/datastore/mysql/testing_utils.go index f32d3c4ee..c39c935d0 100644 --- a/server/datastore/mysql/testing_utils.go +++ b/server/datastore/mysql/testing_utils.go @@ -285,6 +285,10 @@ func ExecAdhocSQL(tb testing.TB, ds *Datastore, fn func(q sqlx.ExtContext) error require.NoError(tb, err) } +func ExecAdhocSQLWithError(ds *Datastore, fn func(q sqlx.ExtContext) error) error { + return fn(ds.primary) +} + // TruncateTables truncates the specified tables, in order, using ds.writer. // Note that the order is typically not important because FK checks are // disabled while truncating. If no table is provided, all tables (except diff --git a/server/fleet/campaigns.go b/server/fleet/campaigns.go index aae607c0b..ff0b2e7ad 100644 --- a/server/fleet/campaigns.go +++ b/server/fleet/campaigns.go @@ -70,4 +70,5 @@ type QueryCampaignResult struct { QueryID uint `json:"query_id"` Error *string `json:"error,omitempty"` Results []QueryResult `json:"results"` + Err error `json:"-"` } diff --git a/server/service/handler.go b/server/service/handler.go index 8123d61ec..fa59f1d9e 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -404,7 +404,9 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ue.GET("/api/_version_/fleet/spec/labels", getLabelSpecsEndpoint, nil) ue.GET("/api/_version_/fleet/spec/labels/{name}", getLabelSpecEndpoint, getGenericSpecRequest{}) - // This GET endpoint runs live queries synchronously (with a configured timeout). + // This endpoint runs live queries synchronously (with a configured timeout). + ue.POST("/api/_version_/fleet/queries/{id:[0-9]+}/run", runOneLiveQueryEndpoint, runOneLiveQueryRequest{}) + // Old endpoint, removed from docs. This GET endpoint runs live queries synchronously (with a configured timeout). ue.GET("/api/_version_/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{}) // The following two POST APIs are the asynchronous way to run live queries. // The live queries are created with these two endpoints and their results can be queried via diff --git a/server/service/integration_live_queries_test.go b/server/service/integration_live_queries_test.go index 8b8320d32..a427a2250 100644 --- a/server/service/integration_live_queries_test.go +++ b/server/service/integration_live_queries_test.go @@ -4,11 +4,14 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "github.com/jmoiron/sqlx" + "math" "math/rand" "net/http" "sort" + "strings" "sync" "testing" "time" @@ -84,14 +87,15 @@ func (s *liveQueriesTestSuite) TearDownTest() { } func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { - test := func(savedQuery bool, hasStats bool) { + test := func(newEndpoint bool, savedQuery bool, hasStats bool) { t := s.T() host := s.hosts[0] + query := t.Name() + " select 1 from osquery;" q1, err := s.ds.NewQuery( context.Background(), &fleet.Query{ - Query: "select 1 from osquery;", + Query: query, Description: "desc1", Name: t.Name() + "query1", Logging: fleet.LoggingSnapshot, @@ -100,28 +104,54 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { ) require.NoError(t, err) - s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): query}, nil) s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) - s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil) + s.lq.On("RunQuery", mock.Anything, query, []uint{host.ID}).Return(nil) s.lq.On("StopQuery", mock.Anything).Return(nil) - liveQueryRequest := runLiveQueryRequest{ - QueryIDs: []uint{q1.ID}, - HostIDs: []uint{host.ID}, - } - liveQueryResp := runLiveQueryResponse{} - wg := sync.WaitGroup{} wg.Add(1) + + oneLiveQueryResp := runOneLiveQueryResponse{} + liveQueryResp := runLiveQueryResponse{} + if newEndpoint { + liveQueryRequest := runOneLiveQueryRequest{ + HostIDs: []uint{host.ID}, + } + go func() { + defer wg.Done() + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp) + }() + } else { + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{q1.ID}, + HostIDs: []uint{host.ID}, + } + go func() { + defer wg.Done() + s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + }() + } + + // For loop, waiting for campaign to be created. + var cid string + cidChannel := make(chan string) go func() { - defer wg.Done() - s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + for { + campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID) + require.NoError(t, err) + + if len(campaigns) == 1 && campaigns[0].Status == fleet.QueryRunning { + cidChannel <- fmt.Sprint(campaigns[0].ID) + return + } + } }() - - // Give the above call a couple of seconds to create the campaign - time.Sleep(2 * time.Second) - - cid := getCIDForQ(s, q1) + select { + case cid = <-cidChannel: + case <-time.After(5 * time.Second): + t.Error("Timeout: campaign not created/running for TestLiveQueriesRestOneHostOneQuery") + } var stats *fleet.Stats if hasStats { @@ -153,15 +183,62 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { wg.Wait() - require.Len(t, liveQueryResp.Results, 1) - assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount) - assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) - require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1) - assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"]) - assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"]) + var result fleet.QueryResult + if newEndpoint { + assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID) + assert.Equal(t, 1, oneLiveQueryResp.TargetedHostCount) + assert.Equal(t, 1, oneLiveQueryResp.RespondedHostCount) + require.Len(t, oneLiveQueryResp.Results, 1) + result = oneLiveQueryResp.Results[0] + } else { + require.Len(t, liveQueryResp.Results, 1) + assert.Equal(t, 1, liveQueryResp.Summary.TargetedHostCount) + assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount) + assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) + require.Len(t, liveQueryResp.Results[0].Results, 1) + result = liveQueryResp.Results[0].Results[0] + } + assert.Equal(t, host.ID, result.HostID) + require.Len(t, result.Rows, 1) + assert.Equal(t, "a", result.Rows[0]["col1"]) + assert.Equal(t, "b", result.Rows[0]["col2"]) + + // For loop, waiting for activity feed to update, which happens after aggregated stats update. + var activity *fleet.ActivityTypeLiveQuery + activityUpdated := make(chan *fleet.ActivityTypeLiveQuery) + go func() { + for { + details := json.RawMessage{} + err := mysql.ExecAdhocSQLWithError( + s.ds, func(q sqlx.ExtContext) error { + return sqlx.GetContext( + context.Background(), q, &details, + `SELECT details FROM activities WHERE activity_type = 'live_query' ORDER BY id DESC LIMIT 1`, + ) + }, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + t.Error("Error selecting from activity feed", err) + return + } + if err == nil { + act := fleet.ActivityTypeLiveQuery{} + err = json.Unmarshal(details, &act) + require.NoError(t, err) + if act.QuerySQL == q1.Query { + assert.Equal(t, act.TargetsCount, uint(1)) + activityUpdated <- &act + return + } + } + } + }() + select { + case activity = <-activityUpdated: + case <-time.After(5 * time.Second): + t.Error("Timeout: activity not created for TestLiveQueriesRestOneHostOneQuery") + } - // Allow time for aggregated stats and activity feed to update - time.Sleep(500 * time.Millisecond) aggStats, err := mysql.GetAggregatedStats(context.Background(), s.ds, fleet.AggregatedStatsTypeScheduledQuery, q1.ID) if savedQuery && hasStats { require.NoError(t, err) @@ -171,23 +248,9 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { assert.Equal(t, float64(1), *aggStats.UserTimeP50) assert.Equal(t, float64(1), *aggStats.UserTimeP95) } else { - require.ErrorAs(t, err, &sql.ErrNoRows) + require.ErrorIs(t, err, sql.ErrNoRows) } // Check activity - details := json.RawMessage{} - mysql.ExecAdhocSQL( - t, s.ds, func(q sqlx.ExtContext) error { - return sqlx.GetContext( - context.Background(), q, &details, - `SELECT details FROM activities WHERE activity_type = 'live_query' ORDER BY id DESC LIMIT 1`, - ) - }, - ) - activity := fleet.ActivityTypeLiveQuery{} - err = json.Unmarshal(details, &activity) - require.NoError(t, err) - assert.Equal(t, activity.TargetsCount, uint(1)) - assert.Equal(t, activity.QuerySQL, q1.Query) if savedQuery { assert.Equal(t, q1.Name, *activity.QueryName) if hasStats { @@ -199,9 +262,12 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { } } } - s.Run("not saved query", func() { test(false, true) }) - s.Run("saved query without stats", func() { test(true, false) }) - s.Run("saved query with stats", func() { test(true, true) }) + s.Run("not saved query (old)", func() { test(false, false, true) }) + s.Run("saved query without stats (old)", func() { test(false, true, false) }) + s.Run("saved query with stats (old)", func() { test(false, true, true) }) + s.Run("not saved query", func() { test(true, false, true) }) + s.Run("saved query without stats", func() { test(true, true, false) }) + s.Run("saved query with stats", func() { test(true, true, true) }) } func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() { @@ -548,23 +614,33 @@ func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() { liveQueryResp := runLiveQueryResponse{} s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp) - liveQueryRequest = runLiveQueryRequest{ - QueryIDs: []uint{q1.ID}, - HostIDs: []uint{}, - } - s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp) - liveQueryRequest = runLiveQueryRequest{ QueryIDs: nil, HostIDs: []uint{host.ID}, } s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp) + // No hosts + liveQueryRequest = runLiveQueryRequest{ + QueryIDs: []uint{q1.ID}, + HostIDs: []uint{}, + } + s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp) + oneLiveQueryRequest := runOneLiveQueryRequest{ + HostIDs: []uint{}, + } + oneLiveQueryResp := runOneLiveQueryResponse{} + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp) + liveQueryRequest = runLiveQueryRequest{ QueryIDs: []uint{q1.ID}, HostIDs: nil, } s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp) + oneLiveQueryRequest = runOneLiveQueryRequest{ + HostIDs: nil, + } + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp) } // TestLiveQueriesFailsToAuthorize when an observer tries to run a live query @@ -596,6 +672,11 @@ func (s *liveQueriesTestSuite) TestLiveQueriesFailsToAuthorize() { s.token = originalToken }() s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusForbidden, &liveQueryResp) + oneLiveQueryRequest := runOneLiveQueryRequest{ + HostIDs: []uint{host.ID}, + } + oneLiveQueryResp := runOneLiveQueryResponse{} + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusForbidden, &oneLiveQueryResp) } func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() { @@ -613,87 +694,169 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() { assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount) require.NotNil(t, liveQueryResp.Results[0].Error) assert.Contains(t, *liveQueryResp.Results[0].Error, "Query 999 was not found in the datastore") + + oneLiveQueryRequest := runOneLiveQueryRequest{ + HostIDs: []uint{888}, + } + oneLiveQueryResp := runOneLiveQueryResponse{} + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", 999), oneLiveQueryRequest, http.StatusNotFound, &oneLiveQueryResp) + assert.Equal(t, 0, oneLiveQueryResp.RespondedHostCount) + } -func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() { +func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() { t := s.T() - h1 := s.hosts[0] - h2 := s.hosts[1] - - q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{ - Query: "select 1 from osquery;", - Description: "desc1", - Name: t.Name() + "query1", - Logging: fleet.LoggingSnapshot, - }) + q1, err := s.ds.NewQuery( + context.Background(), &fleet.Query{ + Query: "select 1 from osquery;", + Description: "desc1", + Name: t.Name() + "query1", + Logging: fleet.LoggingSnapshot, + }, + ) require.NoError(t, err) - s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) - s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) - s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) - s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil) - s.lq.On("StopQuery", mock.Anything).Return(nil) - liveQueryRequest := runLiveQueryRequest{ QueryIDs: []uint{q1.ID}, - HostIDs: []uint{h1.ID, h2.ID}, + HostIDs: []uint{math.MaxUint}, } liveQueryResp := runLiveQueryResponse{} - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) - }() - - // Give the above call a couple of seconds to create the campaign - time.Sleep(2 * time.Second) - cid1 := getCIDForQ(s, q1) - distributedReq := submitDistributedQueryResultsRequestShim{ - NodeKey: *h1.NodeKey, - Results: map[string]json.RawMessage{ - hostDistributedQueryPrefix + cid1: json.RawMessage(`[{"col1": "a", "col2": "b"}]`), - }, - Statuses: map[string]interface{}{ - hostDistributedQueryPrefix + cid1: "0", - }, - Messages: map[string]string{ - hostDistributedQueryPrefix + cid1: "some msg", - }, - } - distributedResp := submitDistributedQueryResultsResponse{} - s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) - - distributedReq = submitDistributedQueryResultsRequestShim{ - NodeKey: *h2.NodeKey, - Results: map[string]json.RawMessage{ - hostDistributedQueryPrefix + cid1: json.RawMessage(`""`), - }, - Statuses: map[string]interface{}{ - hostDistributedQueryPrefix + cid1: 123, - }, - Messages: map[string]string{ - hostDistributedQueryPrefix + cid1: "some error!", - }, - } - distributedResp = submitDistributedQueryResultsResponse{} - s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) - - wg.Wait() + s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) require.Len(t, liveQueryResp.Results, 1) - assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount) + assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount) + assert.Len(t, liveQueryResp.Results[0].Results, 0) + assert.True(t, strings.Contains(*liveQueryResp.Results[0].Error, "no hosts targeted")) - result := liveQueryResp.Results[0] - require.Len(t, result.Results, 2) - require.Len(t, result.Results[0].Rows, 1) - assert.Equal(t, "a", result.Results[0].Rows[0]["col1"]) - assert.Equal(t, "b", result.Results[0].Rows[0]["col2"]) - require.Len(t, result.Results[1].Rows, 0) - require.NotNil(t, result.Results[1].Error) - assert.Equal(t, "some error!", *result.Results[1].Error) + oneLiveQueryRequest := runOneLiveQueryRequest{ + HostIDs: []uint{math.MaxUint}, + } + oneLiveQueryResp := runOneLiveQueryResponse{} + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp) + +} + +func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() { + test := func(newEndpoint bool) { + t := s.T() + + h1 := s.hosts[0] + h2 := s.hosts[1] + + q1, err := s.ds.NewQuery( + context.Background(), &fleet.Query{ + Query: "select 1 from osquery;", + Description: "desc1", + Name: t.Name() + "query1", + Logging: fleet.LoggingSnapshot, + }, + ) + require.NoError(t, err) + + s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil) + s.lq.On("StopQuery", mock.Anything).Return(nil) + + wg := sync.WaitGroup{} + wg.Add(1) + oneLiveQueryResp := runOneLiveQueryResponse{} + liveQueryResp := runLiveQueryResponse{} + if newEndpoint { + liveQueryRequest := runOneLiveQueryRequest{ + HostIDs: []uint{h1.ID, h2.ID}, + } + go func() { + defer wg.Done() + s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp) + }() + } else { + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{q1.ID}, + HostIDs: []uint{h1.ID, h2.ID}, + } + go func() { + defer wg.Done() + s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + }() + } + + // For loop waiting to create the campaign + var cid string + cidChannel := make(chan string) + go func() { + for { + campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID) + require.NoError(t, err) + + if len(campaigns) == 1 && campaigns[0].Status == fleet.QueryRunning { + cidChannel <- fmt.Sprint(campaigns[0].ID) + return + } + } + }() + select { + case cid = <-cidChannel: + case <-time.After(5 * time.Second): + t.Error("Timeout: campaign not created/running for TestLiveQueriesRestFailsOnSomeHost") + } + + distributedReq := submitDistributedQueryResultsRequestShim{ + NodeKey: *h1.NodeKey, + Results: map[string]json.RawMessage{ + hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`), + }, + Statuses: map[string]interface{}{ + hostDistributedQueryPrefix + cid: "0", + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid: "some msg", + }, + } + distributedResp := submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + + distributedReq = submitDistributedQueryResultsRequestShim{ + NodeKey: *h2.NodeKey, + Results: map[string]json.RawMessage{ + hostDistributedQueryPrefix + cid: json.RawMessage(`""`), + }, + Statuses: map[string]interface{}{ + hostDistributedQueryPrefix + cid: 123, + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid: "some error!", + }, + } + distributedResp = submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + + wg.Wait() + + var qResults []fleet.QueryResult + if newEndpoint { + assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID) + assert.Equal(t, 2, oneLiveQueryResp.TargetedHostCount) + assert.Equal(t, 2, oneLiveQueryResp.RespondedHostCount) + qResults = oneLiveQueryResp.Results + } else { + require.Len(t, liveQueryResp.Results, 1) + assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount) + qResults = liveQueryResp.Results[0].Results + } + require.Len(t, qResults, 2) + require.Len(t, qResults[0].Rows, 1) + assert.Equal(t, "a", qResults[0].Rows[0]["col1"]) + assert.Equal(t, "b", qResults[0].Rows[0]["col2"]) + require.Len(t, qResults[1].Rows, 0) + require.NotNil(t, qResults[1].Error) + assert.Equal(t, "some error!", *qResults[1].Error) + } + s.Run("old endpoint", func() { test(false) }) + s.Run("new endpoint", func() { test(true) }) } func (s *liveQueriesTestSuite) TestCreateDistributedQueryCampaign() { diff --git a/server/service/live_queries.go b/server/service/live_queries.go index a62b0e282..6bb10308e 100644 --- a/server/service/live_queries.go +++ b/server/service/live_queries.go @@ -22,6 +22,11 @@ type runLiveQueryRequest struct { HostIDs []uint `json:"host_ids"` } +type runOneLiveQueryRequest struct { + QueryID uint `url:"id"` + HostIDs []uint `json:"host_ids"` +} + type summaryPayload struct { TargetedHostCount int `json:"targeted_host_count"` RespondedHostCount int `json:"responded_host_count"` @@ -36,9 +41,72 @@ type runLiveQueryResponse struct { func (r runLiveQueryResponse) error() error { return r.Err } +type runOneLiveQueryResponse struct { + QueryID uint `json:"query_id"` + TargetedHostCount int `json:"targeted_host_count"` + RespondedHostCount int `json:"responded_host_count"` + Results []fleet.QueryResult `json:"results"` + Err error `json:"error,omitempty"` +} + +func (r runOneLiveQueryResponse) error() error { return r.Err } + +func runOneLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { + req := request.(*runOneLiveQueryRequest) + + // Only allow a host to be specified once in HostIDs + hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs) + + campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, hostIDs) + if err != nil { + return nil, err + } + //goland:noinspection GoPreferNilSlice -- use an empty slice here so that API returns an empty array if there are no results + queryResults := []fleet.QueryResult{} + if len(campaignResults) > 0 { + if campaignResults[0].Err != nil { + return nil, campaignResults[0].Err + } + if campaignResults[0].Results != nil { + queryResults = campaignResults[0].Results + } + } + + res := runOneLiveQueryResponse{ + QueryID: req.QueryID, + TargetedHostCount: len(hostIDs), + RespondedHostCount: respondedHostCount, + Results: queryResults, + } + return res, nil +} + func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { req := request.(*runLiveQueryRequest) + // Only allow a query to be specified once + queryIDs := server.RemoveDuplicatesFromSlice(req.QueryIDs) + // Only allow a host to be specified once in HostIDs + hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs) + + queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, hostIDs) + if err != nil { + return nil, err + } + + res := runLiveQueryResponse{ + Summary: summaryPayload{ + TargetedHostCount: len(hostIDs), + RespondedHostCount: respondedHostCount, + }, + Results: queryResults, + } + return res, nil +} + +func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostIDs []uint) ( + []fleet.QueryCampaignResult, int, error, +) { // The period used here should always be less than the request timeout for any load // balancer/proxy between Fleet and the API client. period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD") @@ -51,20 +119,9 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se logging.WithExtras(ctx, "live_query_rest_period_err", err) } - // Only allow a query to be specified once - req.QueryIDs = server.RemoveDuplicatesFromSlice(req.QueryIDs) - // Only allow a host to be specified once in HostIDs - req.HostIDs = server.RemoveDuplicatesFromSlice(req.HostIDs) - res := runLiveQueryResponse{ - Summary: summaryPayload{ - TargetedHostCount: len(req.HostIDs), - RespondedHostCount: 0, - }, - } - - queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration) + queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, hostIDs, duration) if err != nil { - return nil, err + return nil, 0, err } // Check if all query results were forbidden due to lack of authorization. allResultsForbidden := len(queryResults) > 0 && respondedHostCount == 0 @@ -77,12 +134,11 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se } } if allResultsForbidden { - return nil, authz.ForbiddenWithInternal("All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil) + return nil, 0, authz.ForbiddenWithInternal( + "All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil, + ) } - res.Results = queryResults - res.Summary.RespondedHostCount = respondedHostCount - - return res, nil + return queryResults, respondedHostCount, nil } func (svc *Service) RunLiveQueryDeadline( @@ -106,13 +162,13 @@ func (svc *Service) RunLiveQueryDeadline( defer wg.Done() campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs}) if err != nil { - resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())} + resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err} return } readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign) if err != nil { - resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())} + resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err} return } defer cancelFunc() @@ -120,7 +176,7 @@ func (svc *Service) RunLiveQueryDeadline( defer func() { err := svc.CompleteCampaign(ctx, campaign) if err != nil { - resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())} + resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err} } }() @@ -168,12 +224,14 @@ func (svc *Service) RunLiveQueryDeadline( svc.updateStats(ctx, campaign.QueryID, svc.logger, &perfStatsTracker, false) } } + if len(results) == len(hostIDs) { + break loop + } case error: - resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error())} + resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error()), Err: res} return } case <-timeout: - // This is the normal path for returning results. break loop case <-ctx.Done(): break loop diff --git a/server/service/service_campaign_test.go b/server/service/service_campaign_test.go index cf13fc17a..e03cf6e55 100644 --- a/server/service/service_campaign_test.go +++ b/server/service/service_campaign_test.go @@ -324,6 +324,7 @@ func TestCalculateOutputSize(t *testing.T) { result.Rows = append(result.Rows, map[string]string{}) result.Rows = append(result.Rows, map[string]string{"a": "b", "a1": "b1"}) result.Rows = append(result.Rows, map[string]string{"c": "d"}) + result.Stats = &fleet.Stats{} return &result } t.Run( diff --git a/server/service/service_campaigns.go b/server/service/service_campaigns.go index f8f813d42..fe1fea794 100644 --- a/server/service/service_campaigns.go +++ b/server/service/service_campaigns.go @@ -278,7 +278,8 @@ func (svc Service) addLiveQueryActivity( func calculateOutputSize(perfStatsTracker *statsTracker, res *fleet.DistributedQueryResult) uint64 { outputSize := uint64(0) - if perfStatsTracker.saveStats { + // We only need the output size if other stats are present. + if perfStatsTracker.saveStats && res.Stats != nil { for _, row := range res.Rows { if row == nil { continue