From fb1e78d039742219369758d4f3b9b51c3e18fcac Mon Sep 17 00:00:00 2001 From: Lucas Manuel Rodriguez Date: Mon, 25 Oct 2021 15:46:49 -0300 Subject: [PATCH] Make /refetch run labels and policies for the host (#2651) * Make /refetch run labels and policies for the host * PR rework changes --- .github/pull_request_template.md | 2 +- .../issue-1458-refetch-labels-and-policies | 1 + docs/01-Using-Fleet/03-REST-API.md | 2 +- server/fleet/service.go | 11 +- server/service/service_osquery.go | 95 +++--- server/service/service_osquery_test.go | 291 ++++++++++++------ 6 files changed, 262 insertions(+), 140 deletions(-) create mode 100644 changes/issue-1458-refetch-labels-and-policies diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index d8297eaaf..d9708c738 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -3,7 +3,7 @@ If some of the following don't apply, delete the relevant line. - [ ] Changes file added (for user-visible changes) -- [ ] Documented any API changes +- [ ] Documented any API changes (docs/01-Using-Fleet/03-REST-API.md) - [ ] Documented any permissions changes - [ ] Added/updated tests - [ ] Manual QA for all new/changed functionality diff --git a/changes/issue-1458-refetch-labels-and-policies b/changes/issue-1458-refetch-labels-and-policies new file mode 100644 index 000000000..2e5503653 --- /dev/null +++ b/changes/issue-1458-refetch-labels-and-policies @@ -0,0 +1 @@ +* Refetch will now make a host run policy and label queries next time it checks in for distributed queries. diff --git a/docs/01-Using-Fleet/03-REST-API.md b/docs/01-Using-Fleet/03-REST-API.md index b446ab4fd..3e6765388 100644 --- a/docs/01-Using-Fleet/03-REST-API.md +++ b/docs/01-Using-Fleet/03-REST-API.md @@ -893,7 +893,7 @@ Deletes the specified host from Fleet. Note that a deleted host will fail authen ### Refetch host -Flags the host details to be refetched the next time the host checks in for live queries. Note that we cannot be certain when the host will actually check in and update these details. Further requests to the host APIs will indicate that the refetch has been requested through the `refetch_requested` field on the host object. +Flags the host details, labels and policies to be refetched the next time the host checks in for distributed queries. Note that we cannot be certain when the host will actually check in and update the query results. Further requests to the host APIs will indicate that the refetch has been requested through the `refetch_requested` field on the host object. `POST /api/v1/fleet/hosts/{id}/refetch` diff --git a/server/fleet/service.go b/server/fleet/service.go index 62e12ff2d..ebc3600e7 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -14,8 +14,15 @@ type OsqueryService interface { ) (nodeKey string, err error) AuthenticateHost(ctx context.Context, nodeKey string) (host *Host, debug bool, err error) GetClientConfig(ctx context.Context) (config map[string]interface{}, err error) - // GetDistributedQueries retrieves the distributed queries to run for the host in the provided context. These may be - // detail queries, label queries, or user-initiated distributed queries. A map from query name to query is returned. + // GetDistributedQueries retrieves the distributed queries to run for the host in + // the provided context. These may be (depending on update intervals): + // - detail queries (including additional queries, if any), + // - label queries, + // - user-initiated distributed queries (aka live queries), + // - policy queries. + // + // A map from query name to query is returned. + // // To enable the osquery "accelerated checkins" feature, a positive integer (number of seconds to activate for) // should be returned. Returning 0 for this will not activate the feature. GetDistributedQueries(ctx context.Context) (queries map[string]string, accelerate uint, err error) diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go index 71f3619c9..234eedafe 100644 --- a/server/service/service_osquery.go +++ b/server/service/service_osquery.go @@ -402,19 +402,19 @@ const hostPolicyQueryPrefix = "fleet_policy_query_" // run from a distributed query campaign const hostDistributedQueryPrefix = "fleet_distributed_query_" -// hostDetailQueries returns the map of queries that should be executed by -// osqueryd to fill in the host details -func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map[string]string, error) { - queries := make(map[string]string) +// detailQueriesForHost returns the map of detail+additional queries that should be executed by +// osqueryd to fill in the host details. +func (svc *Service) detailQueriesForHost(ctx context.Context, host fleet.Host) (map[string]string, error) { if !svc.shouldUpdate(host.DetailUpdatedAt, svc.config.Osquery.DetailUpdateInterval) && !host.RefetchRequested { - // No need to update already fresh details - return queries, nil - } - config, err := svc.ds.AppConfig(ctx) - if err != nil { - return nil, osqueryError{message: "get additional queries: " + err.Error()} + return nil, nil } + config, err := svc.ds.AppConfig(ctx) + if err != nil { + return nil, errors.Wrap(err, "read app config") + } + + queries := make(map[string]string) detailQueries := osquery_utils.GetDetailQueries(config) for name, query := range detailQueries { if query.RunsForPlatform(host.Platform) { @@ -422,7 +422,6 @@ func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map } } - // Get additional queries if config.HostSettings.AdditionalQueries == nil { // No additional queries set return queries, nil @@ -430,7 +429,7 @@ func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map var additionalQueries map[string]string if err := json.Unmarshal(*config.HostSettings.AdditionalQueries, &additionalQueries); err != nil { - return nil, osqueryError{message: "unmarshal additional queries: " + err.Error()} + return nil, errors.Wrap(err, "unmarshal additional queries") } for name, query := range additionalQueries { @@ -453,6 +452,28 @@ func (svc *Service) shouldUpdate(lastUpdated time.Time, interval time.Duration) return lastUpdated.Before(cutoff) } +func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { + if !svc.shouldUpdate(host.LabelUpdatedAt, svc.config.Osquery.LabelUpdateInterval) && !host.RefetchRequested { + return nil, nil + } + labelQueries, err := svc.ds.LabelQueriesForHost(ctx, host) + if err != nil { + return nil, errors.Wrap(err, "retrieve label queries") + } + return labelQueries, nil +} + +func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { + if !svc.shouldUpdate(host.PolicyUpdatedAt, svc.config.Osquery.PolicyUpdateInterval) && !host.RefetchRequested { + return nil, nil + } + policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host) + if err != nil { + return nil, errors.Wrap(err, "retrieve policy queries") + } + return policyQueries, nil +} + func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]string, uint, error) { // skipauth: Authorization is currently for user endpoints only. svc.authz.SkipAuthorization(ctx) @@ -464,41 +485,38 @@ func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]strin return nil, 0, osqueryError{message: "internal error: missing host from request context"} } - queries, err := svc.hostDetailQueries(ctx, host) + queries := make(map[string]string) + + detailQueries, err := svc.detailQueriesForHost(ctx, host) if err != nil { - return nil, 0, err + return nil, 0, osqueryError{message: err.Error()} + } + for name, query := range detailQueries { + queries[name] = query } - // Retrieve the label queries that should be updated - if svc.shouldUpdate(host.LabelUpdatedAt, svc.config.Osquery.LabelUpdateInterval) { - labelQueries, err := svc.ds.LabelQueriesForHost(ctx, &host) - if err != nil { - return nil, 0, osqueryError{message: "retrieving label queries: " + err.Error()} - } - - for name, query := range labelQueries { - queries[hostLabelQueryPrefix+name] = query - } + labelQueries, err := svc.labelQueriesForHost(ctx, &host) + if err != nil { + return nil, 0, osqueryError{message: err.Error()} + } + for name, query := range labelQueries { + queries[hostLabelQueryPrefix+name] = query } liveQueries, err := svc.liveQueryStore.QueriesForHost(host.ID) if err != nil { return nil, 0, osqueryError{message: "retrieve live queries: " + err.Error()} } - for name, query := range liveQueries { queries[hostDistributedQueryPrefix+name] = query } - if svc.shouldUpdate(host.PolicyUpdatedAt, svc.config.Osquery.PolicyUpdateInterval) { - policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, &host) - if err != nil { - return nil, 0, osqueryError{message: "retrieving policy queries: " + err.Error()} - } - - for name, query := range policyQueries { - queries[hostPolicyQueryPrefix+name] = query - } + policyQueries, err := svc.policyQueriesForHost(ctx, &host) + if err != nil { + return nil, 0, osqueryError{message: err.Error()} + } + for name, query := range policyQueries { + queries[hostPolicyQueryPrefix+name] = query } accelerate := uint(0) @@ -535,9 +553,6 @@ func (svc *Service) ingestDetailQuery(ctx context.Context, host *fleet.Host, nam } } - // Refetch is no longer needed after ingesting details. - host.RefetchRequested = false - return nil } @@ -644,7 +659,6 @@ func (svc *Service) SubmitDistributedQueryResults( logIPs(ctx) host, ok := hostctx.FromContext(ctx) - if !ok { return osqueryError{message: "internal error: missing host from request context"} } @@ -733,6 +747,11 @@ func (svc *Service) SubmitDistributedQueryResults( svc.maybeDebugHost(ctx, host, results, statuses, messages) + if host.RefetchRequested { + host.RefetchRequested = false + host.Modified = true + } + if host.Modified { err = svc.ds.SaveHost(ctx, &host) if err != nil { diff --git a/server/service/service_osquery_test.go b/server/service/service_osquery_test.go index b78fad43c..9aafebc3d 100644 --- a/server/service/service_osquery_test.go +++ b/server/service/service_osquery_test.go @@ -58,7 +58,7 @@ func TestEnrollAgent(t *testing.T) { svc := newTestService(ds, nil, nil) nodeKey, err := svc.EnrollAgent(context.Background(), "valid_secret", "host123", nil) - require.Nil(t, err) + require.NoError(t, err) assert.NotEmpty(t, nodeKey) } @@ -114,7 +114,7 @@ func TestEnrollAgentDetails(t *testing.T) { "foo": {"foo": "bar"}, } nodeKey, err := svc.EnrollAgent(context.Background(), "", "host123", details) - require.Nil(t, err) + require.NoError(t, err) assert.NotEmpty(t, nodeKey) assert.Equal(t, "Mac OS X 10.14.5", gotHost.OSVersion) @@ -144,19 +144,19 @@ func TestAuthenticateHost(t *testing.T) { } _, _, err := svc.AuthenticateHost(context.Background(), "test") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "test", gotKey) assert.False(t, ds.MarkHostsSeenFuncInvoked) host = fleet.Host{ID: 7, Hostname: "foobar"} _, _, err = svc.AuthenticateHost(context.Background(), "floobar") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "floobar", gotKey) assert.False(t, ds.MarkHostsSeenFuncInvoked) // Host checks in twice host = fleet.Host{ID: 7, Hostname: "foobar"} _, _, err = svc.AuthenticateHost(context.Background(), "floobar") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "floobar", gotKey) assert.False(t, ds.MarkHostsSeenFuncInvoked) @@ -168,7 +168,7 @@ func TestAuthenticateHost(t *testing.T) { err = svc.FlushSeenHosts(context.Background()) require.NoError(t, err) assert.True(t, ds.MarkHostsSeenFuncInvoked) - assert.Len(t, gotHostIDs, 0) + require.Len(t, gotHostIDs, 0) } func TestAuthenticateHostFailure(t *testing.T) { @@ -210,12 +210,12 @@ func TestSubmitStatusLogs(t *testing.T) { var status []json.RawMessage err := json.Unmarshal([]byte(logJSON), &status) - require.Nil(t, err) + require.NoError(t, err) host := fleet.Host{} ctx := hostctx.NewContext(context.Background(), host) err = serv.SubmitStatusLogs(ctx, status) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, status, testLogger.logs) } @@ -242,12 +242,12 @@ func TestSubmitResultLogs(t *testing.T) { var results []json.RawMessage err := json.Unmarshal([]byte(logJSON), &results) - require.Nil(t, err) + require.NoError(t, err) host := fleet.Host{} ctx := hostctx.NewContext(context.Background(), host) err = serv.SubmitResultLogs(ctx, results) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, results, testLogger.logs) } @@ -280,23 +280,23 @@ func TestHostDetailQueries(t *testing.T) { svc := &Service{clock: mockClock, config: config.TestConfig(), ds: ds} - queries, err := svc.hostDetailQueries(context.Background(), host) - assert.Nil(t, err) + queries, err := svc.detailQueriesForHost(context.Background(), host) + require.NoError(t, err) assert.Empty(t, queries) - // With refetch requested queries should be returned + // With refetch requested detail queries should be returned host.RefetchRequested = true - queries, err = svc.hostDetailQueries(context.Background(), host) - assert.Nil(t, err) + queries, err = svc.detailQueriesForHost(context.Background(), host) + require.NoError(t, err) assert.NotEmpty(t, queries) host.RefetchRequested = false // Advance the time mockClock.AddTime(1*time.Hour + 1*time.Minute) - queries, err = svc.hostDetailQueries(context.Background(), host) - assert.Nil(t, err) - assert.Len(t, queries, expectedDetailQueries+2) + queries, err = svc.detailQueriesForHost(context.Background(), host) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+2) for name := range queries { assert.True(t, strings.HasPrefix(name, hostDetailQueryPrefix) || strings.HasPrefix(name, hostAdditionalQueryPrefix), @@ -330,7 +330,8 @@ func TestLabelQueries(t *testing.T) { ds.HostFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { return host, nil } - ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { + ds.SaveHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { + host = gotHost return nil } ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { @@ -347,18 +348,18 @@ func TestLabelQueries(t *testing.T) { // With a new host, we should get the detail queries (and accelerate // should be turned on so that we can quickly fill labels) queries, acc, err := svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, expectedDetailQueries) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) assert.NotZero(t, acc) - // Simulate the detail queries being added + // Simulate the detail queries being added. host.DetailUpdatedAt = mockClock.Now().Add(-1 * time.Minute) host.Hostname = "zwass.local" ctx = hostctx.NewContext(ctx, *host) queries, acc, err = svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, 0) + require.NoError(t, err) + require.Len(t, queries, 0) assert.Zero(t, acc) ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { @@ -371,8 +372,8 @@ func TestLabelQueries(t *testing.T) { // Now we should get the label queries queries, acc, err = svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, 3) + require.NoError(t, err) + require.Len(t, queries, 3) assert.Zero(t, acc) var gotHost *fleet.Host @@ -394,13 +395,12 @@ func TestLabelQueries(t *testing.T) { map[string]fleet.OsqueryStatus{}, map[string]string{}, ) - assert.Nil(t, err) + require.NoError(t, err) host.LabelUpdatedAt = mockClock.Now() assert.Equal(t, host, gotHost) assert.Equal(t, mockClock.Now(), gotTime) - if assert.Len(t, gotResults, 1) { - assert.Equal(t, true, *gotResults[1]) - } + require.Len(t, gotResults, 1) + assert.Equal(t, true, *gotResults[1]) mockClock.AddTime(1 * time.Second) @@ -414,14 +414,57 @@ func TestLabelQueries(t *testing.T) { map[string]fleet.OsqueryStatus{}, map[string]string{}, ) - assert.Nil(t, err) + require.NoError(t, err) host.LabelUpdatedAt = mockClock.Now() assert.Equal(t, host, gotHost) assert.Equal(t, mockClock.Now(), gotTime) - if assert.Len(t, gotResults, 2) { - assert.Equal(t, true, *gotResults[2]) - assert.Equal(t, false, *gotResults[3]) - } + require.Len(t, gotResults, 2) + assert.Equal(t, true, *gotResults[2]) + assert.Equal(t, false, *gotResults[3]) + + // We should get no labels now. + host.LabelUpdatedAt = mockClock.Now() + ctx = hostctx.NewContext(ctx, *host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) + + // With refetch requested details+label queries should be returned. + host.RefetchRequested = true + ctx = hostctx.NewContext(ctx, *host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+3) + assert.Zero(t, acc) + + // Record a query execution + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostLabelQueryPrefix + "2": {{"col1": "val1"}}, + hostLabelQueryPrefix + "3": {}, + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + host.LabelUpdatedAt = mockClock.Now() + assert.Equal(t, host, gotHost) + assert.Equal(t, mockClock.Now(), gotTime) + require.Len(t, gotResults, 2) + assert.Equal(t, true, *gotResults[2]) + assert.Equal(t, false, *gotResults[3]) + + // SubmitDistributedQueryResults will set RefetchRequested to false. + require.False(t, host.RefetchRequested) + + // There shouldn't be any labels now. + ctx = hostctx.NewContext(context.Background(), *host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) } func TestGetClientConfig(t *testing.T) { @@ -469,11 +512,11 @@ func TestGetClientConfig(t *testing.T) { // No packs loaded yet conf, err := svc.GetClientConfig(ctx1) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expectedConfig, conf) conf, err = svc.GetClientConfig(ctx2) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expectedConfig, conf) // Now add packs @@ -494,7 +537,7 @@ func TestGetClientConfig(t *testing.T) { } conf, err = svc.GetClientConfig(ctx1) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expectedOptions, conf["options"]) assert.JSONEq(t, `{ "pack_by_other_label": { @@ -513,7 +556,7 @@ func TestGetClientConfig(t *testing.T) { ) conf, err = svc.GetClientConfig(ctx2) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expectedOptions, conf["options"]) assert.JSONEq(t, `{ "pack_by_label": { @@ -550,8 +593,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) { // With a new host, we should get the detail queries (and accelerated // queries) queries, acc, err := svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, expectedDetailQueries) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) assert.NotZero(t, acc) resultJSON := ` @@ -644,7 +687,7 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) { var results fleet.OsqueryDistributedQueryResults err = json.Unmarshal([]byte(resultJSON), &results) - require.Nil(t, err) + require.NoError(t, err) var gotHost *fleet.Host ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { @@ -687,16 +730,16 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) { // Now no detail queries should be required ctx = hostctx.NewContext(context.Background(), host) queries, acc, err = svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, 0) + require.NoError(t, err) + require.Len(t, queries, 0) assert.Zero(t, acc) // Advance clock and queries should exist again mockClock.AddTime(1*time.Hour + 1*time.Minute) queries, acc, err = svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, expectedDetailQueries) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) assert.Zero(t, acc) } @@ -725,7 +768,7 @@ func TestDetailQueries(t *testing.T) { // queries) queries, acc, err := svc.GetDistributedQueries(ctx) require.NoError(t, err) - assert.Len(t, queries, expectedDetailQueries+1) + require.Len(t, queries, expectedDetailQueries+1) assert.NotZero(t, acc) resultJSON := ` @@ -849,7 +892,7 @@ func TestDetailQueries(t *testing.T) { var results fleet.OsqueryDistributedQueryResults err = json.Unmarshal([]byte(resultJSON), &results) - require.Nil(t, err) + require.NoError(t, err) var gotHost *fleet.Host ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { @@ -921,16 +964,16 @@ func TestDetailQueries(t *testing.T) { // Now no detail queries should be required ctx = hostctx.NewContext(ctx, host) queries, acc, err = svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, 0) + require.NoError(t, err) + require.Len(t, queries, 0) assert.Zero(t, acc) // Advance clock and queries should exist again mockClock.AddTime(1*time.Hour + 1*time.Minute) queries, acc, err = svc.GetDistributedQueries(ctx) - assert.Nil(t, err) - assert.Len(t, queries, expectedDetailQueries+1) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+1) assert.Zero(t, acc) } @@ -990,7 +1033,7 @@ func TestNewDistributedQueryCampaign(t *testing.T) { return nil } campaign, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, gotQuery.ID, gotCampaign.QueryID) assert.True(t, ds.NewActivityFuncInvoked) assert.Equal(t, []*fleet.DistributedQueryCampaignTarget{ @@ -1043,8 +1086,8 @@ func TestDistributedQueryResults(t *testing.T) { // Now we should get the active distributed query queries, acc, err := svc.GetDistributedQueries(hostCtx) - require.Nil(t, err) - assert.Len(t, queries, expectedDetailQueries+1) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+1) queryKey := fmt.Sprintf("%s%d", hostDistributedQueryPrefix, campaign.ID) assert.Equal(t, "select * from time", queries[queryKey]) assert.NotZero(t, acc) @@ -1065,7 +1108,7 @@ func TestDistributedQueryResults(t *testing.T) { // TODO use service method readChan, err := rs.ReadChannel(context.Background(), *campaign) - require.Nil(t, err) + require.NoError(t, err) // We need to listen for the result in a separate thread to prevent the // write to the result channel from failing @@ -1100,7 +1143,7 @@ func TestDistributedQueryResults(t *testing.T) { time.Sleep(10 * time.Millisecond) err = svc.SubmitDistributedQueryResults(hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}) - require.Nil(t, err) + require.NoError(t, err) } func TestIngestDistributedQueryParseIdError(t *testing.T) { @@ -1490,7 +1533,7 @@ func TestUpdateHostIntervals(t *testing.T) { } _, err := svc.GetClientConfig(ctx) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, tt.saveHostCalled, saveHostCalled) }) } @@ -1655,7 +1698,7 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) { map[string]fleet.OsqueryStatus{}, map[string]string{}, ) - assert.Nil(t, err) + require.NoError(t, err) lCtx.Log(ctx, logger) @@ -1663,7 +1706,8 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) { parts := strings.Split(strings.TrimSpace(logs), "\n") require.Len(t, parts, 1) logData := make(map[string]json.RawMessage) - require.NoError(t, json.Unmarshal([]byte(parts[0]), &logData)) + err = json.Unmarshal([]byte(parts[0]), &logData) + require.NoError(t, err) assert.Equal(t, json.RawMessage(`"something went wrong"`), logData["err"]) assert.Equal(t, json.RawMessage(`"Missing authorization check"`), logData["internal"]) } @@ -1697,7 +1741,7 @@ func TestDistributedQueriesReloadsHostIfDetailsAreIn(t *testing.T) { map[string]fleet.OsqueryStatus{}, map[string]string{}, ) - assert.Nil(t, err) + require.NoError(t, err) assert.True(t, ds.HostFuncInvoked) } @@ -1809,9 +1853,9 @@ func TestTeamMaintainerCanRunNewDistributedCampaigns(t *testing.T) { ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { return nil } - //var gotQuery *fleet.Query + // var gotQuery *fleet.Query ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { - //gotQuery = query + // gotQuery = query query.ID = 42 return query, nil } @@ -1874,21 +1918,25 @@ func TestPolicyQueries(t *testing.T) { require.NoError(t, err) require.Len(t, queries, expectedDetailQueries+2) - hasPolicy1, hasPolicy2 := false, false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - if name[len(hostPolicyQueryPrefix):] == "1" { - hasPolicy1 = true - } - if name[len(hostPolicyQueryPrefix):] == "2" { - hasPolicy2 = true + checkPolicyResults := func(queries map[string]string) { + hasPolicy1, hasPolicy2 := false, false + for name := range queries { + if strings.HasPrefix(name, hostPolicyQueryPrefix) { + if name[len(hostPolicyQueryPrefix):] == "1" { + hasPolicy1 = true + } + if name[len(hostPolicyQueryPrefix):] == "2" { + hasPolicy2 = true + } } } + assert.True(t, hasPolicy1) + assert.True(t, hasPolicy2) } - assert.True(t, hasPolicy1) - assert.True(t, hasPolicy2) - // Record a query execution + checkPolicyResults(queries) + + // Record a query execution. err = svc.SubmitDistributedQueryResults( ctx, map[string][]map[string]string{ @@ -1904,39 +1952,86 @@ func TestPolicyQueries(t *testing.T) { require.Equal(t, true, *recordedResults[1]) require.Nil(t, recordedResults[2]) + noPolicyResults := func(queries map[string]string) { + hasAnyPolicy := false + for name := range queries { + if strings.HasPrefix(name, hostPolicyQueryPrefix) { + hasAnyPolicy = true + break + } + } + assert.False(t, hasAnyPolicy) + } + + // After the first time we get policies and update the host, then there shouldn't be any policies. ctx = hostctx.NewContext(context.Background(), *host) queries, _, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) - // After the first time we get policies and update the host, then there shouldn't be any policies - hasAnyPolicy := false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - hasAnyPolicy = true - break - } - } - assert.False(t, hasAnyPolicy) - - // Let's move time forward, there should be policies now + // Let's move time forward, there should be policies now. mockClock.AddTime(2 * time.Hour) queries, _, err = svc.GetDistributedQueries(ctx) require.NoError(t, err) require.Len(t, queries, expectedDetailQueries+2) + checkPolicyResults(queries) - hasPolicy1, hasPolicy2 = false, false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - if name[len(hostPolicyQueryPrefix):] == "1" { - hasPolicy1 = true - } - if name[len(hostPolicyQueryPrefix):] == "2" { - hasPolicy2 = true - } - } - } - assert.True(t, hasPolicy1) - assert.True(t, hasPolicy2) + // Record another query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, + }, + map[string]string{}, + ) + require.NoError(t, err) + require.NotNil(t, recordedResults[1]) + require.Equal(t, true, *recordedResults[1]) + require.Nil(t, recordedResults[2]) + + // There shouldn't be any policies now. + ctx = hostctx.NewContext(context.Background(), *host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) + + // With refetch requested policy queries should be returned. + host.RefetchRequested = true + ctx = hostctx.NewContext(context.Background(), *host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+2) + checkPolicyResults(queries) + + // Record another query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, + }, + map[string]string{}, + ) + require.NoError(t, err) + require.NotNil(t, recordedResults[1]) + require.Equal(t, true, *recordedResults[1]) + require.Nil(t, recordedResults[2]) + + // SubmitDistributedQueryResults will set RefetchRequested to false. + require.False(t, host.RefetchRequested) + + // There shouldn't be any policies now. + ctx = hostctx.NewContext(context.Background(), *host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) }