Live queries post merge fixes (#2684)

* Refactor based on PR review comments

* Use deadline instead of duration

* Address review comments

* Only set the timeout once
This commit is contained in:
Tomas Touceda 2021-10-26 16:12:07 -03:00 committed by GitHub
parent 66faac25c9
commit 623a38aa9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 104 deletions

View File

@ -3223,30 +3223,6 @@ fleet instance crashing.
}
```
##### Default response
`Status: 200`
```json
{
"campaign": {
"created_at": "0001-01-01T00:00:00Z",
"updated_at": "0001-01-01T00:00:00Z",
"Metrics": {
"TotalHosts": 1,
"OnlineHosts": 0,
"OfflineHosts": 1,
"MissingInActionHosts": 0,
"NewHosts": 1
},
"id": 1,
"query_id": 3,
"status": 0,
"user_id": 1
}
}
```
---
## Schedule

View File

@ -43,3 +43,15 @@ type DistributedQueryResult struct {
// implementing that interface may not (un)marshal properly
Error *string `json:"error"`
}
type QueryResult struct {
HostID uint `json:"host_id"`
Rows []map[string]string `json:"rows"`
Error *string `json:"error"`
}
type QueryCampaignResult struct {
QueryID uint `json:"query_id"`
Error *string `json:"error,omitempty"`
Results []QueryResult `json:"results"`
}

View File

@ -3,6 +3,7 @@ package fleet
import (
"context"
"encoding/json"
"time"
"github.com/fleetdm/fleet/v4/server/websocket"
"github.com/kolide/kit/version"
@ -223,6 +224,7 @@ type Service interface {
GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int)
///////////////////////////////////////////////////////////////////////////////
// AgentOptionsService

View File

@ -8,6 +8,7 @@ import (
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/pkg/errors"
@ -27,31 +28,13 @@ type runLiveQueryResponse struct {
Summary summaryPayload `json:"summary"`
Err error `json:"error,omitempty"`
Results []queryCampaignResult `json:"live_query_results"`
Results []fleet.QueryCampaignResult `json:"live_query_results"`
}
func (r runLiveQueryResponse) error() error { return r.Err }
type queryResult struct {
HostID uint `json:"host_id"`
Rows []map[string]string `json:"rows"`
Error *string `json:"error"`
}
type queryCampaignResult struct {
QueryID uint `json:"query_id"`
Error *string `json:"error,omitempty"`
Results []queryResult `json:"results"`
}
func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*runLiveQueryRequest)
wg := sync.WaitGroup{}
resultsCh := make(chan queryCampaignResult)
counterMutex := sync.Mutex{}
counter := make(map[uint]struct{})
period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD")
if period == "" {
@ -60,62 +43,9 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
duration, err := time.ParseDuration(period)
if err != nil {
duration = 90 * time.Second
logging.WithExtras(ctx, "live_query_rest_period_err", err)
}
for _, queryID := range req.QueryIDs {
queryID := queryID
wg.Add(1)
go func() {
defer wg.Done()
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: req.HostIDs})
if err != nil {
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
return
}
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
if err != nil {
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
return
}
defer cancelFunc()
defer func() {
err := svc.CompleteCampaign(ctx, campaign)
if err != nil {
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
}
}()
ticker := time.NewTicker(duration)
defer ticker.Stop()
var results []queryResult
loop:
for {
select {
case res := <-readChan:
// Receive a result and push it over the websocket
switch res := res.(type) {
case fleet.DistributedQueryResult:
results = append(results, queryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
counterMutex.Lock()
counter[res.Host.ID] = struct{}{}
counterMutex.Unlock()
}
case <-ticker.C:
break loop
}
}
resultsCh <- queryCampaignResult{QueryID: queryID, Results: results}
}()
}
go func() {
wg.Wait()
close(resultsCh)
}()
res := runLiveQueryResponse{
Summary: summaryPayload{
TargetedHostCount: len(req.HostIDs),
@ -123,15 +53,86 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
},
}
for result := range resultsCh {
res.Results = append(res.Results, result)
}
res.Summary.RespondedHostCount = len(counter)
queryResults, respondedHostCount := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
res.Results = queryResults
res.Summary.RespondedHostCount = respondedHostCount
return res, nil
}
func (svc *Service) RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]fleet.QueryCampaignResult, int) {
wg := sync.WaitGroup{}
resultsCh := make(chan fleet.QueryCampaignResult)
counterMutex := sync.Mutex{}
respondedHostIDs := make(map[uint]struct{})
for _, queryID := range queryIDs {
queryID := queryID
wg.Add(1)
go func() {
defer wg.Done()
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs})
if err != nil {
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
return
}
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
if err != nil {
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
return
}
defer cancelFunc()
defer func() {
err := svc.CompleteCampaign(ctx, campaign)
if err != nil {
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
}
}()
var results []fleet.QueryResult
timeout := time.After(deadline)
loop:
for {
select {
case res := <-readChan:
switch res := res.(type) {
case fleet.DistributedQueryResult:
results = append(results, fleet.QueryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
counterMutex.Lock()
respondedHostIDs[res.Host.ID] = struct{}{}
counterMutex.Unlock()
case error:
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error())}
return
}
case <-timeout:
break loop
case <-ctx.Done():
break loop
}
}
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Results: results}
}()
}
// Iterate collecting results until all the goroutines have returned
go func() {
wg.Wait()
close(resultsCh)
}()
var results []fleet.QueryCampaignResult
for result := range resultsCh {
results = append(results, result)
}
return results, len(respondedHostIDs)
}
func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
// Open the channel from which we will receive incoming query results
// (probably from the redis pubsub implementation)
@ -143,8 +144,6 @@ func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.Distr
return nil, nil, fmt.Errorf("cannot open read channel for campaign %d ", campaign.ID)
}
// Setting status to running will cause the query to be returned to the
// targets when they check in for their queries
campaign.Status = fleet.QueryRunning
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
cancelFunc()