mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Make /refetch run labels and policies for the host (#2651)
* Make /refetch run labels and policies for the host * PR rework changes
This commit is contained in:
parent
cab1099890
commit
fb1e78d039
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@ -3,7 +3,7 @@
|
||||
If some of the following don't apply, delete the relevant line.
|
||||
|
||||
- [ ] Changes file added (for user-visible changes)
|
||||
- [ ] Documented any API changes
|
||||
- [ ] Documented any API changes (docs/01-Using-Fleet/03-REST-API.md)
|
||||
- [ ] Documented any permissions changes
|
||||
- [ ] Added/updated tests
|
||||
- [ ] Manual QA for all new/changed functionality
|
||||
|
1
changes/issue-1458-refetch-labels-and-policies
Normal file
1
changes/issue-1458-refetch-labels-and-policies
Normal file
@ -0,0 +1 @@
|
||||
* Refetch will now make a host run policy and label queries next time it checks in for distributed queries.
|
@ -893,7 +893,7 @@ Deletes the specified host from Fleet. Note that a deleted host will fail authen
|
||||
|
||||
### Refetch host
|
||||
|
||||
Flags the host details to be refetched the next time the host checks in for live queries. Note that we cannot be certain when the host will actually check in and update these details. Further requests to the host APIs will indicate that the refetch has been requested through the `refetch_requested` field on the host object.
|
||||
Flags the host details, labels and policies to be refetched the next time the host checks in for distributed queries. Note that we cannot be certain when the host will actually check in and update the query results. Further requests to the host APIs will indicate that the refetch has been requested through the `refetch_requested` field on the host object.
|
||||
|
||||
`POST /api/v1/fleet/hosts/{id}/refetch`
|
||||
|
||||
|
@ -14,8 +14,15 @@ type OsqueryService interface {
|
||||
) (nodeKey string, err error)
|
||||
AuthenticateHost(ctx context.Context, nodeKey string) (host *Host, debug bool, err error)
|
||||
GetClientConfig(ctx context.Context) (config map[string]interface{}, err error)
|
||||
// GetDistributedQueries retrieves the distributed queries to run for the host in the provided context. These may be
|
||||
// detail queries, label queries, or user-initiated distributed queries. A map from query name to query is returned.
|
||||
// GetDistributedQueries retrieves the distributed queries to run for the host in
|
||||
// the provided context. These may be (depending on update intervals):
|
||||
// - detail queries (including additional queries, if any),
|
||||
// - label queries,
|
||||
// - user-initiated distributed queries (aka live queries),
|
||||
// - policy queries.
|
||||
//
|
||||
// A map from query name to query is returned.
|
||||
//
|
||||
// To enable the osquery "accelerated checkins" feature, a positive integer (number of seconds to activate for)
|
||||
// should be returned. Returning 0 for this will not activate the feature.
|
||||
GetDistributedQueries(ctx context.Context) (queries map[string]string, accelerate uint, err error)
|
||||
|
@ -402,19 +402,19 @@ const hostPolicyQueryPrefix = "fleet_policy_query_"
|
||||
// run from a distributed query campaign
|
||||
const hostDistributedQueryPrefix = "fleet_distributed_query_"
|
||||
|
||||
// hostDetailQueries returns the map of queries that should be executed by
|
||||
// osqueryd to fill in the host details
|
||||
func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map[string]string, error) {
|
||||
queries := make(map[string]string)
|
||||
// detailQueriesForHost returns the map of detail+additional queries that should be executed by
|
||||
// osqueryd to fill in the host details.
|
||||
func (svc *Service) detailQueriesForHost(ctx context.Context, host fleet.Host) (map[string]string, error) {
|
||||
if !svc.shouldUpdate(host.DetailUpdatedAt, svc.config.Osquery.DetailUpdateInterval) && !host.RefetchRequested {
|
||||
// No need to update already fresh details
|
||||
return queries, nil
|
||||
}
|
||||
config, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, osqueryError{message: "get additional queries: " + err.Error()}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
config, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read app config")
|
||||
}
|
||||
|
||||
queries := make(map[string]string)
|
||||
detailQueries := osquery_utils.GetDetailQueries(config)
|
||||
for name, query := range detailQueries {
|
||||
if query.RunsForPlatform(host.Platform) {
|
||||
@ -422,7 +422,6 @@ func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map
|
||||
}
|
||||
}
|
||||
|
||||
// Get additional queries
|
||||
if config.HostSettings.AdditionalQueries == nil {
|
||||
// No additional queries set
|
||||
return queries, nil
|
||||
@ -430,7 +429,7 @@ func (svc *Service) hostDetailQueries(ctx context.Context, host fleet.Host) (map
|
||||
|
||||
var additionalQueries map[string]string
|
||||
if err := json.Unmarshal(*config.HostSettings.AdditionalQueries, &additionalQueries); err != nil {
|
||||
return nil, osqueryError{message: "unmarshal additional queries: " + err.Error()}
|
||||
return nil, errors.Wrap(err, "unmarshal additional queries")
|
||||
}
|
||||
|
||||
for name, query := range additionalQueries {
|
||||
@ -453,6 +452,28 @@ func (svc *Service) shouldUpdate(lastUpdated time.Time, interval time.Duration)
|
||||
return lastUpdated.Before(cutoff)
|
||||
}
|
||||
|
||||
func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
|
||||
if !svc.shouldUpdate(host.LabelUpdatedAt, svc.config.Osquery.LabelUpdateInterval) && !host.RefetchRequested {
|
||||
return nil, nil
|
||||
}
|
||||
labelQueries, err := svc.ds.LabelQueriesForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "retrieve label queries")
|
||||
}
|
||||
return labelQueries, nil
|
||||
}
|
||||
|
||||
func (svc *Service) policyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
|
||||
if !svc.shouldUpdate(host.PolicyUpdatedAt, svc.config.Osquery.PolicyUpdateInterval) && !host.RefetchRequested {
|
||||
return nil, nil
|
||||
}
|
||||
policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "retrieve policy queries")
|
||||
}
|
||||
return policyQueries, nil
|
||||
}
|
||||
|
||||
func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]string, uint, error) {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
@ -464,41 +485,38 @@ func (svc *Service) GetDistributedQueries(ctx context.Context) (map[string]strin
|
||||
return nil, 0, osqueryError{message: "internal error: missing host from request context"}
|
||||
}
|
||||
|
||||
queries, err := svc.hostDetailQueries(ctx, host)
|
||||
queries := make(map[string]string)
|
||||
|
||||
detailQueries, err := svc.detailQueriesForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, osqueryError{message: err.Error()}
|
||||
}
|
||||
for name, query := range detailQueries {
|
||||
queries[name] = query
|
||||
}
|
||||
|
||||
// Retrieve the label queries that should be updated
|
||||
if svc.shouldUpdate(host.LabelUpdatedAt, svc.config.Osquery.LabelUpdateInterval) {
|
||||
labelQueries, err := svc.ds.LabelQueriesForHost(ctx, &host)
|
||||
if err != nil {
|
||||
return nil, 0, osqueryError{message: "retrieving label queries: " + err.Error()}
|
||||
}
|
||||
|
||||
for name, query := range labelQueries {
|
||||
queries[hostLabelQueryPrefix+name] = query
|
||||
}
|
||||
labelQueries, err := svc.labelQueriesForHost(ctx, &host)
|
||||
if err != nil {
|
||||
return nil, 0, osqueryError{message: err.Error()}
|
||||
}
|
||||
for name, query := range labelQueries {
|
||||
queries[hostLabelQueryPrefix+name] = query
|
||||
}
|
||||
|
||||
liveQueries, err := svc.liveQueryStore.QueriesForHost(host.ID)
|
||||
if err != nil {
|
||||
return nil, 0, osqueryError{message: "retrieve live queries: " + err.Error()}
|
||||
}
|
||||
|
||||
for name, query := range liveQueries {
|
||||
queries[hostDistributedQueryPrefix+name] = query
|
||||
}
|
||||
|
||||
if svc.shouldUpdate(host.PolicyUpdatedAt, svc.config.Osquery.PolicyUpdateInterval) {
|
||||
policyQueries, err := svc.ds.PolicyQueriesForHost(ctx, &host)
|
||||
if err != nil {
|
||||
return nil, 0, osqueryError{message: "retrieving policy queries: " + err.Error()}
|
||||
}
|
||||
|
||||
for name, query := range policyQueries {
|
||||
queries[hostPolicyQueryPrefix+name] = query
|
||||
}
|
||||
policyQueries, err := svc.policyQueriesForHost(ctx, &host)
|
||||
if err != nil {
|
||||
return nil, 0, osqueryError{message: err.Error()}
|
||||
}
|
||||
for name, query := range policyQueries {
|
||||
queries[hostPolicyQueryPrefix+name] = query
|
||||
}
|
||||
|
||||
accelerate := uint(0)
|
||||
@ -535,9 +553,6 @@ func (svc *Service) ingestDetailQuery(ctx context.Context, host *fleet.Host, nam
|
||||
}
|
||||
}
|
||||
|
||||
// Refetch is no longer needed after ingesting details.
|
||||
host.RefetchRequested = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -644,7 +659,6 @@ func (svc *Service) SubmitDistributedQueryResults(
|
||||
logIPs(ctx)
|
||||
|
||||
host, ok := hostctx.FromContext(ctx)
|
||||
|
||||
if !ok {
|
||||
return osqueryError{message: "internal error: missing host from request context"}
|
||||
}
|
||||
@ -733,6 +747,11 @@ func (svc *Service) SubmitDistributedQueryResults(
|
||||
|
||||
svc.maybeDebugHost(ctx, host, results, statuses, messages)
|
||||
|
||||
if host.RefetchRequested {
|
||||
host.RefetchRequested = false
|
||||
host.Modified = true
|
||||
}
|
||||
|
||||
if host.Modified {
|
||||
err = svc.ds.SaveHost(ctx, &host)
|
||||
if err != nil {
|
||||
|
@ -58,7 +58,7 @@ func TestEnrollAgent(t *testing.T) {
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
nodeKey, err := svc.EnrollAgent(context.Background(), "valid_secret", "host123", nil)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, nodeKey)
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func TestEnrollAgentDetails(t *testing.T) {
|
||||
"foo": {"foo": "bar"},
|
||||
}
|
||||
nodeKey, err := svc.EnrollAgent(context.Background(), "", "host123", details)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, nodeKey)
|
||||
|
||||
assert.Equal(t, "Mac OS X 10.14.5", gotHost.OSVersion)
|
||||
@ -144,19 +144,19 @@ func TestAuthenticateHost(t *testing.T) {
|
||||
}
|
||||
|
||||
_, _, err := svc.AuthenticateHost(context.Background(), "test")
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test", gotKey)
|
||||
assert.False(t, ds.MarkHostsSeenFuncInvoked)
|
||||
|
||||
host = fleet.Host{ID: 7, Hostname: "foobar"}
|
||||
_, _, err = svc.AuthenticateHost(context.Background(), "floobar")
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "floobar", gotKey)
|
||||
assert.False(t, ds.MarkHostsSeenFuncInvoked)
|
||||
// Host checks in twice
|
||||
host = fleet.Host{ID: 7, Hostname: "foobar"}
|
||||
_, _, err = svc.AuthenticateHost(context.Background(), "floobar")
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "floobar", gotKey)
|
||||
assert.False(t, ds.MarkHostsSeenFuncInvoked)
|
||||
|
||||
@ -168,7 +168,7 @@ func TestAuthenticateHost(t *testing.T) {
|
||||
err = svc.FlushSeenHosts(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ds.MarkHostsSeenFuncInvoked)
|
||||
assert.Len(t, gotHostIDs, 0)
|
||||
require.Len(t, gotHostIDs, 0)
|
||||
}
|
||||
|
||||
func TestAuthenticateHostFailure(t *testing.T) {
|
||||
@ -210,12 +210,12 @@ func TestSubmitStatusLogs(t *testing.T) {
|
||||
|
||||
var status []json.RawMessage
|
||||
err := json.Unmarshal([]byte(logJSON), &status)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
host := fleet.Host{}
|
||||
ctx := hostctx.NewContext(context.Background(), host)
|
||||
err = serv.SubmitStatusLogs(ctx, status)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, status, testLogger.logs)
|
||||
}
|
||||
@ -242,12 +242,12 @@ func TestSubmitResultLogs(t *testing.T) {
|
||||
|
||||
var results []json.RawMessage
|
||||
err := json.Unmarshal([]byte(logJSON), &results)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
host := fleet.Host{}
|
||||
ctx := hostctx.NewContext(context.Background(), host)
|
||||
err = serv.SubmitResultLogs(ctx, results)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, results, testLogger.logs)
|
||||
}
|
||||
@ -280,23 +280,23 @@ func TestHostDetailQueries(t *testing.T) {
|
||||
|
||||
svc := &Service{clock: mockClock, config: config.TestConfig(), ds: ds}
|
||||
|
||||
queries, err := svc.hostDetailQueries(context.Background(), host)
|
||||
assert.Nil(t, err)
|
||||
queries, err := svc.detailQueriesForHost(context.Background(), host)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, queries)
|
||||
|
||||
// With refetch requested queries should be returned
|
||||
// With refetch requested detail queries should be returned
|
||||
host.RefetchRequested = true
|
||||
queries, err = svc.hostDetailQueries(context.Background(), host)
|
||||
assert.Nil(t, err)
|
||||
queries, err = svc.detailQueriesForHost(context.Background(), host)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, queries)
|
||||
host.RefetchRequested = false
|
||||
|
||||
// Advance the time
|
||||
mockClock.AddTime(1*time.Hour + 1*time.Minute)
|
||||
|
||||
queries, err = svc.hostDetailQueries(context.Background(), host)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries+2)
|
||||
queries, err = svc.detailQueriesForHost(context.Background(), host)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+2)
|
||||
for name := range queries {
|
||||
assert.True(t,
|
||||
strings.HasPrefix(name, hostDetailQueryPrefix) || strings.HasPrefix(name, hostAdditionalQueryPrefix),
|
||||
@ -330,7 +330,8 @@ func TestLabelQueries(t *testing.T) {
|
||||
ds.HostFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
return host, nil
|
||||
}
|
||||
ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error {
|
||||
ds.SaveHostFunc = func(ctx context.Context, gotHost *fleet.Host) error {
|
||||
host = gotHost
|
||||
return nil
|
||||
}
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
@ -347,18 +348,18 @@ func TestLabelQueries(t *testing.T) {
|
||||
// With a new host, we should get the detail queries (and accelerate
|
||||
// should be turned on so that we can quickly fill labels)
|
||||
queries, acc, err := svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries)
|
||||
assert.NotZero(t, acc)
|
||||
|
||||
// Simulate the detail queries being added
|
||||
// Simulate the detail queries being added.
|
||||
host.DetailUpdatedAt = mockClock.Now().Add(-1 * time.Minute)
|
||||
host.Hostname = "zwass.local"
|
||||
ctx = hostctx.NewContext(ctx, *host)
|
||||
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 0)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 0)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) {
|
||||
@ -371,8 +372,8 @@ func TestLabelQueries(t *testing.T) {
|
||||
|
||||
// Now we should get the label queries
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 3)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 3)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
var gotHost *fleet.Host
|
||||
@ -394,13 +395,12 @@ func TestLabelQueries(t *testing.T) {
|
||||
map[string]fleet.OsqueryStatus{},
|
||||
map[string]string{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
host.LabelUpdatedAt = mockClock.Now()
|
||||
assert.Equal(t, host, gotHost)
|
||||
assert.Equal(t, mockClock.Now(), gotTime)
|
||||
if assert.Len(t, gotResults, 1) {
|
||||
assert.Equal(t, true, *gotResults[1])
|
||||
}
|
||||
require.Len(t, gotResults, 1)
|
||||
assert.Equal(t, true, *gotResults[1])
|
||||
|
||||
mockClock.AddTime(1 * time.Second)
|
||||
|
||||
@ -414,14 +414,57 @@ func TestLabelQueries(t *testing.T) {
|
||||
map[string]fleet.OsqueryStatus{},
|
||||
map[string]string{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
host.LabelUpdatedAt = mockClock.Now()
|
||||
assert.Equal(t, host, gotHost)
|
||||
assert.Equal(t, mockClock.Now(), gotTime)
|
||||
if assert.Len(t, gotResults, 2) {
|
||||
assert.Equal(t, true, *gotResults[2])
|
||||
assert.Equal(t, false, *gotResults[3])
|
||||
}
|
||||
require.Len(t, gotResults, 2)
|
||||
assert.Equal(t, true, *gotResults[2])
|
||||
assert.Equal(t, false, *gotResults[3])
|
||||
|
||||
// We should get no labels now.
|
||||
host.LabelUpdatedAt = mockClock.Now()
|
||||
ctx = hostctx.NewContext(ctx, *host)
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 0)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
// With refetch requested details+label queries should be returned.
|
||||
host.RefetchRequested = true
|
||||
ctx = hostctx.NewContext(ctx, *host)
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+3)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
// Record a query execution
|
||||
err = svc.SubmitDistributedQueryResults(
|
||||
ctx,
|
||||
map[string][]map[string]string{
|
||||
hostLabelQueryPrefix + "2": {{"col1": "val1"}},
|
||||
hostLabelQueryPrefix + "3": {},
|
||||
},
|
||||
map[string]fleet.OsqueryStatus{},
|
||||
map[string]string{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
host.LabelUpdatedAt = mockClock.Now()
|
||||
assert.Equal(t, host, gotHost)
|
||||
assert.Equal(t, mockClock.Now(), gotTime)
|
||||
require.Len(t, gotResults, 2)
|
||||
assert.Equal(t, true, *gotResults[2])
|
||||
assert.Equal(t, false, *gotResults[3])
|
||||
|
||||
// SubmitDistributedQueryResults will set RefetchRequested to false.
|
||||
require.False(t, host.RefetchRequested)
|
||||
|
||||
// There shouldn't be any labels now.
|
||||
ctx = hostctx.NewContext(context.Background(), *host)
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 0)
|
||||
assert.Zero(t, acc)
|
||||
}
|
||||
|
||||
func TestGetClientConfig(t *testing.T) {
|
||||
@ -469,11 +512,11 @@ func TestGetClientConfig(t *testing.T) {
|
||||
|
||||
// No packs loaded yet
|
||||
conf, err := svc.GetClientConfig(ctx1)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedConfig, conf)
|
||||
|
||||
conf, err = svc.GetClientConfig(ctx2)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedConfig, conf)
|
||||
|
||||
// Now add packs
|
||||
@ -494,7 +537,7 @@ func TestGetClientConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
conf, err = svc.GetClientConfig(ctx1)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, conf["options"])
|
||||
assert.JSONEq(t, `{
|
||||
"pack_by_other_label": {
|
||||
@ -513,7 +556,7 @@ func TestGetClientConfig(t *testing.T) {
|
||||
)
|
||||
|
||||
conf, err = svc.GetClientConfig(ctx2)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedOptions, conf["options"])
|
||||
assert.JSONEq(t, `{
|
||||
"pack_by_label": {
|
||||
@ -550,8 +593,8 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
|
||||
// With a new host, we should get the detail queries (and accelerated
|
||||
// queries)
|
||||
queries, acc, err := svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries)
|
||||
assert.NotZero(t, acc)
|
||||
|
||||
resultJSON := `
|
||||
@ -644,7 +687,7 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
|
||||
|
||||
var results fleet.OsqueryDistributedQueryResults
|
||||
err = json.Unmarshal([]byte(resultJSON), &results)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var gotHost *fleet.Host
|
||||
ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error {
|
||||
@ -687,16 +730,16 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
|
||||
// Now no detail queries should be required
|
||||
ctx = hostctx.NewContext(context.Background(), host)
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 0)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 0)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
// Advance clock and queries should exist again
|
||||
mockClock.AddTime(1*time.Hour + 1*time.Minute)
|
||||
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries)
|
||||
assert.Zero(t, acc)
|
||||
}
|
||||
|
||||
@ -725,7 +768,7 @@ func TestDetailQueries(t *testing.T) {
|
||||
// queries)
|
||||
queries, acc, err := svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries+1)
|
||||
require.Len(t, queries, expectedDetailQueries+1)
|
||||
assert.NotZero(t, acc)
|
||||
|
||||
resultJSON := `
|
||||
@ -849,7 +892,7 @@ func TestDetailQueries(t *testing.T) {
|
||||
|
||||
var results fleet.OsqueryDistributedQueryResults
|
||||
err = json.Unmarshal([]byte(resultJSON), &results)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var gotHost *fleet.Host
|
||||
ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error {
|
||||
@ -921,16 +964,16 @@ func TestDetailQueries(t *testing.T) {
|
||||
// Now no detail queries should be required
|
||||
ctx = hostctx.NewContext(ctx, host)
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 0)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, 0)
|
||||
assert.Zero(t, acc)
|
||||
|
||||
// Advance clock and queries should exist again
|
||||
mockClock.AddTime(1*time.Hour + 1*time.Minute)
|
||||
|
||||
queries, acc, err = svc.GetDistributedQueries(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries+1)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+1)
|
||||
assert.Zero(t, acc)
|
||||
}
|
||||
|
||||
@ -990,7 +1033,7 @@ func TestNewDistributedQueryCampaign(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
campaign, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, gotQuery.ID, gotCampaign.QueryID)
|
||||
assert.True(t, ds.NewActivityFuncInvoked)
|
||||
assert.Equal(t, []*fleet.DistributedQueryCampaignTarget{
|
||||
@ -1043,8 +1086,8 @@ func TestDistributedQueryResults(t *testing.T) {
|
||||
|
||||
// Now we should get the active distributed query
|
||||
queries, acc, err := svc.GetDistributedQueries(hostCtx)
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, queries, expectedDetailQueries+1)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+1)
|
||||
queryKey := fmt.Sprintf("%s%d", hostDistributedQueryPrefix, campaign.ID)
|
||||
assert.Equal(t, "select * from time", queries[queryKey])
|
||||
assert.NotZero(t, acc)
|
||||
@ -1065,7 +1108,7 @@ func TestDistributedQueryResults(t *testing.T) {
|
||||
|
||||
// TODO use service method
|
||||
readChan, err := rs.ReadChannel(context.Background(), *campaign)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We need to listen for the result in a separate thread to prevent the
|
||||
// write to the result channel from failing
|
||||
@ -1100,7 +1143,7 @@ func TestDistributedQueryResults(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = svc.SubmitDistributedQueryResults(hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestIngestDistributedQueryParseIdError(t *testing.T) {
|
||||
@ -1490,7 +1533,7 @@ func TestUpdateHostIntervals(t *testing.T) {
|
||||
}
|
||||
|
||||
_, err := svc.GetClientConfig(ctx)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.saveHostCalled, saveHostCalled)
|
||||
})
|
||||
}
|
||||
@ -1655,7 +1698,7 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) {
|
||||
map[string]fleet.OsqueryStatus{},
|
||||
map[string]string{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
lCtx.Log(ctx, logger)
|
||||
|
||||
@ -1663,7 +1706,8 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) {
|
||||
parts := strings.Split(strings.TrimSpace(logs), "\n")
|
||||
require.Len(t, parts, 1)
|
||||
logData := make(map[string]json.RawMessage)
|
||||
require.NoError(t, json.Unmarshal([]byte(parts[0]), &logData))
|
||||
err = json.Unmarshal([]byte(parts[0]), &logData)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, json.RawMessage(`"something went wrong"`), logData["err"])
|
||||
assert.Equal(t, json.RawMessage(`"Missing authorization check"`), logData["internal"])
|
||||
}
|
||||
@ -1697,7 +1741,7 @@ func TestDistributedQueriesReloadsHostIfDetailsAreIn(t *testing.T) {
|
||||
map[string]fleet.OsqueryStatus{},
|
||||
map[string]string{},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ds.HostFuncInvoked)
|
||||
}
|
||||
|
||||
@ -1809,9 +1853,9 @@ func TestTeamMaintainerCanRunNewDistributedCampaigns(t *testing.T) {
|
||||
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
//var gotQuery *fleet.Query
|
||||
// var gotQuery *fleet.Query
|
||||
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
|
||||
//gotQuery = query
|
||||
// gotQuery = query
|
||||
query.ID = 42
|
||||
return query, nil
|
||||
}
|
||||
@ -1874,21 +1918,25 @@ func TestPolicyQueries(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+2)
|
||||
|
||||
hasPolicy1, hasPolicy2 := false, false
|
||||
for name := range queries {
|
||||
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
|
||||
if name[len(hostPolicyQueryPrefix):] == "1" {
|
||||
hasPolicy1 = true
|
||||
}
|
||||
if name[len(hostPolicyQueryPrefix):] == "2" {
|
||||
hasPolicy2 = true
|
||||
checkPolicyResults := func(queries map[string]string) {
|
||||
hasPolicy1, hasPolicy2 := false, false
|
||||
for name := range queries {
|
||||
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
|
||||
if name[len(hostPolicyQueryPrefix):] == "1" {
|
||||
hasPolicy1 = true
|
||||
}
|
||||
if name[len(hostPolicyQueryPrefix):] == "2" {
|
||||
hasPolicy2 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasPolicy1)
|
||||
assert.True(t, hasPolicy2)
|
||||
}
|
||||
assert.True(t, hasPolicy1)
|
||||
assert.True(t, hasPolicy2)
|
||||
|
||||
// Record a query execution
|
||||
checkPolicyResults(queries)
|
||||
|
||||
// Record a query execution.
|
||||
err = svc.SubmitDistributedQueryResults(
|
||||
ctx,
|
||||
map[string][]map[string]string{
|
||||
@ -1904,39 +1952,86 @@ func TestPolicyQueries(t *testing.T) {
|
||||
require.Equal(t, true, *recordedResults[1])
|
||||
require.Nil(t, recordedResults[2])
|
||||
|
||||
noPolicyResults := func(queries map[string]string) {
|
||||
hasAnyPolicy := false
|
||||
for name := range queries {
|
||||
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
|
||||
hasAnyPolicy = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, hasAnyPolicy)
|
||||
}
|
||||
|
||||
// After the first time we get policies and update the host, then there shouldn't be any policies.
|
||||
ctx = hostctx.NewContext(context.Background(), *host)
|
||||
queries, _, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries)
|
||||
noPolicyResults(queries)
|
||||
|
||||
// After the first time we get policies and update the host, then there shouldn't be any policies
|
||||
hasAnyPolicy := false
|
||||
for name := range queries {
|
||||
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
|
||||
hasAnyPolicy = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.False(t, hasAnyPolicy)
|
||||
|
||||
// Let's move time forward, there should be policies now
|
||||
// Let's move time forward, there should be policies now.
|
||||
mockClock.AddTime(2 * time.Hour)
|
||||
|
||||
queries, _, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+2)
|
||||
checkPolicyResults(queries)
|
||||
|
||||
hasPolicy1, hasPolicy2 = false, false
|
||||
for name := range queries {
|
||||
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
|
||||
if name[len(hostPolicyQueryPrefix):] == "1" {
|
||||
hasPolicy1 = true
|
||||
}
|
||||
if name[len(hostPolicyQueryPrefix):] == "2" {
|
||||
hasPolicy2 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasPolicy1)
|
||||
assert.True(t, hasPolicy2)
|
||||
// Record another query execution.
|
||||
err = svc.SubmitDistributedQueryResults(
|
||||
ctx,
|
||||
map[string][]map[string]string{
|
||||
hostPolicyQueryPrefix + "1": {{"col1": "val1"}},
|
||||
},
|
||||
map[string]fleet.OsqueryStatus{
|
||||
hostPolicyQueryPrefix + "2": 1,
|
||||
},
|
||||
map[string]string{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, recordedResults[1])
|
||||
require.Equal(t, true, *recordedResults[1])
|
||||
require.Nil(t, recordedResults[2])
|
||||
|
||||
// There shouldn't be any policies now.
|
||||
ctx = hostctx.NewContext(context.Background(), *host)
|
||||
queries, _, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries)
|
||||
noPolicyResults(queries)
|
||||
|
||||
// With refetch requested policy queries should be returned.
|
||||
host.RefetchRequested = true
|
||||
ctx = hostctx.NewContext(context.Background(), *host)
|
||||
queries, _, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries+2)
|
||||
checkPolicyResults(queries)
|
||||
|
||||
// Record another query execution.
|
||||
err = svc.SubmitDistributedQueryResults(
|
||||
ctx,
|
||||
map[string][]map[string]string{
|
||||
hostPolicyQueryPrefix + "1": {{"col1": "val1"}},
|
||||
},
|
||||
map[string]fleet.OsqueryStatus{
|
||||
hostPolicyQueryPrefix + "2": 1,
|
||||
},
|
||||
map[string]string{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, recordedResults[1])
|
||||
require.Equal(t, true, *recordedResults[1])
|
||||
require.Nil(t, recordedResults[2])
|
||||
|
||||
// SubmitDistributedQueryResults will set RefetchRequested to false.
|
||||
require.False(t, host.RefetchRequested)
|
||||
|
||||
// There shouldn't be any policies now.
|
||||
ctx = hostctx.NewContext(context.Background(), *host)
|
||||
queries, _, err = svc.GetDistributedQueries(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queries, expectedDetailQueries)
|
||||
noPolicyResults(queries)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user