mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Live queries post merge fixes (#2684)
* Refactor based on PR review comments * Use deadline instead of duration * Address review comments * Only set the timeout once
This commit is contained in:
parent
66faac25c9
commit
623a38aa9d
@ -3223,30 +3223,6 @@ fleet instance crashing.
|
||||
}
|
||||
```
|
||||
|
||||
##### Default response
|
||||
|
||||
`Status: 200`
|
||||
|
||||
```json
|
||||
{
|
||||
"campaign": {
|
||||
"created_at": "0001-01-01T00:00:00Z",
|
||||
"updated_at": "0001-01-01T00:00:00Z",
|
||||
"Metrics": {
|
||||
"TotalHosts": 1,
|
||||
"OnlineHosts": 0,
|
||||
"OfflineHosts": 1,
|
||||
"MissingInActionHosts": 0,
|
||||
"NewHosts": 1
|
||||
},
|
||||
"id": 1,
|
||||
"query_id": 3,
|
||||
"status": 0,
|
||||
"user_id": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Schedule
|
||||
|
@ -43,3 +43,15 @@ type DistributedQueryResult struct {
|
||||
// implementing that interface may not (un)marshal properly
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
type QueryResult struct {
|
||||
HostID uint `json:"host_id"`
|
||||
Rows []map[string]string `json:"rows"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
type QueryCampaignResult struct {
|
||||
QueryID uint `json:"query_id"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
Results []QueryResult `json:"results"`
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package fleet
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/websocket"
|
||||
"github.com/kolide/kit/version"
|
||||
@ -223,6 +224,7 @@ type Service interface {
|
||||
|
||||
GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
|
||||
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
|
||||
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// AgentOptionsService
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/pkg/errors"
|
||||
@ -27,31 +28,13 @@ type runLiveQueryResponse struct {
|
||||
Summary summaryPayload `json:"summary"`
|
||||
Err error `json:"error,omitempty"`
|
||||
|
||||
Results []queryCampaignResult `json:"live_query_results"`
|
||||
Results []fleet.QueryCampaignResult `json:"live_query_results"`
|
||||
}
|
||||
|
||||
func (r runLiveQueryResponse) error() error { return r.Err }
|
||||
|
||||
type queryResult struct {
|
||||
HostID uint `json:"host_id"`
|
||||
Rows []map[string]string `json:"rows"`
|
||||
Error *string `json:"error"`
|
||||
}
|
||||
|
||||
type queryCampaignResult struct {
|
||||
QueryID uint `json:"query_id"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
Results []queryResult `json:"results"`
|
||||
}
|
||||
|
||||
func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*runLiveQueryRequest)
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
resultsCh := make(chan queryCampaignResult)
|
||||
|
||||
counterMutex := sync.Mutex{}
|
||||
counter := make(map[uint]struct{})
|
||||
|
||||
period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD")
|
||||
if period == "" {
|
||||
@ -60,62 +43,9 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
||||
duration, err := time.ParseDuration(period)
|
||||
if err != nil {
|
||||
duration = 90 * time.Second
|
||||
logging.WithExtras(ctx, "live_query_rest_period_err", err)
|
||||
}
|
||||
|
||||
for _, queryID := range req.QueryIDs {
|
||||
queryID := queryID
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: req.HostIDs})
|
||||
if err != nil {
|
||||
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
return
|
||||
}
|
||||
|
||||
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
||||
if err != nil {
|
||||
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
return
|
||||
}
|
||||
defer cancelFunc()
|
||||
|
||||
defer func() {
|
||||
err := svc.CompleteCampaign(ctx, campaign)
|
||||
if err != nil {
|
||||
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(duration)
|
||||
defer ticker.Stop()
|
||||
|
||||
var results []queryResult
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case res := <-readChan:
|
||||
// Receive a result and push it over the websocket
|
||||
switch res := res.(type) {
|
||||
case fleet.DistributedQueryResult:
|
||||
results = append(results, queryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
|
||||
counterMutex.Lock()
|
||||
counter[res.Host.ID] = struct{}{}
|
||||
counterMutex.Unlock()
|
||||
}
|
||||
case <-ticker.C:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
resultsCh <- queryCampaignResult{QueryID: queryID, Results: results}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultsCh)
|
||||
}()
|
||||
|
||||
res := runLiveQueryResponse{
|
||||
Summary: summaryPayload{
|
||||
TargetedHostCount: len(req.HostIDs),
|
||||
@ -123,15 +53,86 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
||||
},
|
||||
}
|
||||
|
||||
for result := range resultsCh {
|
||||
res.Results = append(res.Results, result)
|
||||
}
|
||||
|
||||
res.Summary.RespondedHostCount = len(counter)
|
||||
queryResults, respondedHostCount := svc.RunLiveQueryDeadline(ctx, req.QueryIDs, req.HostIDs, duration)
|
||||
res.Results = queryResults
|
||||
res.Summary.RespondedHostCount = respondedHostCount
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]fleet.QueryCampaignResult, int) {
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
resultsCh := make(chan fleet.QueryCampaignResult)
|
||||
|
||||
counterMutex := sync.Mutex{}
|
||||
respondedHostIDs := make(map[uint]struct{})
|
||||
|
||||
for _, queryID := range queryIDs {
|
||||
queryID := queryID
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs})
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
return
|
||||
}
|
||||
|
||||
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
return
|
||||
}
|
||||
defer cancelFunc()
|
||||
|
||||
defer func() {
|
||||
err := svc.CompleteCampaign(ctx, campaign)
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||
}
|
||||
}()
|
||||
|
||||
var results []fleet.QueryResult
|
||||
timeout := time.After(deadline)
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case res := <-readChan:
|
||||
switch res := res.(type) {
|
||||
case fleet.DistributedQueryResult:
|
||||
results = append(results, fleet.QueryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
|
||||
counterMutex.Lock()
|
||||
respondedHostIDs[res.Host.ID] = struct{}{}
|
||||
counterMutex.Unlock()
|
||||
case error:
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error())}
|
||||
return
|
||||
}
|
||||
case <-timeout:
|
||||
break loop
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
}
|
||||
}
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Results: results}
|
||||
}()
|
||||
}
|
||||
|
||||
// Iterate collecting results until all the goroutines have returned
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultsCh)
|
||||
}()
|
||||
|
||||
var results []fleet.QueryCampaignResult
|
||||
for result := range resultsCh {
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return results, len(respondedHostIDs)
|
||||
}
|
||||
|
||||
func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
|
||||
// Open the channel from which we will receive incoming query results
|
||||
// (probably from the redis pubsub implementation)
|
||||
@ -143,8 +144,6 @@ func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.Distr
|
||||
return nil, nil, fmt.Errorf("cannot open read channel for campaign %d ", campaign.ID)
|
||||
}
|
||||
|
||||
// Setting status to running will cause the query to be returned to the
|
||||
// targets when they check in for their queries
|
||||
campaign.Status = fleet.QueryRunning
|
||||
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
|
||||
cancelFunc()
|
||||
|
Loading…
Reference in New Issue
Block a user