Make /refetch run labels and policies for the host (#2651)

* Make /refetch run labels and policies for the host

* PR rework changes
This commit is contained in:
Lucas Manuel Rodriguez 2021-10-25 15:46:49 -03:00 committed by GitHub
parent cab1099890
commit fb1e78d039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 262 additions and 140 deletions

View File

@ -3,7 +3,7 @@
If some of the following don't apply, delete the relevant line.
- [ ] Changes file added (for user-visible changes)
- [ ] Documented any API changes
- [ ] Documented any API changes (docs/01-Using-Fleet/03-REST-API.md)
- [ ] Documented any permissions changes
- [ ] Added/updated tests
- [ ] Manual QA for all new/changed functionality

View File

@ -0,0 +1 @@
* Refetch will now make a host run policy and label queries next time it checks in for distributed queries.

View File

@ -893,7 +893,7 @@ Deletes the specified host from Fleet. Note that a deleted host will fail authen
### Refetch host
Flags the host details to be refetched the next time the host checks in for live queries. Note that we cannot be certain when the host will actually check in and update these details. Further requests to the host APIs will indicate that the refetch has been requested through the `refetch_requested` field on the host object.
Flags the host details, labels and policies to be refetched the next time the host checks in for distributed queries. Note that we cannot be certain when the host will actually check in and update the query results. Further requests to the host APIs will indicate that the refetch has been requested through the `refetch_requested` field on the host object.
`POST /api/v1/fleet/hosts/{id}/refetch`

View File

@ -14,8 +14,15 @@ type OsqueryService interface {
) (nodeKey string, err error)
AuthenticateHost(ctx context.Context, nodeKey string) (host *Host, debug bool, err error)
GetClientConfig(ctx context.Context) (config map[string]interface{}, err error)
// GetDistributedQueries retrieves the distributed queries to run for the host in the provided context. These may be
// detail queries, label queries, or user-initiated distributed queries. A map from query name to query is returned.
// GetDistributedQueries retrieves the distributed queries to run for the host in
// the provided context. These may be (depending on update intervals):
// - detail queries (including additional queries, if any),
// - label queries,
// - user-initiated distributed queries (aka live queries),
// - policy queries.
//
// A map from query name to query is returned.
//
// To enable the osquery "accelerated checkins" feature, a positive integer (number of seconds to activate for)
// should be returned. Returning 0 for this will not activate the feature.
GetDistributedQueries(ctx context.Context) (queries map[string]string, accelerate uint, err error)

View File

@ -402,19 +402,19 @@ const hostPolicyQueryPrefix = "fleet_policy_query_"
// run from a distributed query campaign
const hostDistributedQueryPrefix = "fleet_distributed_query_"
// hostDetailQueries returns the map of queries that should be executed by
// osqueryd to fill in the host details
func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map[string]string, error) {
queries := make(map[string]string)
// detailQueriesForHost returns the map of detail+additional queries that should be executed by
// osqueryd to fill in the host details.
func (svc *Service) detailQueriesForHost(ctx context.Context, host fleet.Host) (map[string]string, error) {
if !svc.shouldUpdate(host.DetailUpdatedAt, svc.config.Osquery.DetailUpdateInterval) && !host.RefetchRequested {
// No need to update already fresh details
return queries, nil
}
config, err := svc.ds.AppConfig(ctx)
if err != nil {
return nil, osqueryError{message: "get additional queries: " + err.Error()}
return nil, nil
}
config, err := svc.ds.AppConfig(ctx)
if err != nil {
return nil, errors.Wrap(err, "read app config")
}
queries := make(map[string]string)
detailQueries := osquery_utils.GetDetailQueries(config)
for name, query := range detailQueries {
if query.RunsForPlatform(host.Platform) {
@ -422,7 +422,6 @@ func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map
}
}
// Get additional queries
if config.HostSettings.AdditionalQueries == nil {
// No additional queries set
return queries, nil
@ -430,7 +429,7 @@ func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map
var additionalQueries map[string]string
if err := json.Unmarshal(*config.HostSettings.AdditionalQueries, &additionalQueries); err != nil {
return nil, osqueryError{message: "unmarshal additional queries: " + err.Error()}
return nil, errors.Wrap(err, "unmarshal additional queries")
}
for name, query := range additionalQueries {
@ -453,6 +452,28 @@ func (svc *Service) shouldUpdate(lastUpdated time.Time, interval time.Duration)
return lastUpdated.Before(cutoff)
}
func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
if !svc.shouldUpdate(host.LabelUpdatedAt, svc.config.Osquery.LabelUpdateInterval) && !host.RefetchRequested {
return nil, nil
}
labelQueries, err := svc.ds.LabelQueriesForHost(ctx, host)
if err != nil {
return nil, errors.Wrap(err, "retrieve label queries")
}
return labelQueries, nil
}
func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
if !svc.shouldUpdate(host.PolicyUpdatedAt, svc.config.Osquery.PolicyUpdateInterval) && !host.RefetchRequested {
return nil, nil
}
policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host)
if err != nil {
return nil, errors.Wrap(err, "retrieve policy queries")
}
return policyQueries, nil
}
func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]string, uint, error) {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
@ -464,41 +485,38 @@ func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]strin
return nil, 0, osqueryError{message: "internal error: missing host from request context"}
}
queries, err := svc.hostDetailQueries(ctx, host)
queries := make(map[string]string)
detailQueries, err := svc.detailQueriesForHost(ctx, host)
if err != nil {
return nil, 0, err
return nil, 0, osqueryError{message: err.Error()}
}
for name, query := range detailQueries {
queries[name] = query
}
// Retrieve the label queries that should be updated
if svc.shouldUpdate(host.LabelUpdatedAt, svc.config.Osquery.LabelUpdateInterval) {
labelQueries, err := svc.ds.LabelQueriesForHost(ctx, &host)
if err != nil {
return nil, 0, osqueryError{message: "retrieving label queries: " + err.Error()}
}
for name, query := range labelQueries {
queries[hostLabelQueryPrefix+name] = query
}
labelQueries, err := svc.labelQueriesForHost(ctx, &host)
if err != nil {
return nil, 0, osqueryError{message: err.Error()}
}
for name, query := range labelQueries {
queries[hostLabelQueryPrefix+name] = query
}
liveQueries, err := svc.liveQueryStore.QueriesForHost(host.ID)
if err != nil {
return nil, 0, osqueryError{message: "retrieve live queries: " + err.Error()}
}
for name, query := range liveQueries {
queries[hostDistributedQueryPrefix+name] = query
}
if svc.shouldUpdate(host.PolicyUpdatedAt, svc.config.Osquery.PolicyUpdateInterval) {
policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, &host)
if err != nil {
return nil, 0, osqueryError{message: "retrieving policy queries: " + err.Error()}
}
for name, query := range policyQueries {
queries[hostPolicyQueryPrefix+name] = query
}
policyQueries, err := svc.policyQueriesForHost(ctx, &host)
if err != nil {
return nil, 0, osqueryError{message: err.Error()}
}
for name, query := range policyQueries {
queries[hostPolicyQueryPrefix+name] = query
}
accelerate := uint(0)
@ -535,9 +553,6 @@ func (svc *Service) ingestDetailQuery(ctx context.Context, host *fleet.Host, nam
}
}
// Refetch is no longer needed after ingesting details.
host.RefetchRequested = false
return nil
}
@ -644,7 +659,6 @@ func (svc *Service) SubmitDistributedQueryResults(
logIPs(ctx)
host, ok := hostctx.FromContext(ctx)
if !ok {
return osqueryError{message: "internal error: missing host from request context"}
}
@ -733,6 +747,11 @@ func (svc *Service) SubmitDistributedQueryResults(
svc.maybeDebugHost(ctx, host, results, statuses, messages)
if host.RefetchRequested {
host.RefetchRequested = false
host.Modified = true
}
if host.Modified {
err = svc.ds.SaveHost(ctx, &host)
if err != nil {

View File

@ -58,7 +58,7 @@ func TestEnrollAgent(t *testing.T) {
svc := newTestService(ds, nil, nil)
nodeKey, err := svc.EnrollAgent(context.Background(), "valid_secret", "host123", nil)
require.Nil(t, err)
require.NoError(t, err)
assert.NotEmpty(t, nodeKey)
}
@ -114,7 +114,7 @@ func TestEnrollAgentDetails(t *testing.T) {
"foo": {"foo": "bar"},
}
nodeKey, err := svc.EnrollAgent(context.Background(), "", "host123", details)
require.Nil(t, err)
require.NoError(t, err)
assert.NotEmpty(t, nodeKey)
assert.Equal(t, "Mac OS X 10.14.5", gotHost.OSVersion)
@ -144,19 +144,19 @@ func TestAuthenticateHost(t *testing.T) {
}
_, _, err := svc.AuthenticateHost(context.Background(), "test")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "test", gotKey)
assert.False(t, ds.MarkHostsSeenFuncInvoked)
host = fleet.Host{ID: 7, Hostname: "foobar"}
_, _, err = svc.AuthenticateHost(context.Background(), "floobar")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "floobar", gotKey)
assert.False(t, ds.MarkHostsSeenFuncInvoked)
// Host checks in twice
host = fleet.Host{ID: 7, Hostname: "foobar"}
_, _, err = svc.AuthenticateHost(context.Background(), "floobar")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "floobar", gotKey)
assert.False(t, ds.MarkHostsSeenFuncInvoked)
@ -168,7 +168,7 @@ func TestAuthenticateHost(t *testing.T) {
err = svc.FlushSeenHosts(context.Background())
require.NoError(t, err)
assert.True(t, ds.MarkHostsSeenFuncInvoked)
assert.Len(t, gotHostIDs, 0)
require.Len(t, gotHostIDs, 0)
}
func TestAuthenticateHostFailure(t *testing.T) {
@ -210,12 +210,12 @@ func TestSubmitStatusLogs(t *testing.T) {
var status []json.RawMessage
err := json.Unmarshal([]byte(logJSON), &status)
require.Nil(t, err)
require.NoError(t, err)
host := fleet.Host{}
ctx := hostctx.NewContext(context.Background(), host)
err = serv.SubmitStatusLogs(ctx, status)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, status, testLogger.logs)
}
@ -242,12 +242,12 @@ func TestSubmitResultLogs(t *testing.T) {
var results []json.RawMessage
err := json.Unmarshal([]byte(logJSON), &results)
require.Nil(t, err)
require.NoError(t, err)
host := fleet.Host{}
ctx := hostctx.NewContext(context.Background(), host)
err = serv.SubmitResultLogs(ctx, results)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, results, testLogger.logs)
}
@ -280,23 +280,23 @@ func TestHostDetailQueries(t *testing.T) {
svc := &Service{clock: mockClock, config: config.TestConfig(), ds: ds}
queries, err := svc.hostDetailQueries(context.Background(), host)
assert.Nil(t, err)
queries, err := svc.detailQueriesForHost(context.Background(), host)
require.NoError(t, err)
assert.Empty(t, queries)
// With refetch requested queries should be returned
// With refetch requested detail queries should be returned
host.RefetchRequested = true
queries, err = svc.hostDetailQueries(context.Background(), host)
assert.Nil(t, err)
queries, err = svc.detailQueriesForHost(context.Background(), host)
require.NoError(t, err)
assert.NotEmpty(t, queries)
host.RefetchRequested = false
// Advance the time
mockClock.AddTime(1*time.Hour + 1*time.Minute)
queries, err = svc.hostDetailQueries(context.Background(), host)
assert.Nil(t, err)
assert.Len(t, queries, expectedDetailQueries+2)
queries, err = svc.detailQueriesForHost(context.Background(), host)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+2)
for name := range queries {
assert.True(t,
strings.HasPrefix(name, hostDetailQueryPrefix) || strings.HasPrefix(name, hostAdditionalQueryPrefix),
@ -330,7 +330,8 @@ func TestLabelQueries(t *testing.T) {
ds.HostFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return host, nil
}
ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error {
ds.SaveHostFunc = func(ctx context.Context, gotHost *fleet.Host) error {
host = gotHost
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
@ -347,18 +348,18 @@ func TestLabelQueries(t *testing.T) {
// With a new host, we should get the detail queries (and accelerate
// should be turned on so that we can quickly fill labels)
queries, acc, err := svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, expectedDetailQueries)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries)
assert.NotZero(t, acc)
// Simulate the detail queries being added
// Simulate the detail queries being added.
host.DetailUpdatedAt = mockClock.Now().Add(-1 * time.Minute)
host.Hostname = "zwass.local"
ctx = hostctx.NewContext(ctx, *host)
queries, acc, err = svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, 0)
require.NoError(t, err)
require.Len(t, queries, 0)
assert.Zero(t, acc)
ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) {
@ -371,8 +372,8 @@ func TestLabelQueries(t *testing.T) {
// Now we should get the label queries
queries, acc, err = svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, 3)
require.NoError(t, err)
require.Len(t, queries, 3)
assert.Zero(t, acc)
var gotHost *fleet.Host
@ -394,13 +395,12 @@ func TestLabelQueries(t *testing.T) {
map[string]fleet.OsqueryStatus{},
map[string]string{},
)
assert.Nil(t, err)
require.NoError(t, err)
host.LabelUpdatedAt = mockClock.Now()
assert.Equal(t, host, gotHost)
assert.Equal(t, mockClock.Now(), gotTime)
if assert.Len(t, gotResults, 1) {
assert.Equal(t, true, *gotResults[1])
}
require.Len(t, gotResults, 1)
assert.Equal(t, true, *gotResults[1])
mockClock.AddTime(1 * time.Second)
@ -414,14 +414,57 @@ func TestLabelQueries(t *testing.T) {
map[string]fleet.OsqueryStatus{},
map[string]string{},
)
assert.Nil(t, err)
require.NoError(t, err)
host.LabelUpdatedAt = mockClock.Now()
assert.Equal(t, host, gotHost)
assert.Equal(t, mockClock.Now(), gotTime)
if assert.Len(t, gotResults, 2) {
assert.Equal(t, true, *gotResults[2])
assert.Equal(t, false, *gotResults[3])
}
require.Len(t, gotResults, 2)
assert.Equal(t, true, *gotResults[2])
assert.Equal(t, false, *gotResults[3])
// We should get no labels now.
host.LabelUpdatedAt = mockClock.Now()
ctx = hostctx.NewContext(ctx, *host)
queries, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, 0)
assert.Zero(t, acc)
// With refetch requested details+label queries should be returned.
host.RefetchRequested = true
ctx = hostctx.NewContext(ctx, *host)
queries, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+3)
assert.Zero(t, acc)
// Record a query execution
err = svc.SubmitDistributedQueryResults(
ctx,
map[string][]map[string]string{
hostLabelQueryPrefix + "2": {{"col1": "val1"}},
hostLabelQueryPrefix + "3": {},
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
)
require.NoError(t, err)
host.LabelUpdatedAt = mockClock.Now()
assert.Equal(t, host, gotHost)
assert.Equal(t, mockClock.Now(), gotTime)
require.Len(t, gotResults, 2)
assert.Equal(t, true, *gotResults[2])
assert.Equal(t, false, *gotResults[3])
// SubmitDistributedQueryResults will set RefetchRequested to false.
require.False(t, host.RefetchRequested)
// There shouldn't be any labels now.
ctx = hostctx.NewContext(context.Background(), *host)
queries, acc, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, 0)
assert.Zero(t, acc)
}
func TestGetClientConfig(t *testing.T) {
@ -469,11 +512,11 @@ func TestGetClientConfig(t *testing.T) {
// No packs loaded yet
conf, err := svc.GetClientConfig(ctx1)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expectedConfig, conf)
conf, err = svc.GetClientConfig(ctx2)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expectedConfig, conf)
// Now add packs
@ -494,7 +537,7 @@ func TestGetClientConfig(t *testing.T) {
}
conf, err = svc.GetClientConfig(ctx1)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expectedOptions, conf["options"])
assert.JSONEq(t, `{
"pack_by_other_label": {
@ -513,7 +556,7 @@ func TestGetClientConfig(t *testing.T) {
)
conf, err = svc.GetClientConfig(ctx2)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expectedOptions, conf["options"])
assert.JSONEq(t, `{
"pack_by_label": {
@ -550,8 +593,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
// With a new host, we should get the detail queries (and accelerated
// queries)
queries, acc, err := svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, expectedDetailQueries)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries)
assert.NotZero(t, acc)
resultJSON := `
@ -644,7 +687,7 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
var results fleet.OsqueryDistributedQueryResults
err = json.Unmarshal([]byte(resultJSON), &results)
require.Nil(t, err)
require.NoError(t, err)
var gotHost *fleet.Host
ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error {
@ -687,16 +730,16 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
// Now no detail queries should be required
ctx = hostctx.NewContext(context.Background(), host)
queries, acc, err = svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, 0)
require.NoError(t, err)
require.Len(t, queries, 0)
assert.Zero(t, acc)
// Advance clock and queries should exist again
mockClock.AddTime(1*time.Hour + 1*time.Minute)
queries, acc, err = svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, expectedDetailQueries)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries)
assert.Zero(t, acc)
}
@ -725,7 +768,7 @@ func TestDetailQueries(t *testing.T) {
// queries)
queries, acc, err := svc.GetDistributedQueries(ctx)
require.NoError(t, err)
assert.Len(t, queries, expectedDetailQueries+1)
require.Len(t, queries, expectedDetailQueries+1)
assert.NotZero(t, acc)
resultJSON := `
@ -849,7 +892,7 @@ func TestDetailQueries(t *testing.T) {
var results fleet.OsqueryDistributedQueryResults
err = json.Unmarshal([]byte(resultJSON), &results)
require.Nil(t, err)
require.NoError(t, err)
var gotHost *fleet.Host
ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error {
@ -921,16 +964,16 @@ func TestDetailQueries(t *testing.T) {
// Now no detail queries should be required
ctx = hostctx.NewContext(ctx, host)
queries, acc, err = svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, 0)
require.NoError(t, err)
require.Len(t, queries, 0)
assert.Zero(t, acc)
// Advance clock and queries should exist again
mockClock.AddTime(1*time.Hour + 1*time.Minute)
queries, acc, err = svc.GetDistributedQueries(ctx)
assert.Nil(t, err)
assert.Len(t, queries, expectedDetailQueries+1)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+1)
assert.Zero(t, acc)
}
@ -990,7 +1033,7 @@ func TestNewDistributedQueryCampaign(t *testing.T) {
return nil
}
campaign, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}})
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, gotQuery.ID, gotCampaign.QueryID)
assert.True(t, ds.NewActivityFuncInvoked)
assert.Equal(t, []*fleet.DistributedQueryCampaignTarget{
@ -1043,8 +1086,8 @@ func TestDistributedQueryResults(t *testing.T) {
// Now we should get the active distributed query
queries, acc, err := svc.GetDistributedQueries(hostCtx)
require.Nil(t, err)
assert.Len(t, queries, expectedDetailQueries+1)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+1)
queryKey := fmt.Sprintf("%s%d", hostDistributedQueryPrefix, campaign.ID)
assert.Equal(t, "select * from time", queries[queryKey])
assert.NotZero(t, acc)
@ -1065,7 +1108,7 @@ func TestDistributedQueryResults(t *testing.T) {
// TODO use service method
readChan, err := rs.ReadChannel(context.Background(), *campaign)
require.Nil(t, err)
require.NoError(t, err)
// We need to listen for the result in a separate thread to prevent the
// write to the result channel from failing
@ -1100,7 +1143,7 @@ func TestDistributedQueryResults(t *testing.T) {
time.Sleep(10 * time.Millisecond)
err = svc.SubmitDistributedQueryResults(hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{})
require.Nil(t, err)
require.NoError(t, err)
}
func TestIngestDistributedQueryParseIdError(t *testing.T) {
@ -1490,7 +1533,7 @@ func TestUpdateHostIntervals(t *testing.T) {
}
_, err := svc.GetClientConfig(ctx)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, tt.saveHostCalled, saveHostCalled)
})
}
@ -1655,7 +1698,7 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) {
map[string]fleet.OsqueryStatus{},
map[string]string{},
)
assert.Nil(t, err)
require.NoError(t, err)
lCtx.Log(ctx, logger)
@ -1663,7 +1706,8 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) {
parts := strings.Split(strings.TrimSpace(logs), "\n")
require.Len(t, parts, 1)
logData := make(map[string]json.RawMessage)
require.NoError(t, json.Unmarshal([]byte(parts[0]), &logData))
err = json.Unmarshal([]byte(parts[0]), &logData)
require.NoError(t, err)
assert.Equal(t, json.RawMessage(`"something went wrong"`), logData["err"])
assert.Equal(t, json.RawMessage(`"Missing authorization check"`), logData["internal"])
}
@ -1697,7 +1741,7 @@ func TestDistributedQueriesReloadsHostIfDetailsAreIn(t *testing.T) {
map[string]fleet.OsqueryStatus{},
map[string]string{},
)
assert.Nil(t, err)
require.NoError(t, err)
assert.True(t, ds.HostFuncInvoked)
}
@ -1809,9 +1853,9 @@ func TestTeamMaintainerCanRunNewDistributedCampaigns(t *testing.T) {
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error {
return nil
}
//var gotQuery *fleet.Query
// var gotQuery *fleet.Query
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
//gotQuery = query
// gotQuery = query
query.ID = 42
return query, nil
}
@ -1874,21 +1918,25 @@ func TestPolicyQueries(t *testing.T) {
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+2)
hasPolicy1, hasPolicy2 := false, false
for name := range queries {
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
if name[len(hostPolicyQueryPrefix):] == "1" {
hasPolicy1 = true
}
if name[len(hostPolicyQueryPrefix):] == "2" {
hasPolicy2 = true
checkPolicyResults := func(queries map[string]string) {
hasPolicy1, hasPolicy2 := false, false
for name := range queries {
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
if name[len(hostPolicyQueryPrefix):] == "1" {
hasPolicy1 = true
}
if name[len(hostPolicyQueryPrefix):] == "2" {
hasPolicy2 = true
}
}
}
assert.True(t, hasPolicy1)
assert.True(t, hasPolicy2)
}
assert.True(t, hasPolicy1)
assert.True(t, hasPolicy2)
// Record a query execution
checkPolicyResults(queries)
// Record a query execution.
err = svc.SubmitDistributedQueryResults(
ctx,
map[string][]map[string]string{
@ -1904,39 +1952,86 @@ func TestPolicyQueries(t *testing.T) {
require.Equal(t, true, *recordedResults[1])
require.Nil(t, recordedResults[2])
noPolicyResults := func(queries map[string]string) {
hasAnyPolicy := false
for name := range queries {
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
hasAnyPolicy = true
break
}
}
assert.False(t, hasAnyPolicy)
}
// After the first time we get policies and update the host, then there shouldn't be any policies.
ctx = hostctx.NewContext(context.Background(), *host)
queries, _, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries)
noPolicyResults(queries)
// After the first time we get policies and update the host, then there shouldn't be any policies
hasAnyPolicy := false
for name := range queries {
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
hasAnyPolicy = true
break
}
}
assert.False(t, hasAnyPolicy)
// Let's move time forward, there should be policies now
// Let's move time forward, there should be policies now.
mockClock.AddTime(2 * time.Hour)
queries, _, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+2)
checkPolicyResults(queries)
hasPolicy1, hasPolicy2 = false, false
for name := range queries {
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
if name[len(hostPolicyQueryPrefix):] == "1" {
hasPolicy1 = true
}
if name[len(hostPolicyQueryPrefix):] == "2" {
hasPolicy2 = true
}
}
}
assert.True(t, hasPolicy1)
assert.True(t, hasPolicy2)
// Record another query execution.
err = svc.SubmitDistributedQueryResults(
ctx,
map[string][]map[string]string{
hostPolicyQueryPrefix + "1": {{"col1": "val1"}},
},
map[string]fleet.OsqueryStatus{
hostPolicyQueryPrefix + "2": 1,
},
map[string]string{},
)
require.NoError(t, err)
require.NotNil(t, recordedResults[1])
require.Equal(t, true, *recordedResults[1])
require.Nil(t, recordedResults[2])
// There shouldn't be any policies now.
ctx = hostctx.NewContext(context.Background(), *host)
queries, _, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries)
noPolicyResults(queries)
// With refetch requested policy queries should be returned.
host.RefetchRequested = true
ctx = hostctx.NewContext(context.Background(), *host)
queries, _, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries+2)
checkPolicyResults(queries)
// Record another query execution.
err = svc.SubmitDistributedQueryResults(
ctx,
map[string][]map[string]string{
hostPolicyQueryPrefix + "1": {{"col1": "val1"}},
},
map[string]fleet.OsqueryStatus{
hostPolicyQueryPrefix + "2": 1,
},
map[string]string{},
)
require.NoError(t, err)
require.NotNil(t, recordedResults[1])
require.Equal(t, true, *recordedResults[1])
require.Nil(t, recordedResults[2])
// SubmitDistributedQueryResults will set RefetchRequested to false.
require.False(t, host.RefetchRequested)
// There shouldn't be any policies now.
ctx = hostctx.NewContext(context.Background(), *host)
queries, _, err = svc.GetDistributedQueries(ctx)
require.NoError(t, err)
require.Len(t, queries, expectedDetailQueries)
noPolicyResults(queries)
}