mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Added a new synchronous live query endpoint: POST /api/v1/fleet/queries/:id/run (#15860)
Added a new synchronous live query endpoint: POST /api/v1/fleet/queries/:id/run #14800 All relevant integration tests have been updated to work with the old endpoint and new endpoint. # Checklist for submitter If some of the following don't apply, delete the relevant line. <!-- Note that API documentation changes are now addressed by the product design team. --> - [x] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Added/updated tests - [x] Manual QA for all new/changed functionality
This commit is contained in:
parent
348bcd4263
commit
912abd211d
1
changes/14800-new-live-query-endpoint
Normal file
1
changes/14800-new-live-query-endpoint
Normal file
@ -0,0 +1 @@
|
||||
Added a new synchronous live query endpoint: POST /api/v1/fleet/queries/:id/run
|
@ -285,6 +285,10 @@ func ExecAdhocSQL(tb testing.TB, ds *Datastore, fn func(q sqlx.ExtContext) error
|
||||
require.NoError(tb, err)
|
||||
}
|
||||
|
||||
func ExecAdhocSQLWithError(ds *Datastore, fn func(q sqlx.ExtContext) error) error {
|
||||
return fn(ds.primary)
|
||||
}
|
||||
|
||||
// TruncateTables truncates the specified tables, in order, using ds.writer.
|
||||
// Note that the order is typically not important because FK checks are
|
||||
// disabled while truncating. If no table is provided, all tables (except
|
||||
|
@ -70,4 +70,5 @@ type QueryCampaignResult struct {
|
||||
QueryID uint `json:"query_id"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
Results []QueryResult `json:"results"`
|
||||
Err error `json:"-"`
|
||||
}
|
||||
|
@ -404,7 +404,9 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
||||
ue.GET("/api/_version_/fleet/spec/labels", getLabelSpecsEndpoint, nil)
|
||||
ue.GET("/api/_version_/fleet/spec/labels/{name}", getLabelSpecEndpoint, getGenericSpecRequest{})
|
||||
|
||||
// This GET endpoint runs live queries synchronously (with a configured timeout).
|
||||
// This endpoint runs live queries synchronously (with a configured timeout).
|
||||
ue.POST("/api/_version_/fleet/queries/{id:[0-9]+}/run", runOneLiveQueryEndpoint, runOneLiveQueryRequest{})
|
||||
// Old endpoint, removed from docs. This GET endpoint runs live queries synchronously (with a configured timeout).
|
||||
ue.GET("/api/_version_/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{})
|
||||
// The following two POST APIs are the asynchronous way to run live queries.
|
||||
// The live queries are created with these two endpoints and their results can be queried via
|
||||
|
@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -84,14 +87,15 @@ func (s *liveQueriesTestSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||
test := func(savedQuery bool, hasStats bool) {
|
||||
test := func(newEndpoint bool, savedQuery bool, hasStats bool) {
|
||||
t := s.T()
|
||||
|
||||
host := s.hosts[0]
|
||||
|
||||
query := t.Name() + " select 1 from osquery;"
|
||||
q1, err := s.ds.NewQuery(
|
||||
context.Background(), &fleet.Query{
|
||||
Query: "select 1 from osquery;",
|
||||
Query: query,
|
||||
Description: "desc1",
|
||||
Name: t.Name() + "query1",
|
||||
Logging: fleet.LoggingSnapshot,
|
||||
@ -100,28 +104,54 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): query}, nil)
|
||||
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
|
||||
s.lq.On("RunQuery", mock.Anything, query, []uint{host.ID}).Return(nil)
|
||||
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||
|
||||
liveQueryRequest := runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{host.ID},
|
||||
}
|
||||
liveQueryResp := runLiveQueryResponse{}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
liveQueryResp := runLiveQueryResponse{}
|
||||
if newEndpoint {
|
||||
liveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{host.ID},
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp)
|
||||
}()
|
||||
} else {
|
||||
liveQueryRequest := runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{host.ID},
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||
}()
|
||||
}
|
||||
|
||||
// For loop, waiting for campaign to be created.
|
||||
var cid string
|
||||
cidChannel := make(chan string)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||
for {
|
||||
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(campaigns) == 1 && campaigns[0].Status == fleet.QueryRunning {
|
||||
cidChannel <- fmt.Sprint(campaigns[0].ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Give the above call a couple of seconds to create the campaign
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
cid := getCIDForQ(s, q1)
|
||||
select {
|
||||
case cid = <-cidChannel:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout: campaign not created/running for TestLiveQueriesRestOneHostOneQuery")
|
||||
}
|
||||
|
||||
var stats *fleet.Stats
|
||||
if hasStats {
|
||||
@ -153,15 +183,62 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||
|
||||
wg.Wait()
|
||||
|
||||
require.Len(t, liveQueryResp.Results, 1)
|
||||
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
||||
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
||||
require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1)
|
||||
assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"])
|
||||
assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"])
|
||||
var result fleet.QueryResult
|
||||
if newEndpoint {
|
||||
assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID)
|
||||
assert.Equal(t, 1, oneLiveQueryResp.TargetedHostCount)
|
||||
assert.Equal(t, 1, oneLiveQueryResp.RespondedHostCount)
|
||||
require.Len(t, oneLiveQueryResp.Results, 1)
|
||||
result = oneLiveQueryResp.Results[0]
|
||||
} else {
|
||||
require.Len(t, liveQueryResp.Results, 1)
|
||||
assert.Equal(t, 1, liveQueryResp.Summary.TargetedHostCount)
|
||||
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
||||
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
||||
require.Len(t, liveQueryResp.Results[0].Results, 1)
|
||||
result = liveQueryResp.Results[0].Results[0]
|
||||
}
|
||||
assert.Equal(t, host.ID, result.HostID)
|
||||
require.Len(t, result.Rows, 1)
|
||||
assert.Equal(t, "a", result.Rows[0]["col1"])
|
||||
assert.Equal(t, "b", result.Rows[0]["col2"])
|
||||
|
||||
// For loop, waiting for activity feed to update, which happens after aggregated stats update.
|
||||
var activity *fleet.ActivityTypeLiveQuery
|
||||
activityUpdated := make(chan *fleet.ActivityTypeLiveQuery)
|
||||
go func() {
|
||||
for {
|
||||
details := json.RawMessage{}
|
||||
err := mysql.ExecAdhocSQLWithError(
|
||||
s.ds, func(q sqlx.ExtContext) error {
|
||||
return sqlx.GetContext(
|
||||
context.Background(), q, &details,
|
||||
`SELECT details FROM activities WHERE activity_type = 'live_query' ORDER BY id DESC LIMIT 1`,
|
||||
)
|
||||
},
|
||||
)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Error("Error selecting from activity feed", err)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
act := fleet.ActivityTypeLiveQuery{}
|
||||
err = json.Unmarshal(details, &act)
|
||||
require.NoError(t, err)
|
||||
if act.QuerySQL == q1.Query {
|
||||
assert.Equal(t, act.TargetsCount, uint(1))
|
||||
activityUpdated <- &act
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case activity = <-activityUpdated:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout: activity not created for TestLiveQueriesRestOneHostOneQuery")
|
||||
}
|
||||
|
||||
// Allow time for aggregated stats and activity feed to update
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
aggStats, err := mysql.GetAggregatedStats(context.Background(), s.ds, fleet.AggregatedStatsTypeScheduledQuery, q1.ID)
|
||||
if savedQuery && hasStats {
|
||||
require.NoError(t, err)
|
||||
@ -171,23 +248,9 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||
assert.Equal(t, float64(1), *aggStats.UserTimeP50)
|
||||
assert.Equal(t, float64(1), *aggStats.UserTimeP95)
|
||||
} else {
|
||||
require.ErrorAs(t, err, &sql.ErrNoRows)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
}
|
||||
// Check activity
|
||||
details := json.RawMessage{}
|
||||
mysql.ExecAdhocSQL(
|
||||
t, s.ds, func(q sqlx.ExtContext) error {
|
||||
return sqlx.GetContext(
|
||||
context.Background(), q, &details,
|
||||
`SELECT details FROM activities WHERE activity_type = 'live_query' ORDER BY id DESC LIMIT 1`,
|
||||
)
|
||||
},
|
||||
)
|
||||
activity := fleet.ActivityTypeLiveQuery{}
|
||||
err = json.Unmarshal(details, &activity)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, activity.TargetsCount, uint(1))
|
||||
assert.Equal(t, activity.QuerySQL, q1.Query)
|
||||
if savedQuery {
|
||||
assert.Equal(t, q1.Name, *activity.QueryName)
|
||||
if hasStats {
|
||||
@ -199,9 +262,12 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||
}
|
||||
}
|
||||
}
|
||||
s.Run("not saved query", func() { test(false, true) })
|
||||
s.Run("saved query without stats", func() { test(true, false) })
|
||||
s.Run("saved query with stats", func() { test(true, true) })
|
||||
s.Run("not saved query (old)", func() { test(false, false, true) })
|
||||
s.Run("saved query without stats (old)", func() { test(false, true, false) })
|
||||
s.Run("saved query with stats (old)", func() { test(false, true, true) })
|
||||
s.Run("not saved query", func() { test(true, false, true) })
|
||||
s.Run("saved query without stats", func() { test(true, true, false) })
|
||||
s.Run("saved query with stats", func() { test(true, true, true) })
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
|
||||
@ -548,23 +614,33 @@ func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
|
||||
liveQueryResp := runLiveQueryResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
||||
|
||||
liveQueryRequest = runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{},
|
||||
}
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
||||
|
||||
liveQueryRequest = runLiveQueryRequest{
|
||||
QueryIDs: nil,
|
||||
HostIDs: []uint{host.ID},
|
||||
}
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
||||
|
||||
// No hosts
|
||||
liveQueryRequest = runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{},
|
||||
}
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
||||
oneLiveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{},
|
||||
}
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
||||
|
||||
liveQueryRequest = runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: nil,
|
||||
}
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
||||
oneLiveQueryRequest = runOneLiveQueryRequest{
|
||||
HostIDs: nil,
|
||||
}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
||||
}
|
||||
|
||||
// TestLiveQueriesFailsToAuthorize when an observer tries to run a live query
|
||||
@ -596,6 +672,11 @@ func (s *liveQueriesTestSuite) TestLiveQueriesFailsToAuthorize() {
|
||||
s.token = originalToken
|
||||
}()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusForbidden, &liveQueryResp)
|
||||
oneLiveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{host.ID},
|
||||
}
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusForbidden, &oneLiveQueryResp)
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
|
||||
@ -613,87 +694,169 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
|
||||
assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount)
|
||||
require.NotNil(t, liveQueryResp.Results[0].Error)
|
||||
assert.Contains(t, *liveQueryResp.Results[0].Error, "Query 999 was not found in the datastore")
|
||||
|
||||
oneLiveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{888},
|
||||
}
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", 999), oneLiveQueryRequest, http.StatusNotFound, &oneLiveQueryResp)
|
||||
assert.Equal(t, 0, oneLiveQueryResp.RespondedHostCount)
|
||||
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() {
|
||||
t := s.T()
|
||||
|
||||
h1 := s.hosts[0]
|
||||
h2 := s.hosts[1]
|
||||
|
||||
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{
|
||||
Query: "select 1 from osquery;",
|
||||
Description: "desc1",
|
||||
Name: t.Name() + "query1",
|
||||
Logging: fleet.LoggingSnapshot,
|
||||
})
|
||||
q1, err := s.ds.NewQuery(
|
||||
context.Background(), &fleet.Query{
|
||||
Query: "select 1 from osquery;",
|
||||
Description: "desc1",
|
||||
Name: t.Name() + "query1",
|
||||
Logging: fleet.LoggingSnapshot,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
||||
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||
|
||||
liveQueryRequest := runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{h1.ID, h2.ID},
|
||||
HostIDs: []uint{math.MaxUint},
|
||||
}
|
||||
liveQueryResp := runLiveQueryResponse{}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||
}()
|
||||
|
||||
// Give the above call a couple of seconds to create the campaign
|
||||
time.Sleep(2 * time.Second)
|
||||
cid1 := getCIDForQ(s, q1)
|
||||
distributedReq := submitDistributedQueryResultsRequestShim{
|
||||
NodeKey: *h1.NodeKey,
|
||||
Results: map[string]json.RawMessage{
|
||||
hostDistributedQueryPrefix + cid1: json.RawMessage(`[{"col1": "a", "col2": "b"}]`),
|
||||
},
|
||||
Statuses: map[string]interface{}{
|
||||
hostDistributedQueryPrefix + cid1: "0",
|
||||
},
|
||||
Messages: map[string]string{
|
||||
hostDistributedQueryPrefix + cid1: "some msg",
|
||||
},
|
||||
}
|
||||
distributedResp := submitDistributedQueryResultsResponse{}
|
||||
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||
|
||||
distributedReq = submitDistributedQueryResultsRequestShim{
|
||||
NodeKey: *h2.NodeKey,
|
||||
Results: map[string]json.RawMessage{
|
||||
hostDistributedQueryPrefix + cid1: json.RawMessage(`""`),
|
||||
},
|
||||
Statuses: map[string]interface{}{
|
||||
hostDistributedQueryPrefix + cid1: 123,
|
||||
},
|
||||
Messages: map[string]string{
|
||||
hostDistributedQueryPrefix + cid1: "some error!",
|
||||
},
|
||||
}
|
||||
distributedResp = submitDistributedQueryResultsResponse{}
|
||||
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||
|
||||
wg.Wait()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||
|
||||
require.Len(t, liveQueryResp.Results, 1)
|
||||
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
|
||||
assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount)
|
||||
assert.Len(t, liveQueryResp.Results[0].Results, 0)
|
||||
assert.True(t, strings.Contains(*liveQueryResp.Results[0].Error, "no hosts targeted"))
|
||||
|
||||
result := liveQueryResp.Results[0]
|
||||
require.Len(t, result.Results, 2)
|
||||
require.Len(t, result.Results[0].Rows, 1)
|
||||
assert.Equal(t, "a", result.Results[0].Rows[0]["col1"])
|
||||
assert.Equal(t, "b", result.Results[0].Rows[0]["col2"])
|
||||
require.Len(t, result.Results[1].Rows, 0)
|
||||
require.NotNil(t, result.Results[1].Error)
|
||||
assert.Equal(t, "some error!", *result.Results[1].Error)
|
||||
oneLiveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{math.MaxUint},
|
||||
}
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
||||
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {
|
||||
test := func(newEndpoint bool) {
|
||||
t := s.T()
|
||||
|
||||
h1 := s.hosts[0]
|
||||
h2 := s.hosts[1]
|
||||
|
||||
q1, err := s.ds.NewQuery(
|
||||
context.Background(), &fleet.Query{
|
||||
Query: "select 1 from osquery;",
|
||||
Description: "desc1",
|
||||
Name: t.Name() + "query1",
|
||||
Logging: fleet.LoggingSnapshot,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
||||
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
liveQueryResp := runLiveQueryResponse{}
|
||||
if newEndpoint {
|
||||
liveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{h1.ID, h2.ID},
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp)
|
||||
}()
|
||||
} else {
|
||||
liveQueryRequest := runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{h1.ID, h2.ID},
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||
}()
|
||||
}
|
||||
|
||||
// For loop waiting to create the campaign
|
||||
var cid string
|
||||
cidChannel := make(chan string)
|
||||
go func() {
|
||||
for {
|
||||
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(campaigns) == 1 && campaigns[0].Status == fleet.QueryRunning {
|
||||
cidChannel <- fmt.Sprint(campaigns[0].ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case cid = <-cidChannel:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout: campaign not created/running for TestLiveQueriesRestFailsOnSomeHost")
|
||||
}
|
||||
|
||||
distributedReq := submitDistributedQueryResultsRequestShim{
|
||||
NodeKey: *h1.NodeKey,
|
||||
Results: map[string]json.RawMessage{
|
||||
hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`),
|
||||
},
|
||||
Statuses: map[string]interface{}{
|
||||
hostDistributedQueryPrefix + cid: "0",
|
||||
},
|
||||
Messages: map[string]string{
|
||||
hostDistributedQueryPrefix + cid: "some msg",
|
||||
},
|
||||
}
|
||||
distributedResp := submitDistributedQueryResultsResponse{}
|
||||
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||
|
||||
distributedReq = submitDistributedQueryResultsRequestShim{
|
||||
NodeKey: *h2.NodeKey,
|
||||
Results: map[string]json.RawMessage{
|
||||
hostDistributedQueryPrefix + cid: json.RawMessage(`""`),
|
||||
},
|
||||
Statuses: map[string]interface{}{
|
||||
hostDistributedQueryPrefix + cid: 123,
|
||||
},
|
||||
Messages: map[string]string{
|
||||
hostDistributedQueryPrefix + cid: "some error!",
|
||||
},
|
||||
}
|
||||
distributedResp = submitDistributedQueryResultsResponse{}
|
||||
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
var qResults []fleet.QueryResult
|
||||
if newEndpoint {
|
||||
assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID)
|
||||
assert.Equal(t, 2, oneLiveQueryResp.TargetedHostCount)
|
||||
assert.Equal(t, 2, oneLiveQueryResp.RespondedHostCount)
|
||||
qResults = oneLiveQueryResp.Results
|
||||
} else {
|
||||
require.Len(t, liveQueryResp.Results, 1)
|
||||
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
|
||||
qResults = liveQueryResp.Results[0].Results
|
||||
}
|
||||
require.Len(t, qResults, 2)
|
||||
require.Len(t, qResults[0].Rows, 1)
|
||||
assert.Equal(t, "a", qResults[0].Rows[0]["col1"])
|
||||
assert.Equal(t, "b", qResults[0].Rows[0]["col2"])
|
||||
require.Len(t, qResults[1].Rows, 0)
|
||||
require.NotNil(t, qResults[1].Error)
|
||||
assert.Equal(t, "some error!", *qResults[1].Error)
|
||||
}
|
||||
s.Run("old endpoint", func() { test(false) })
|
||||
s.Run("new endpoint", func() { test(true) })
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestCreateDistributedQueryCampaign() {
|
||||
|
@ -22,6 +22,11 @@ type runLiveQueryRequest struct {
|
||||
HostIDs []uint `json:"host_ids"`
|
||||
}
|
||||
|
||||
type runOneLiveQueryRequest struct {
|
||||
QueryID uint `url:"id"`
|
||||
HostIDs []uint `json:"host_ids"`
|
||||
}
|
||||
|
||||
type summaryPayload struct {
|
||||
TargetedHostCount int `json:"targeted_host_count"`
|
||||
RespondedHostCount int `json:"responded_host_count"`
|
||||
@ -36,9 +41,72 @@ type runLiveQueryResponse struct {
|
||||
|
||||
func (r runLiveQueryResponse) error() error { return r.Err }
|
||||
|
||||
type runOneLiveQueryResponse struct {
|
||||
QueryID uint `json:"query_id"`
|
||||
TargetedHostCount int `json:"targeted_host_count"`
|
||||
RespondedHostCount int `json:"responded_host_count"`
|
||||
Results []fleet.QueryResult `json:"results"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r runOneLiveQueryResponse) error() error { return r.Err }
|
||||
|
||||
func runOneLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
||||
req := request.(*runOneLiveQueryRequest)
|
||||
|
||||
// Only allow a host to be specified once in HostIDs
|
||||
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
|
||||
|
||||
campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, hostIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//goland:noinspection GoPreferNilSlice -- use an empty slice here so that API returns an empty array if there are no results
|
||||
queryResults := []fleet.QueryResult{}
|
||||
if len(campaignResults) > 0 {
|
||||
if campaignResults[0].Err != nil {
|
||||
return nil, campaignResults[0].Err
|
||||
}
|
||||
if campaignResults[0].Results != nil {
|
||||
queryResults = campaignResults[0].Results
|
||||
}
|
||||
}
|
||||
|
||||
res := runOneLiveQueryResponse{
|
||||
QueryID: req.QueryID,
|
||||
TargetedHostCount: len(hostIDs),
|
||||
RespondedHostCount: respondedHostCount,
|
||||
Results: queryResults,
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
||||
req := request.(*runLiveQueryRequest)
|
||||
|
||||
// Only allow a query to be specified once
|
||||
queryIDs := server.RemoveDuplicatesFromSlice(req.QueryIDs)
|
||||
// Only allow a host to be specified once in HostIDs
|
||||
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
|
||||
|
||||
queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, hostIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := runLiveQueryResponse{
|
||||
Summary: summaryPayload{
|
||||
TargetedHostCount: len(hostIDs),
|
||||
RespondedHostCount: respondedHostCount,
|
||||
},
|
||||
Results: queryResults,
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostIDs []uint) (
|
||||
[]fleet.QueryCampaignResult, int, error,
|
||||
) {
|
||||
// The period used here should always be less than the request timeout for any load
|
||||
// balancer/proxy between Fleet and the API client.
|
||||
period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD")
|
||||
@ -51,20 +119,9 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
||||
logging.WithExtras(ctx, "live_query_rest_period_err", err)
|
||||
}
|
||||
|
||||
// Only allow a query to be specified once
|
||||
req.QueryIDs = server.RemoveDuplicatesFromSlice(req.QueryIDs)
|
||||
// Only allow a host to be specified once in HostIDs
|
||||
req.HostIDs = server.RemoveDuplicatesFromSlice(req.HostIDs)
|
||||
res := runLiveQueryResponse{
|
||||
Summary: summaryPayload{
|
||||
TargetedHostCount: len(req.HostIDs),
|
||||
RespondedHostCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
|
||||
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, hostIDs, duration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
// Check if all query results were forbidden due to lack of authorization.
|
||||
allResultsForbidden := len(queryResults) > 0 && respondedHostCount == 0
|
||||
@ -77,12 +134,11 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
||||
}
|
||||
}
|
||||
if allResultsForbidden {
|
||||
return nil, authz.ForbiddenWithInternal("All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil)
|
||||
return nil, 0, authz.ForbiddenWithInternal(
|
||||
"All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil,
|
||||
)
|
||||
}
|
||||
res.Results = queryResults
|
||||
res.Summary.RespondedHostCount = respondedHostCount
|
||||
|
||||
return res, nil
|
||||
return queryResults, respondedHostCount, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RunLiveQueryDeadline(
|
||||
@ -106,13 +162,13 @@ func (svc *Service) RunLiveQueryDeadline(
|
||||
defer wg.Done()
|
||||
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs})
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
|
||||
return
|
||||
}
|
||||
|
||||
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
|
||||
return
|
||||
}
|
||||
defer cancelFunc()
|
||||
@ -120,7 +176,7 @@ func (svc *Service) RunLiveQueryDeadline(
|
||||
defer func() {
|
||||
err := svc.CompleteCampaign(ctx, campaign)
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
|
||||
}
|
||||
}()
|
||||
|
||||
@ -168,12 +224,14 @@ func (svc *Service) RunLiveQueryDeadline(
|
||||
svc.updateStats(ctx, campaign.QueryID, svc.logger, &perfStatsTracker, false)
|
||||
}
|
||||
}
|
||||
if len(results) == len(hostIDs) {
|
||||
break loop
|
||||
}
|
||||
case error:
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error())}
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error()), Err: res}
|
||||
return
|
||||
}
|
||||
case <-timeout:
|
||||
// This is the normal path for returning results.
|
||||
break loop
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
|
@ -324,6 +324,7 @@ func TestCalculateOutputSize(t *testing.T) {
|
||||
result.Rows = append(result.Rows, map[string]string{})
|
||||
result.Rows = append(result.Rows, map[string]string{"a": "b", "a1": "b1"})
|
||||
result.Rows = append(result.Rows, map[string]string{"c": "d"})
|
||||
result.Stats = &fleet.Stats{}
|
||||
return &result
|
||||
}
|
||||
t.Run(
|
||||
|
@ -278,7 +278,8 @@ func (svc Service) addLiveQueryActivity(
|
||||
|
||||
func calculateOutputSize(perfStatsTracker *statsTracker, res *fleet.DistributedQueryResult) uint64 {
|
||||
outputSize := uint64(0)
|
||||
if perfStatsTracker.saveStats {
|
||||
// We only need the output size if other stats are present.
|
||||
if perfStatsTracker.saveStats && res.Stats != nil {
|
||||
for _, row := range res.Rows {
|
||||
if row == nil {
|
||||
continue
|
||||
|
Loading…
Reference in New Issue
Block a user