mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
e1aac9c776
# Checklist for submitter If some of the following don't apply, delete the relevant line. <!-- Note that API documentation changes are now addressed by the product design team. --> - [ ] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [ ] Documented any permissions changes (docs/Using Fleet/manage-access.md) - [ ] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for new osquery data ingestion features. - [ ] Added/updated tests - [ ] If database migrations are included, checked table schema to confirm autoupdate - For database migrations: - [ ] Checked schema for all modified table for columns that will auto-update timestamps during migration. - [ ] Confirmed that updating the timestamps is acceptable, and will not cause unwanted side effects. - [ ] Manual QA for all new/changed functionality - For Orbit and Fleet Desktop changes: - [ ] Manual QA must be performed in the three main OSs, macOS, Windows and Linux. - [ ] Auto-update manual QA, from released version of component to new version (see [tools/tuf/test](../tools/tuf/test/README.md)).
1145 lines
38 KiB
Go
1145 lines
38 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"math/rand"
|
|
"net/http"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/authz"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/pubsub"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
)
|
|
|
|
func TestIntegrationLiveQueriesTestSuite(t *testing.T) {
|
|
testingSuite := new(liveQueriesTestSuite)
|
|
testingSuite.withServer.s = &testingSuite.Suite
|
|
suite.Run(t, testingSuite)
|
|
}
|
|
|
|
type liveQueriesTestSuite struct {
|
|
withServer
|
|
suite.Suite
|
|
|
|
lq *live_query_mock.MockLiveQuery
|
|
hosts []*fleet.Host
|
|
}
|
|
|
|
// SetupTest partially implements suite.SetupTestSuite.
|
|
func (s *liveQueriesTestSuite) SetupTest() {
|
|
s.lq.Mock.Test(s.T())
|
|
}
|
|
|
|
// SetupSuite partially implements suite.SetupAllSuite.
|
|
func (s *liveQueriesTestSuite) SetupSuite() {
|
|
s.T().Setenv("FLEET_LIVE_QUERY_REST_PERIOD", "5s")
|
|
|
|
s.withDS.SetupSuite("liveQueriesTestSuite")
|
|
|
|
rs := pubsub.NewInmemQueryResults()
|
|
lq := live_query_mock.New(s.T())
|
|
s.lq = lq
|
|
|
|
users, server := RunServerForTestsWithDS(s.T(), s.ds, &TestServerOpts{Lq: lq, Rs: rs})
|
|
s.server = server
|
|
s.users = users
|
|
s.token = getTestAdminToken(s.T(), s.server)
|
|
|
|
t := s.T()
|
|
for i := 0; i < 3; i++ {
|
|
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now().Add(-time.Duration(i) * time.Minute),
|
|
OsqueryHostID: ptr.String(fmt.Sprintf("%s%d", t.Name(), i)),
|
|
NodeKey: ptr.String(fmt.Sprintf("%s%d", t.Name(), i)),
|
|
UUID: fmt.Sprintf("%s%d", t.Name(), i),
|
|
Hostname: fmt.Sprintf("%sfoo.local%d", t.Name(), i),
|
|
})
|
|
require.NoError(s.T(), err)
|
|
s.hosts = append(s.hosts, host)
|
|
}
|
|
}
|
|
|
|
// TearDownTest partially implements suite.TearDownTestSuite.
|
|
func (s *liveQueriesTestSuite) TearDownTest() {
|
|
// reset the mock
|
|
s.lq.Mock = mock.Mock{}
|
|
}
|
|
|
|
type liveQueryEndpoint int
|
|
|
|
const (
|
|
oldEndpoint liveQueryEndpoint = iota
|
|
oneQueryEndpoint
|
|
customQueryOneHostIdEndpoint
|
|
customQueryOneHostIdentifierEndpoint
|
|
)
|
|
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|
test := func(endpoint liveQueryEndpoint, savedQuery bool, hasStats bool) {
|
|
t := s.T()
|
|
|
|
host := s.hosts[0]
|
|
|
|
query := t.Name() + " select 1 from osquery;"
|
|
q1, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: query,
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
Saved: savedQuery,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): query}, nil)
|
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, query, []uint{host.ID}).Return(nil)
|
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
|
|
oneLiveQueryResp := runOneLiveQueryResponse{}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
liveQueryOnHostResp := runLiveQueryOnHostResponse{}
|
|
if endpoint == oneQueryEndpoint {
|
|
liveQueryRequest := runOneLiveQueryRequest{
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp)
|
|
}()
|
|
} else if endpoint == oldEndpoint {
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID},
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
}()
|
|
} else { // customQueryOneHostId(.*)Endpoint
|
|
liveQueryRequest := runLiveQueryOnHostRequest{
|
|
Query: query,
|
|
}
|
|
url := fmt.Sprintf("/api/latest/fleet/hosts/%d/query", host.ID)
|
|
if endpoint == customQueryOneHostIdentifierEndpoint {
|
|
url = fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID)
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON(
|
|
"POST", url, liveQueryRequest, http.StatusOK,
|
|
&liveQueryOnHostResp,
|
|
)
|
|
}()
|
|
}
|
|
|
|
// For loop, waiting for campaign to be created.
|
|
var cid string
|
|
cidChannel := make(chan string)
|
|
go func() {
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
if endpoint == customQueryOneHostIdentifierEndpoint || endpoint == customQueryOneHostIdEndpoint {
|
|
campaign := fleet.DistributedQueryCampaign{}
|
|
err := mysql.ExecAdhocSQLWithError(
|
|
s.ds, func(q sqlx.ExtContext) error {
|
|
return sqlx.GetContext(
|
|
context.Background(), q, &campaign,
|
|
`SELECT * FROM distributed_query_campaigns WHERE status = ? ORDER BY id DESC LIMIT 1`,
|
|
fleet.QueryRunning,
|
|
)
|
|
},
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
continue
|
|
}
|
|
if err != nil {
|
|
t.Error("Error selecting from distributed_query_campaigns", err)
|
|
return
|
|
}
|
|
q1.ID = campaign.QueryID
|
|
cidChannel <- fmt.Sprint(campaign.ID)
|
|
return
|
|
}
|
|
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
|
require.NoError(t, err)
|
|
|
|
if len(campaigns) == 1 && campaigns[0].Status == fleet.QueryRunning {
|
|
cidChannel <- fmt.Sprint(campaigns[0].ID)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
select {
|
|
case cid = <-cidChannel:
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("Timeout: campaign not created/running for TestLiveQueriesRestOneHostOneQuery")
|
|
}
|
|
|
|
var stats *fleet.Stats
|
|
if hasStats {
|
|
stats = &fleet.Stats{
|
|
UserTime: uint64(1),
|
|
SystemTime: uint64(2),
|
|
}
|
|
}
|
|
distributedReq := submitDistributedQueryResultsRequestShim{
|
|
NodeKey: *host.NodeKey,
|
|
Results: map[string]json.RawMessage{
|
|
hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`),
|
|
hostDistributedQueryPrefix + "invalidcid": json.RawMessage(`""`), // empty string is sometimes sent for no results
|
|
hostDistributedQueryPrefix + "9999": json.RawMessage(`""`),
|
|
},
|
|
Statuses: map[string]interface{}{
|
|
hostDistributedQueryPrefix + cid: 0,
|
|
hostDistributedQueryPrefix + "9999": "0",
|
|
},
|
|
Messages: map[string]string{},
|
|
Stats: map[string]*fleet.Stats{
|
|
hostDistributedQueryPrefix + cid: stats,
|
|
},
|
|
}
|
|
distributedResp := submitDistributedQueryResultsResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
|
|
|
wg.Wait()
|
|
|
|
var result fleet.QueryResult
|
|
if endpoint == oneQueryEndpoint {
|
|
assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID)
|
|
assert.Equal(t, 1, oneLiveQueryResp.TargetedHostCount)
|
|
assert.Equal(t, 1, oneLiveQueryResp.RespondedHostCount)
|
|
require.Len(t, oneLiveQueryResp.Results, 1)
|
|
result = oneLiveQueryResp.Results[0]
|
|
} else if endpoint == oldEndpoint {
|
|
require.Len(t, liveQueryResp.Results, 1)
|
|
assert.Equal(t, 1, liveQueryResp.Summary.TargetedHostCount)
|
|
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
|
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
|
require.Len(t, liveQueryResp.Results[0].Results, 1)
|
|
result = liveQueryResp.Results[0].Results[0]
|
|
} else { // customQueryOneHostId(.*)Endpoint
|
|
assert.Empty(t, liveQueryOnHostResp.Error)
|
|
assert.Equal(t, host.ID, liveQueryOnHostResp.HostID)
|
|
assert.Equal(t, fleet.StatusOnline, liveQueryOnHostResp.Status)
|
|
assert.Equal(t, query, liveQueryOnHostResp.Query)
|
|
result = fleet.QueryResult{
|
|
HostID: liveQueryOnHostResp.HostID,
|
|
Rows: liveQueryOnHostResp.Rows,
|
|
}
|
|
}
|
|
assert.Equal(t, host.ID, result.HostID)
|
|
require.Len(t, result.Rows, 1)
|
|
assert.Equal(t, "a", result.Rows[0]["col1"])
|
|
assert.Equal(t, "b", result.Rows[0]["col2"])
|
|
|
|
// For loop, waiting for activity feed to update, which happens after aggregated stats update.
|
|
var activity *fleet.ActivityTypeLiveQuery
|
|
activityUpdated := make(chan *fleet.ActivityTypeLiveQuery)
|
|
go func() {
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
details := json.RawMessage{}
|
|
err := mysql.ExecAdhocSQLWithError(
|
|
s.ds, func(q sqlx.ExtContext) error {
|
|
return sqlx.GetContext(
|
|
context.Background(), q, &details,
|
|
`SELECT details FROM activities WHERE activity_type = 'live_query' ORDER BY id DESC LIMIT 1`,
|
|
)
|
|
},
|
|
)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
t.Error("Error selecting from activity feed", err)
|
|
return
|
|
}
|
|
if err == nil {
|
|
act := fleet.ActivityTypeLiveQuery{}
|
|
err = json.Unmarshal(details, &act)
|
|
require.NoError(t, err)
|
|
if act.QuerySQL == q1.Query {
|
|
assert.Equal(t, act.TargetsCount, uint(1))
|
|
activityUpdated <- &act
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
select {
|
|
case activity = <-activityUpdated:
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("Timeout: activity not created for TestLiveQueriesRestOneHostOneQuery")
|
|
}
|
|
|
|
aggStats, err := mysql.GetAggregatedStats(context.Background(), s.ds, fleet.AggregatedStatsTypeScheduledQuery, q1.ID)
|
|
if savedQuery && hasStats {
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, int(*aggStats.TotalExecutions))
|
|
assert.Equal(t, float64(2), *aggStats.SystemTimeP50)
|
|
assert.Equal(t, float64(2), *aggStats.SystemTimeP95)
|
|
assert.Equal(t, float64(1), *aggStats.UserTimeP50)
|
|
assert.Equal(t, float64(1), *aggStats.UserTimeP95)
|
|
} else {
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
}
|
|
// Check activity
|
|
if savedQuery {
|
|
assert.Equal(t, q1.Name, *activity.QueryName)
|
|
if hasStats {
|
|
assert.Equal(t, 1, int(*activity.Stats.TotalExecutions))
|
|
assert.Equal(t, float64(2), *activity.Stats.SystemTimeP50)
|
|
assert.Equal(t, float64(2), *activity.Stats.SystemTimeP95)
|
|
assert.Equal(t, float64(1), *activity.Stats.UserTimeP50)
|
|
assert.Equal(t, float64(1), *activity.Stats.UserTimeP95)
|
|
}
|
|
}
|
|
}
|
|
s.Run("not saved query (old)", func() { test(oldEndpoint, false, true) })
|
|
s.Run("saved query without stats (old)", func() { test(oldEndpoint, true, false) })
|
|
s.Run("saved query with stats (old)", func() { test(oldEndpoint, true, true) })
|
|
s.Run("not saved query", func() { test(oneQueryEndpoint, false, true) })
|
|
s.Run("saved query without stats", func() { test(oneQueryEndpoint, true, false) })
|
|
s.Run("saved query with stats", func() { test(oneQueryEndpoint, true, true) })
|
|
s.Run("custom query by host id", func() { test(customQueryOneHostIdEndpoint, false, false) })
|
|
s.Run("custom query by host identifier", func() { test(customQueryOneHostIdentifierEndpoint, false, false) })
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
|
|
t := s.T()
|
|
|
|
host := s.hosts[0]
|
|
|
|
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
Saved: rand.Intn(2) == 1, //nolint:gosec
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{
|
|
Query: "select 2 from osquery;",
|
|
Description: "desc2",
|
|
Name: t.Name() + "query2",
|
|
Logging: fleet.LoggingSnapshot,
|
|
Saved: rand.Intn(2) == 1, //nolint:gosec
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
s.lq.On("QueriesForHost", host.ID).Return(map[string]string{
|
|
fmt.Sprint(q1.ID): "select 1 from osquery;",
|
|
fmt.Sprint(q2.ID): "select 2 from osquery;",
|
|
}, nil)
|
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
|
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID, q2.ID},
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
}()
|
|
|
|
// Give the above call a couple of seconds to create the campaign
|
|
time.Sleep(2 * time.Second)
|
|
|
|
cid1 := getCIDForQ(s, q1)
|
|
cid2 := getCIDForQ(s, q2)
|
|
|
|
distributedReq := SubmitDistributedQueryResultsRequest{
|
|
NodeKey: *host.NodeKey,
|
|
Results: map[string][]map[string]string{
|
|
hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}},
|
|
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
|
|
},
|
|
Statuses: map[string]fleet.OsqueryStatus{
|
|
hostDistributedQueryPrefix + cid1: 0,
|
|
hostDistributedQueryPrefix + cid2: 0,
|
|
},
|
|
Messages: map[string]string{
|
|
hostDistributedQueryPrefix + cid1: "some msg",
|
|
hostDistributedQueryPrefix + cid2: "some other msg",
|
|
},
|
|
Stats: map[string]*fleet.Stats{
|
|
hostDistributedQueryPrefix + cid1: {
|
|
UserTime: uint64(1),
|
|
SystemTime: uint64(2),
|
|
},
|
|
},
|
|
}
|
|
distributedResp := submitDistributedQueryResultsResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
|
|
|
wg.Wait()
|
|
|
|
require.Len(t, liveQueryResp.Results, 2)
|
|
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
|
|
|
sort.Slice(liveQueryResp.Results, func(i, j int) bool {
|
|
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
|
|
})
|
|
|
|
require.True(t, q1.ID < q2.ID)
|
|
|
|
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
|
require.Len(t, liveQueryResp.Results[0].Results, 1)
|
|
q1Results := liveQueryResp.Results[0].Results[0]
|
|
require.Len(t, q1Results.Rows, 1)
|
|
assert.Equal(t, "a", q1Results.Rows[0]["col1"])
|
|
assert.Equal(t, "b", q1Results.Rows[0]["col2"])
|
|
|
|
assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
|
|
require.Len(t, liveQueryResp.Results[1].Results, 1)
|
|
q2Results := liveQueryResp.Results[1].Results[0]
|
|
require.Len(t, q2Results.Rows, 2)
|
|
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
|
|
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
|
|
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
|
|
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
|
|
}
|
|
|
|
func getCIDForQ(s *liveQueriesTestSuite, q1 *fleet.Query) string {
|
|
t := s.T()
|
|
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, campaigns, 1)
|
|
cid1 := fmt.Sprint(campaigns[0].ID)
|
|
return cid1
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
|
|
t := s.T()
|
|
|
|
h1 := s.hosts[0]
|
|
h2 := s.hosts[1]
|
|
|
|
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
Saved: rand.Intn(2) == 1, //nolint:gosec
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{
|
|
Query: "select 2 from osquery;",
|
|
Description: "desc2",
|
|
Name: t.Name() + "query2",
|
|
Logging: fleet.LoggingSnapshot,
|
|
Saved: rand.Intn(2) == 1, //nolint:gosec
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{
|
|
fmt.Sprint(q1.ID): "select 1 from osquery;",
|
|
fmt.Sprint(q2.ID): "select 2 from osquery;",
|
|
}, nil)
|
|
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{
|
|
fmt.Sprint(q1.ID): "select 1 from osquery;",
|
|
fmt.Sprint(q2.ID): "select 2 from osquery;",
|
|
}, nil)
|
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID, q2.ID},
|
|
HostIDs: []uint{h1.ID, h2.ID},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
}()
|
|
|
|
// Give the above call a couple of seconds to create the campaign
|
|
time.Sleep(2 * time.Second)
|
|
cid1 := getCIDForQ(s, q1)
|
|
cid2 := getCIDForQ(s, q2)
|
|
for i, h := range []*fleet.Host{h1, h2} {
|
|
distributedReq := SubmitDistributedQueryResultsRequest{
|
|
NodeKey: *h.NodeKey,
|
|
Results: map[string][]map[string]string{
|
|
hostDistributedQueryPrefix + cid1: {{"col1": fmt.Sprintf("a%d", i), "col2": fmt.Sprintf("b%d", i)}},
|
|
hostDistributedQueryPrefix + cid2: {{"col3": fmt.Sprintf("c%d", i), "col4": fmt.Sprintf("d%d", i)}, {"col3": fmt.Sprintf("e%d", i), "col4": fmt.Sprintf("f%d", i)}},
|
|
},
|
|
Statuses: map[string]fleet.OsqueryStatus{
|
|
hostDistributedQueryPrefix + cid1: 0,
|
|
hostDistributedQueryPrefix + cid2: 0,
|
|
},
|
|
Messages: map[string]string{
|
|
hostDistributedQueryPrefix + cid1: "some msg",
|
|
hostDistributedQueryPrefix + cid2: "some other msg",
|
|
},
|
|
Stats: map[string]*fleet.Stats{
|
|
hostDistributedQueryPrefix + cid1: {
|
|
UserTime: uint64(1),
|
|
SystemTime: uint64(2),
|
|
},
|
|
},
|
|
}
|
|
distributedResp := submitDistributedQueryResultsResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
require.Len(t, liveQueryResp.Results, 2) // 2 queries
|
|
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
|
|
|
|
sort.Slice(liveQueryResp.Results, func(i, j int) bool {
|
|
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
|
|
})
|
|
|
|
require.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
|
require.Len(t, liveQueryResp.Results[0].Results, 2)
|
|
for i, r := range liveQueryResp.Results[0].Results {
|
|
require.Len(t, r.Rows, 1)
|
|
assert.Equal(t, fmt.Sprintf("a%d", i), r.Rows[0]["col1"])
|
|
assert.Equal(t, fmt.Sprintf("b%d", i), r.Rows[0]["col2"])
|
|
}
|
|
|
|
require.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
|
|
require.Len(t, liveQueryResp.Results[1].Results, 2)
|
|
for i, r := range liveQueryResp.Results[1].Results {
|
|
require.Len(t, r.Rows, 2)
|
|
assert.Equal(t, fmt.Sprintf("c%d", i), r.Rows[0]["col3"])
|
|
assert.Equal(t, fmt.Sprintf("d%d", i), r.Rows[0]["col4"])
|
|
assert.Equal(t, fmt.Sprintf("e%d", i), r.Rows[1]["col3"])
|
|
assert.Equal(t, fmt.Sprintf("f%d", i), r.Rows[1]["col4"])
|
|
}
|
|
}
|
|
|
|
// TestLiveQueriesSomeFailToAuthorize when a user requests to run a mix of authorized and unauthorized queries
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesSomeFailToAuthorize() {
|
|
t := s.T()
|
|
|
|
host := s.hosts[0]
|
|
|
|
// Unauthorized query
|
|
q1, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// Authorized query
|
|
q2, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: "select 2 from osquery;",
|
|
Description: "desc2",
|
|
Name: t.Name() + "query2",
|
|
Logging: fleet.LoggingSnapshot,
|
|
ObserverCanRun: true,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 2 from osquery;"}, nil)
|
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
|
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
|
|
|
// Switch to observer user.
|
|
originalToken := s.token
|
|
s.token = getTestUserToken(t, s.server, "user2")
|
|
defer func() {
|
|
s.token = originalToken
|
|
}()
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID, q2.ID},
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
}()
|
|
|
|
// Give the above call a couple of seconds to create the campaign
|
|
time.Sleep(2 * time.Second)
|
|
|
|
cid2 := getCIDForQ(s, q2)
|
|
|
|
distributedReq := SubmitDistributedQueryResultsRequest{
|
|
NodeKey: *host.NodeKey,
|
|
Results: map[string][]map[string]string{
|
|
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
|
|
},
|
|
Statuses: map[string]fleet.OsqueryStatus{
|
|
hostDistributedQueryPrefix + cid2: 0,
|
|
},
|
|
Messages: map[string]string{
|
|
hostDistributedQueryPrefix + cid2: "some other msg",
|
|
},
|
|
}
|
|
distributedResp := submitDistributedQueryResultsResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
|
|
|
wg.Wait()
|
|
|
|
require.Len(t, liveQueryResp.Results, 2)
|
|
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
|
|
|
sort.Slice(
|
|
liveQueryResp.Results, func(i, j int) bool {
|
|
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
|
|
},
|
|
)
|
|
|
|
require.True(t, q1.ID < q2.ID)
|
|
|
|
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
|
assert.Nil(t, liveQueryResp.Results[0].Results)
|
|
assert.Equal(t, authz.ForbiddenErrorMessage, *liveQueryResp.Results[0].Error)
|
|
|
|
assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
|
|
require.Len(t, liveQueryResp.Results[1].Results, 1)
|
|
q2Results := liveQueryResp.Results[1].Results[0]
|
|
require.Len(t, q2Results.Rows, 2)
|
|
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
|
|
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
|
|
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
|
|
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
|
|
}
|
|
|
|
// TestLiveQueriesInvalidInput without query/host IDs
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
|
|
t := s.T()
|
|
|
|
host := s.hosts[0]
|
|
|
|
q1, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{},
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
|
|
|
liveQueryRequest = runLiveQueryRequest{
|
|
QueryIDs: nil,
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
|
|
|
// No hosts
|
|
liveQueryRequest = runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID},
|
|
HostIDs: []uint{},
|
|
}
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
|
oneLiveQueryRequest := runOneLiveQueryRequest{
|
|
HostIDs: []uint{},
|
|
}
|
|
oneLiveQueryResp := runOneLiveQueryResponse{}
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
|
|
|
liveQueryRequest = runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID},
|
|
HostIDs: nil,
|
|
}
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusBadRequest, &liveQueryResp)
|
|
oneLiveQueryRequest = runOneLiveQueryRequest{
|
|
HostIDs: nil,
|
|
}
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
|
|
|
// Invalid raw query
|
|
liveQueryOnHostRequest := runLiveQueryOnHostRequest{
|
|
Query: " ",
|
|
}
|
|
liveQueryOnHostResp := runLiveQueryOnHostResponse{}
|
|
s.DoJSON(
|
|
"POST", fmt.Sprintf("/api/latest/fleet/hosts/%d/query", host.ID), liveQueryOnHostRequest, http.StatusBadRequest,
|
|
&liveQueryOnHostResp,
|
|
)
|
|
s.DoJSON(
|
|
"POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryOnHostRequest, http.StatusBadRequest,
|
|
&liveQueryOnHostResp,
|
|
)
|
|
}
|
|
|
|
// TestLiveQueriesFailsToAuthorize when an observer tries to run a live query
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesFailsToAuthorize() {
|
|
t := s.T()
|
|
|
|
host := s.hosts[0]
|
|
|
|
q1, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID},
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
|
|
// Switch to observer user.
|
|
originalToken := s.token
|
|
s.token = getTestUserToken(t, s.server, "user2")
|
|
defer func() {
|
|
s.token = originalToken
|
|
}()
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusForbidden, &liveQueryResp)
|
|
oneLiveQueryRequest := runOneLiveQueryRequest{
|
|
HostIDs: []uint{host.ID},
|
|
}
|
|
oneLiveQueryResp := runOneLiveQueryResponse{}
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusForbidden, &oneLiveQueryResp)
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
|
|
t := s.T()
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{999},
|
|
HostIDs: []uint{888},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
|
|
require.Len(t, liveQueryResp.Results, 1)
|
|
assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount)
|
|
require.NotNil(t, liveQueryResp.Results[0].Error)
|
|
assert.Contains(t, *liveQueryResp.Results[0].Error, "Query 999 was not found in the datastore")
|
|
|
|
oneLiveQueryRequest := runOneLiveQueryRequest{
|
|
HostIDs: []uint{888},
|
|
}
|
|
oneLiveQueryResp := runOneLiveQueryResponse{}
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", 999), oneLiveQueryRequest, http.StatusNotFound, &oneLiveQueryResp)
|
|
assert.Equal(t, 0, oneLiveQueryResp.RespondedHostCount)
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() {
|
|
t := s.T()
|
|
|
|
q1, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID},
|
|
HostIDs: []uint{math.MaxUint},
|
|
}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
|
|
require.Len(t, liveQueryResp.Results, 1)
|
|
assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount)
|
|
assert.Len(t, liveQueryResp.Results[0].Results, 0)
|
|
assert.True(t, strings.Contains(*liveQueryResp.Results[0].Error, "no hosts targeted"))
|
|
|
|
oneLiveQueryRequest := runOneLiveQueryRequest{
|
|
HostIDs: []uint{math.MaxUint},
|
|
}
|
|
oneLiveQueryResp := runOneLiveQueryResponse{}
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {
|
|
test := func(newEndpoint bool) {
|
|
t := s.T()
|
|
|
|
h1 := s.hosts[0]
|
|
h2 := s.hosts[1]
|
|
|
|
q1, err := s.ds.NewQuery(
|
|
context.Background(), &fleet.Query{
|
|
Query: "select 1 from osquery;",
|
|
Description: "desc1",
|
|
Name: t.Name() + "query1",
|
|
Logging: fleet.LoggingSnapshot,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
|
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
oneLiveQueryResp := runOneLiveQueryResponse{}
|
|
liveQueryResp := runLiveQueryResponse{}
|
|
if newEndpoint {
|
|
liveQueryRequest := runOneLiveQueryRequest{
|
|
HostIDs: []uint{h1.ID, h2.ID},
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp)
|
|
}()
|
|
} else {
|
|
liveQueryRequest := runLiveQueryRequest{
|
|
QueryIDs: []uint{q1.ID},
|
|
HostIDs: []uint{h1.ID, h2.ID},
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
|
}()
|
|
}
|
|
|
|
// For loop waiting to create the campaign
|
|
var cid string
|
|
cidChannel := make(chan string)
|
|
go func() {
|
|
for {
|
|
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
|
require.NoError(t, err)
|
|
|
|
if len(campaigns) == 1 && campaigns[0].Status == fleet.QueryRunning {
|
|
cidChannel <- fmt.Sprint(campaigns[0].ID)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
select {
|
|
case cid = <-cidChannel:
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("Timeout: campaign not created/running for TestLiveQueriesRestFailsOnSomeHost")
|
|
}
|
|
|
|
distributedReq := submitDistributedQueryResultsRequestShim{
|
|
NodeKey: *h1.NodeKey,
|
|
Results: map[string]json.RawMessage{
|
|
hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`),
|
|
},
|
|
Statuses: map[string]interface{}{
|
|
hostDistributedQueryPrefix + cid: "0",
|
|
},
|
|
Messages: map[string]string{
|
|
hostDistributedQueryPrefix + cid: "some msg",
|
|
},
|
|
}
|
|
distributedResp := submitDistributedQueryResultsResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
|
|
|
distributedReq = submitDistributedQueryResultsRequestShim{
|
|
NodeKey: *h2.NodeKey,
|
|
Results: map[string]json.RawMessage{
|
|
hostDistributedQueryPrefix + cid: json.RawMessage(`""`),
|
|
},
|
|
Statuses: map[string]interface{}{
|
|
hostDistributedQueryPrefix + cid: 123,
|
|
},
|
|
Messages: map[string]string{
|
|
hostDistributedQueryPrefix + cid: "some error!",
|
|
},
|
|
}
|
|
distributedResp = submitDistributedQueryResultsResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
|
|
|
wg.Wait()
|
|
|
|
var qResults []fleet.QueryResult
|
|
if newEndpoint {
|
|
assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID)
|
|
assert.Equal(t, 2, oneLiveQueryResp.TargetedHostCount)
|
|
assert.Equal(t, 2, oneLiveQueryResp.RespondedHostCount)
|
|
qResults = oneLiveQueryResp.Results
|
|
} else {
|
|
require.Len(t, liveQueryResp.Results, 1)
|
|
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
|
|
qResults = liveQueryResp.Results[0].Results
|
|
}
|
|
require.Len(t, qResults, 2)
|
|
require.Len(t, qResults[0].Rows, 1)
|
|
assert.Equal(t, "a", qResults[0].Rows[0]["col1"])
|
|
assert.Equal(t, "b", qResults[0].Rows[0]["col2"])
|
|
require.Len(t, qResults[1].Rows, 0)
|
|
require.NotNil(t, qResults[1].Error)
|
|
assert.Equal(t, "some error!", *qResults[1].Error)
|
|
}
|
|
s.Run("old endpoint", func() { test(false) })
|
|
s.Run("new endpoint", func() { test(true) })
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestCreateDistributedQueryCampaign() {
|
|
t := s.T()
|
|
|
|
// NOTE: this only tests creating the campaigns, as running them is tested
|
|
// extensively in other test functions.
|
|
|
|
h1 := s.hosts[0]
|
|
h2 := s.hosts[1]
|
|
s.lq.On("RunQuery", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
|
|
|
// create with no payload
|
|
var createResp createDistributedQueryCampaignResponse
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", nil, http.StatusUnprocessableEntity, &createResp)
|
|
|
|
// create with unknown query
|
|
req := createDistributedQueryCampaignRequest{
|
|
QueryID: ptr.Uint(9999),
|
|
Selected: fleet.HostTargets{
|
|
HostIDs: []uint{1},
|
|
},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusNotFound, &createResp)
|
|
|
|
// create with no hosts
|
|
req = createDistributedQueryCampaignRequest{
|
|
QuerySQL: "SELECT 1",
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusBadRequest, &createResp)
|
|
|
|
// wait a second to prevent duplicate name for new query
|
|
time.Sleep(time.Second)
|
|
|
|
// create with new query for specific hosts
|
|
req = createDistributedQueryCampaignRequest{
|
|
QuerySQL: "SELECT 2",
|
|
Selected: fleet.HostTargets{
|
|
HostIDs: []uint{h1.ID, h2.ID},
|
|
},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusOK, &createResp)
|
|
camp1 := *createResp.Campaign
|
|
assert.Equal(t, uint(2), createResp.Campaign.Metrics.TotalHosts)
|
|
|
|
// wait a second to prevent duplicate name for new query
|
|
time.Sleep(time.Second)
|
|
|
|
// create by host name
|
|
req2 := createDistributedQueryCampaignByNamesRequest{
|
|
QuerySQL: "SELECT 3",
|
|
Selected: distributedQueryCampaignTargetsByNames{
|
|
Hosts: []string{h1.Hostname},
|
|
},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run_by_names", req2, http.StatusOK, &createResp)
|
|
assert.NotEqual(t, camp1.ID, createResp.Campaign.ID)
|
|
assert.Equal(t, uint(1), createResp.Campaign.Metrics.TotalHosts)
|
|
|
|
// wait a second to prevent duplicate name for new query
|
|
time.Sleep(time.Second)
|
|
|
|
// create by unknown host name - it ignores the unknown names. Must have at least 1 valid host
|
|
req2 = createDistributedQueryCampaignByNamesRequest{
|
|
QuerySQL: "SELECT 3",
|
|
Selected: distributedQueryCampaignTargetsByNames{
|
|
Hosts: []string{h1.Hostname, h2.Hostname + "ZZZZZ"},
|
|
},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run_by_names", req2, http.StatusOK, &createResp)
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestOsqueryDistributedRead() {
|
|
t := s.T()
|
|
|
|
hostID := s.hosts[1].ID
|
|
s.lq.On("QueriesForHost", hostID).Return(map[string]string{fmt.Sprintf("%d", hostID): "select 1 from osquery;"}, nil)
|
|
|
|
req := getDistributedQueriesRequest{NodeKey: *s.hosts[1].NodeKey}
|
|
var resp getDistributedQueriesResponse
|
|
s.DoJSON("POST", "/api/osquery/distributed/read", req, http.StatusOK, &resp)
|
|
assert.Contains(t, resp.Queries, hostDistributedQueryPrefix+fmt.Sprintf("%d", hostID))
|
|
|
|
// test with invalid node key
|
|
var errRes map[string]interface{}
|
|
req.NodeKey += "zzzz"
|
|
s.DoJSON("POST", "/api/osquery/distributed/read", req, http.StatusUnauthorized, &errRes)
|
|
assert.Contains(t, errRes["error"], "invalid node key")
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestOsqueryDistributedReadWithFeatures() {
|
|
t := s.T()
|
|
|
|
spec := []byte(`
|
|
features:
|
|
additional_queries:
|
|
time: SELECT * FROM time
|
|
enable_host_users: true
|
|
enable_software_inventory: true
|
|
`)
|
|
s.applyConfig(spec)
|
|
|
|
a := json.RawMessage(`{"time": "SELECT * FROM time"}`)
|
|
team, err := s.ds.NewTeam(context.Background(), &fleet.Team{
|
|
ID: 42,
|
|
Name: "team1",
|
|
Description: "desc team1",
|
|
Config: fleet.TeamConfig{
|
|
Features: fleet.Features{
|
|
EnableHostUsers: false,
|
|
EnableSoftwareInventory: false,
|
|
AdditionalQueries: &a,
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now().Add(-1 * time.Minute),
|
|
OsqueryHostID: ptr.String(t.Name()),
|
|
NodeKey: ptr.String(t.Name()),
|
|
UUID: uuid.New().String(),
|
|
Hostname: fmt.Sprintf("%sfoo.local", t.Name()),
|
|
Platform: "darwin",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
s.lq.On("QueriesForHost", host.ID).Return(map[string]string{fmt.Sprintf("%d", host.ID): "select 1 from osquery;"}, nil)
|
|
|
|
err = s.ds.UpdateHostRefetchRequested(context.Background(), host.ID, true)
|
|
require.NoError(t, err)
|
|
req := getDistributedQueriesRequest{NodeKey: *host.NodeKey}
|
|
var dqResp getDistributedQueriesResponse
|
|
s.DoJSON("POST", "/api/osquery/distributed/read", req, http.StatusOK, &dqResp)
|
|
require.Contains(t, dqResp.Queries, "fleet_detail_query_users")
|
|
require.Contains(t, dqResp.Queries, "fleet_detail_query_software_macos")
|
|
|
|
err = s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{host.ID})
|
|
require.NoError(t, err)
|
|
err = s.ds.UpdateHostRefetchRequested(context.Background(), host.ID, true)
|
|
require.NoError(t, err)
|
|
req = getDistributedQueriesRequest{NodeKey: *host.NodeKey}
|
|
dqResp = getDistributedQueriesResponse{}
|
|
s.DoJSON("POST", "/api/osquery/distributed/read", req, http.StatusOK, &dqResp)
|
|
require.Contains(t, dqResp.Queries, "fleet_detail_query_users")
|
|
require.Contains(t, dqResp.Queries, "fleet_detail_query_software_macos")
|
|
}
|
|
|
|
func (s *liveQueriesTestSuite) TestCreateDistributedQueryCampaignBadRequest() {
|
|
t := s.T()
|
|
|
|
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now().Add(-1 * time.Minute),
|
|
OsqueryHostID: ptr.String(t.Name()),
|
|
NodeKey: ptr.String(t.Name()),
|
|
UUID: uuid.New().String(),
|
|
Hostname: fmt.Sprintf("%sfoo.local", t.Name()),
|
|
Platform: "darwin",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Query provided but no targets provided
|
|
req := createDistributedQueryCampaignRequest{
|
|
QuerySQL: "SELECT * FROM osquery_info;",
|
|
}
|
|
var createResp createDistributedQueryCampaignResponse
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusBadRequest, &createResp)
|
|
|
|
// Target provided but no query provided.
|
|
req = createDistributedQueryCampaignRequest{
|
|
Selected: fleet.HostTargets{HostIDs: []uint{host.ID}},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusUnprocessableEntity, &createResp)
|
|
|
|
// Query "provided" but empty.
|
|
req = createDistributedQueryCampaignRequest{
|
|
QuerySQL: " ",
|
|
Selected: fleet.HostTargets{HostIDs: []uint{host.ID}},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusUnprocessableEntity, &createResp)
|
|
|
|
s.lq.On("RunQuery", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
|
|
|
// Query and targets provided.
|
|
req = createDistributedQueryCampaignRequest{
|
|
QuerySQL: "select \"With automounting enabled anyone with physical access could attach a USB drive or disc and have its contents available in system even if they lacked permissions to mount it themselves.\" as Rationale;",
|
|
Selected: fleet.HostTargets{HostIDs: []uint{host.ID}},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusOK, &createResp)
|
|
|
|
// Disable live queries; should get 403, not 500.
|
|
var acResp appConfigResponse
|
|
s.DoJSON("GET", "/api/latest/fleet/config", nil, http.StatusOK, &acResp)
|
|
appCfg := fleet.AppConfig{ServerSettings: fleet.ServerSettings{LiveQueryDisabled: true, ServerURL: acResp.ServerSettings.ServerURL}, OrgInfo: fleet.OrgInfo{OrgName: acResp.OrgInfo.OrgName}}
|
|
s.DoRaw("PATCH", "/api/latest/fleet/config", jsonMustMarshal(t, appCfg), http.StatusOK)
|
|
|
|
req = createDistributedQueryCampaignRequest{
|
|
QuerySQL: "select \"With automounting enabled anyone with physical access could attach a USB drive or disc and have its contents available in system even if they lacked permissions to mount it themselves.\" as Rationale;",
|
|
Selected: fleet.HostTargets{HostIDs: []uint{host.ID}},
|
|
}
|
|
s.DoJSON("POST", "/api/latest/fleet/queries/run", req, http.StatusForbidden, &createResp)
|
|
|
|
s.DoJSON("GET", "/api/latest/fleet/config", nil, http.StatusOK, &acResp)
|
|
appCfg = fleet.AppConfig{ServerSettings: fleet.ServerSettings{LiveQueryDisabled: false, ServerURL: acResp.ServerSettings.ServerURL}, OrgInfo: fleet.OrgInfo{OrgName: acResp.OrgInfo.OrgName}}
|
|
s.DoRaw("PATCH", "/api/latest/fleet/config", jsonMustMarshal(t, appCfg), http.StatusOK)
|
|
}
|