mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
1a679d0882
#17197 Fixing orphaned live queries when context is canceled Co-authored-by: Lucas Rodriguez <lucas@fleetdm.com>
413 lines
13 KiB
Go
413 lines
13 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server"
|
|
"github.com/fleetdm/fleet/v4/server/authz"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/go-kit/log/level"
|
|
)
|
|
|
|
type runLiveQueryRequest struct {
|
|
QueryIDs []uint `json:"query_ids"`
|
|
HostIDs []uint `json:"host_ids"`
|
|
}
|
|
|
|
type runOneLiveQueryRequest struct {
|
|
QueryID uint `url:"id"`
|
|
HostIDs []uint `json:"host_ids"`
|
|
}
|
|
|
|
type runLiveQueryOnHostRequest struct {
|
|
Identifier string `url:"identifier"`
|
|
Query string `json:"query"`
|
|
}
|
|
|
|
type runLiveQueryOnHostByIDRequest struct {
|
|
HostID uint `url:"id"`
|
|
Query string `json:"query"`
|
|
}
|
|
|
|
type summaryPayload struct {
|
|
TargetedHostCount int `json:"targeted_host_count"`
|
|
RespondedHostCount int `json:"responded_host_count"`
|
|
}
|
|
|
|
type runLiveQueryResponse struct {
|
|
Summary summaryPayload `json:"summary"`
|
|
Err error `json:"error,omitempty"`
|
|
|
|
Results []fleet.QueryCampaignResult `json:"live_query_results"`
|
|
}
|
|
|
|
func (r runLiveQueryResponse) error() error { return r.Err }
|
|
|
|
type runOneLiveQueryResponse struct {
|
|
QueryID uint `json:"query_id"`
|
|
TargetedHostCount int `json:"targeted_host_count"`
|
|
RespondedHostCount int `json:"responded_host_count"`
|
|
Results []fleet.QueryResult `json:"results"`
|
|
Err error `json:"error,omitempty"`
|
|
}
|
|
|
|
func (r runOneLiveQueryResponse) error() error { return r.Err }
|
|
|
|
type runLiveQueryOnHostResponse struct {
|
|
HostID uint `json:"host_id"`
|
|
Rows []map[string]string `json:"rows"`
|
|
Query string `json:"query"`
|
|
Status fleet.HostStatus `json:"status"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func (r runLiveQueryOnHostResponse) error() error { return nil }
|
|
|
|
func runOneLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
|
req := request.(*runOneLiveQueryRequest)
|
|
|
|
// Only allow a host to be specified once in HostIDs
|
|
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
|
|
|
|
campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, "", hostIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
//goland:noinspection GoPreferNilSlice -- use an empty slice here so that API returns an empty array if there are no results
|
|
queryResults := []fleet.QueryResult{}
|
|
if len(campaignResults) > 0 {
|
|
if campaignResults[0].Err != nil {
|
|
return nil, campaignResults[0].Err
|
|
}
|
|
if campaignResults[0].Results != nil {
|
|
queryResults = campaignResults[0].Results
|
|
}
|
|
}
|
|
|
|
res := runOneLiveQueryResponse{
|
|
QueryID: req.QueryID,
|
|
TargetedHostCount: len(hostIDs),
|
|
RespondedHostCount: respondedHostCount,
|
|
Results: queryResults,
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
|
req := request.(*runLiveQueryRequest)
|
|
|
|
// Only allow a query to be specified once
|
|
queryIDs := server.RemoveDuplicatesFromSlice(req.QueryIDs)
|
|
// Only allow a host to be specified once in HostIDs
|
|
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
|
|
|
|
queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, "", hostIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res := runLiveQueryResponse{
|
|
Summary: summaryPayload{
|
|
TargetedHostCount: len(hostIDs),
|
|
RespondedHostCount: respondedHostCount,
|
|
},
|
|
Results: queryResults,
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
|
req := request.(*runLiveQueryOnHostRequest)
|
|
|
|
host, err := svc.HostLiteByIdentifier(ctx, req.Identifier)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %s: %s", req.Identifier, err.Error())))
|
|
}
|
|
|
|
return runLiveQueryOnHost(svc, ctx, host, req.Query)
|
|
}
|
|
|
|
func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
|
req := request.(*runLiveQueryOnHostByIDRequest)
|
|
|
|
host, err := svc.HostLiteByID(ctx, req.HostID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %d: %s", req.HostID, err.Error())))
|
|
}
|
|
|
|
return runLiveQueryOnHost(svc, ctx, host, req.Query)
|
|
}
|
|
|
|
func runLiveQueryOnHost(svc fleet.Service, ctx context.Context, host *fleet.HostLite, query string) (errorer, error) {
|
|
query = strings.TrimSpace(query)
|
|
if query == "" {
|
|
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
|
|
}
|
|
|
|
res := runLiveQueryOnHostResponse{
|
|
HostID: host.ID,
|
|
Query: query,
|
|
}
|
|
|
|
status := (&fleet.Host{
|
|
DistributedInterval: host.DistributedInterval,
|
|
ConfigTLSRefresh: host.ConfigTLSRefresh,
|
|
SeenTime: host.SeenTime,
|
|
}).Status(time.Now())
|
|
switch status {
|
|
case fleet.StatusOnline, fleet.StatusNew:
|
|
res.Status = fleet.StatusOnline
|
|
case fleet.StatusOffline, fleet.StatusMIA, fleet.StatusMissing:
|
|
res.Status = fleet.StatusOffline
|
|
return res, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown host status: %s", status)
|
|
}
|
|
|
|
queryResults, _, err := runLiveQuery(ctx, svc, []uint{0}, query, []uint{host.ID})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(queryResults) > 0 {
|
|
var err error
|
|
if queryResults[0].Err != nil {
|
|
err = queryResults[0].Err
|
|
} else if len(queryResults[0].Results) > 0 {
|
|
queryResult := queryResults[0].Results[0]
|
|
if queryResult.Error != nil {
|
|
err = errors.New(*queryResult.Error)
|
|
}
|
|
res.Rows = queryResult.Rows
|
|
res.HostID = queryResult.HostID
|
|
} else {
|
|
err = errors.New("timeout waiting for results")
|
|
}
|
|
if err != nil {
|
|
res.Error = err.Error()
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, query string, hostIDs []uint) (
|
|
[]fleet.QueryCampaignResult, int, error,
|
|
) {
|
|
// The period used here should always be less than the request timeout for any load
|
|
// balancer/proxy between Fleet and the API client.
|
|
period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD")
|
|
if period == "" {
|
|
period = "25s"
|
|
}
|
|
duration, err := time.ParseDuration(period)
|
|
if err != nil {
|
|
duration = 25 * time.Second
|
|
logging.WithExtras(ctx, "live_query_rest_period_err", err)
|
|
}
|
|
|
|
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, query, hostIDs, duration)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
// Check if all query results were forbidden due to lack of authorization.
|
|
allResultsForbidden := len(queryResults) > 0 && respondedHostCount == 0
|
|
if allResultsForbidden {
|
|
for _, r := range queryResults {
|
|
if r.Error == nil || *r.Error != authz.ForbiddenErrorMessage {
|
|
allResultsForbidden = false
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if allResultsForbidden {
|
|
return nil, 0, authz.ForbiddenWithInternal(
|
|
"All Live Query results were forbidden.", authz.UserFromContext(ctx), nil, nil,
|
|
)
|
|
}
|
|
return queryResults, respondedHostCount, nil
|
|
}
|
|
|
|
func (svc *Service) RunLiveQueryDeadline(
|
|
ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration,
|
|
) ([]fleet.QueryCampaignResult, int, error) {
|
|
if len(queryIDs) == 0 || len(hostIDs) == 0 {
|
|
svc.authz.SkipAuthorization(ctx)
|
|
return nil, 0, ctxerr.Wrap(ctx, badRequest("query_ids and host_ids are required"))
|
|
}
|
|
wg := sync.WaitGroup{}
|
|
|
|
resultsCh := make(chan fleet.QueryCampaignResult)
|
|
|
|
counterMutex := sync.Mutex{}
|
|
respondedHostIDs := make(map[uint]struct{})
|
|
|
|
for _, queryID := range queryIDs {
|
|
queryID := queryID
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
queryIDPtr := &queryID
|
|
queryString := ""
|
|
// 0 is a special ID that indicates we should use raw SQL query instead
|
|
if queryID == 0 {
|
|
queryIDPtr = nil
|
|
queryString = query
|
|
}
|
|
|
|
campaign, err := svc.NewDistributedQueryCampaign(ctx, queryString, queryIDPtr, fleet.HostTargets{HostIDs: hostIDs})
|
|
if err != nil {
|
|
level.Error(svc.logger).Log(
|
|
"msg", "new distributed query campaign",
|
|
"queryString", queryString,
|
|
"queryID", queryID,
|
|
"err", err,
|
|
)
|
|
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
|
|
return
|
|
}
|
|
queryID = campaign.QueryID
|
|
|
|
// We do not want to use the outer `ctx` directly because we want to cleanup the campaign
|
|
// even if the outer `ctx` is canceled (e.g. a client terminating the connection).
|
|
// Also, we make sure stats and activity DB operations don't get killed after we return results.
|
|
ctxWithoutCancel := context.WithoutCancel(ctx)
|
|
defer func() {
|
|
err := svc.CompleteCampaign(ctxWithoutCancel, campaign)
|
|
if err != nil {
|
|
level.Error(svc.logger).Log(
|
|
"msg", "completing campaign (sync)", "query.id", campaign.QueryID, "campaign.id", campaign.ID, "err", err,
|
|
)
|
|
resultsCh <- fleet.QueryCampaignResult{
|
|
QueryID: queryID,
|
|
Error: ptr.String(err.Error()),
|
|
Err: err,
|
|
}
|
|
}
|
|
}()
|
|
|
|
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
|
if err != nil {
|
|
level.Error(svc.logger).Log(
|
|
"msg", "get campaign reader", "query.id", campaign.QueryID, "campaign.id", campaign.ID, "err", err,
|
|
)
|
|
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
|
|
return
|
|
}
|
|
defer cancelFunc()
|
|
|
|
var results []fleet.QueryResult
|
|
timeout := time.After(deadline)
|
|
|
|
// We process stats along with results as they are sent back to the user.
|
|
// We do a batch update of the stats.
|
|
// We update aggregated stats once online hosts have reported.
|
|
const statsBatchSize = 1000
|
|
perfStatsTracker := statsTracker{}
|
|
perfStatsTracker.saveStats, err = svc.ds.IsSavedQuery(ctx, campaign.QueryID)
|
|
if err != nil {
|
|
level.Error(svc.logger).Log("msg", "error checking saved query", "query.id", campaign.QueryID, "err", err)
|
|
perfStatsTracker.saveStats = false
|
|
}
|
|
totalHosts := campaign.Metrics.TotalHosts
|
|
// We update aggregated stats and activity at the end asynchronously.
|
|
defer func() {
|
|
go func() {
|
|
svc.updateStats(ctxWithoutCancel, queryID, svc.logger, &perfStatsTracker, true)
|
|
svc.addLiveQueryActivity(ctxWithoutCancel, totalHosts, queryID, svc.logger)
|
|
}()
|
|
}()
|
|
loop:
|
|
for {
|
|
select {
|
|
case res := <-readChan:
|
|
switch res := res.(type) {
|
|
case fleet.DistributedQueryResult:
|
|
results = append(results, fleet.QueryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
|
|
counterMutex.Lock()
|
|
respondedHostIDs[res.Host.ID] = struct{}{}
|
|
counterMutex.Unlock()
|
|
if perfStatsTracker.saveStats && res.Stats != nil {
|
|
perfStatsTracker.stats = append(
|
|
perfStatsTracker.stats,
|
|
statsToSave{
|
|
hostID: res.Host.ID, Stats: res.Stats, outputSize: calculateOutputSize(&perfStatsTracker, &res),
|
|
},
|
|
)
|
|
if len(perfStatsTracker.stats) >= statsBatchSize {
|
|
svc.updateStats(ctx, campaign.QueryID, svc.logger, &perfStatsTracker, false)
|
|
}
|
|
}
|
|
if len(results) == len(hostIDs) {
|
|
break loop
|
|
}
|
|
case error:
|
|
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error()), Err: res}
|
|
return
|
|
}
|
|
case <-timeout:
|
|
break loop
|
|
case <-ctx.Done():
|
|
break loop
|
|
}
|
|
}
|
|
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Results: results}
|
|
}()
|
|
}
|
|
|
|
// Iterate collecting results until all the goroutines have returned
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultsCh)
|
|
}()
|
|
|
|
var results []fleet.QueryCampaignResult
|
|
for result := range resultsCh {
|
|
results = append(results, result)
|
|
}
|
|
|
|
return results, len(respondedHostIDs), nil
|
|
}
|
|
|
|
func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
|
|
// Open the channel from which we will receive incoming query results
|
|
// (probably from the redis pubsub implementation)
|
|
cancelCtx, cancelFunc := context.WithCancel(ctx)
|
|
|
|
readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign)
|
|
if err != nil {
|
|
cancelFunc()
|
|
return nil, nil, fmt.Errorf("cannot open read channel for campaign %d ", campaign.ID)
|
|
}
|
|
|
|
campaign.Status = fleet.QueryRunning
|
|
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
|
|
cancelFunc()
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "error saving campaign state")
|
|
}
|
|
|
|
return readChan, cancelFunc, nil
|
|
}
|
|
|
|
func (svc *Service) CompleteCampaign(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error {
|
|
campaign.Status = fleet.QueryComplete
|
|
err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "saving distributed campaign after complete")
|
|
}
|
|
err = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID)))
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "stopping query after after complete")
|
|
}
|
|
return nil
|
|
}
|