Issue 1798 rest api for live queries (#2675)

* wip

* wip

* Make tests not flaky

* Add changes file

* Make logging and authz thread safe

* Move cancel defer below err check
This commit is contained in:
Tomas Touceda 2021-10-26 11:33:31 -03:00 committed by GitHub
parent 58b26ae1ba
commit 79f2c7c391
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 791 additions and 76 deletions

View File

@ -0,0 +1 @@
* Add live query REST API

View File

@ -3148,6 +3148,105 @@ o
]
```
### Run live queries and retrieve 90 seconds of results using only REST
The same endpoint used to run queries and gather the results over websockets can also be used to collect the results
over a fixed period of 90 seconds and get the results in the REST response.
WARNING: this endpoint collects responses in memory and the elapsed time is capped at 90 seconds, regardless of whether
all results possible have been gathered. This can cause an autoscaling event, depending on the configuration, or the
fleet instance crashing.
`GET /api/v1/fleet/queries/run`
#### Parameters
| Name | Type | In | Description |
| --------- | ------ | ---- | --------------------------------------------- |
| query_ids | array | body | **Required**. The IDs of the queries to run as live queries. |
| host_ids | array | body | **Required**. The IDs of the hosts to run the live queries against. |
#### Example
`GET /api/v1/fleet/queries/run`
##### Request body
```json
{
"query_ids": [ 1, 2 ],
"host_ids": [ 1, 4, 34, 27 ]
}
```
##### Default response
```json
{
"summary": {
"targeted_host_count": 4,
"responded_host_count": 2
},
"live_query_results": [
{
"query_id": 2,
"results": [
{
"host_id": 1,
"rows": [
{
"build_distro": "10.12",
"build_platform": "darwin",
"config_hash": "7bb99fa2c8a998c9459ec71da3a84d66c592d6d3",
"config_valid": "1",
"extensions": "active",
"instance_id": "9a2ec7bf-4946-46ea-93bf-455e0bcbd068",
"pid": "23413",
"platform_mask": "21",
"start_time": "1635194306",
"uuid": "4C182AC7-75F7-5AF4-A74B-1E165ED35742",
"version": "4.9.0",
"watcher": "23412"
}
],
"error": null
},
{
"host_id": 2,
"rows": [],
"error": "no such table: os_version"
}
]
}
]
}
```
##### Default response
`Status: 200`
```json
{
"campaign": {
"created_at": "0001-01-01T00:00:00Z",
"updated_at": "0001-01-01T00:00:00Z",
"Metrics": {
"TotalHosts": 1,
"OnlineHosts": 0,
"OfflineHosts": 1,
"MissingInActionHosts": 0,
"NewHosts": 1
},
"id": 1,
"query_id": 3,
"status": 0,
"user_id": 1
}
}
```
---
## Schedule

View File

@ -67,7 +67,7 @@ func (a *Authorizer) SkipAuthorization(ctx context.Context) {
// Mark the authorization context as checked (otherwise middleware will
// error).
if authctx, ok := authz_ctx.FromContext(ctx); ok {
authctx.Checked = true
authctx.SetChecked()
}
}
@ -81,7 +81,7 @@ func (a *Authorizer) Authorize(ctx context.Context, object, action interface{})
// Mark the authorization context as checked (otherwise middleware will
// error).
if authctx, ok := authz_ctx.FromContext(ctx); ok {
authctx.Checked = true
authctx.SetChecked()
}
subject := UserFromContext(ctx)

View File

@ -2,7 +2,10 @@
// request has had an authorization check performed before returning results.
package authz
import "context"
import (
"context"
"sync"
)
type key int
@ -22,6 +25,19 @@ func FromContext(ctx context.Context) (*AuthorizationContext, bool) {
// AuthorizationContext contains the context information used for the
// authorization check.
type AuthorizationContext struct {
// Checked indicates whether a call was made to check authorization for the request.
Checked bool
l sync.Mutex
// checked indicates whether a call was made to check authorization for the request.
checked bool
}
func (a *AuthorizationContext) Checked() bool {
a.l.Lock()
defer a.l.Unlock()
return a.checked
}
func (a *AuthorizationContext) SetChecked() {
a.l.Lock()
defer a.l.Unlock()
a.checked = true
}

View File

@ -2,6 +2,7 @@ package logging
import (
"context"
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
@ -29,7 +30,7 @@ func FromContext(ctx context.Context) (*LoggingContext, bool) {
// WithStartTime returns a context with logging.StartTime marked as the current time
func WithStartTime(ctx context.Context) context.Context {
if logCtx, ok := FromContext(ctx); ok {
logCtx.StartTime = time.Now()
logCtx.SetStartTime()
}
return ctx
}
@ -37,7 +38,7 @@ func WithStartTime(ctx context.Context) context.Context {
// WithErr returns a context with logging.Err set as the error provided
func WithErr(ctx context.Context, err ...error) context.Context {
if logCtx, ok := FromContext(ctx); ok {
logCtx.Errs = append(logCtx.Errs, err...)
logCtx.SetErrs(err...)
}
return ctx
}
@ -45,7 +46,7 @@ func WithErr(ctx context.Context, err ...error) context.Context {
// WithNoUser returns a context with logging.SkipUser set to true so user won't be logged
func WithNoUser(ctx context.Context) context.Context {
if logCtx, ok := FromContext(ctx); ok {
logCtx.SkipUser = true
logCtx.SetSkipUser()
}
return ctx
}
@ -53,20 +54,22 @@ func WithNoUser(ctx context.Context) context.Context {
// WithExtras returns a context with logging.Extras set as the values provided
func WithExtras(ctx context.Context, extras ...interface{}) context.Context {
if logCtx, ok := FromContext(ctx); ok {
logCtx.Extras = append(logCtx.Extras, extras...)
logCtx.SetExtras(extras...)
}
return ctx
}
func WithLevel(ctx context.Context, level func(kitlog.Logger) kitlog.Logger) context.Context {
if logCtx, ok := FromContext(ctx); ok {
logCtx.ForceLevel = level
logCtx.SetForceLevel(level)
}
return ctx
}
// LoggingContext contains the context information for logging the current request
type LoggingContext struct {
l sync.Mutex
StartTime time.Time
Errs []error
Extras []interface{}
@ -74,8 +77,37 @@ type LoggingContext struct {
ForceLevel func(kitlog.Logger) kitlog.Logger
}
func (l *LoggingContext) SetForceLevel(level func(kitlog.Logger) kitlog.Logger) {
l.l.Lock()
defer l.l.Unlock()
l.ForceLevel = level
}
func (l *LoggingContext) SetExtras(extras ...interface{}) {
l.l.Lock()
defer l.l.Unlock()
l.Extras = append(l.Extras, extras...)
}
func (l *LoggingContext) SetSkipUser() {
l.l.Lock()
defer l.l.Unlock()
l.SkipUser = true
}
func (l *LoggingContext) SetStartTime() {
l.l.Lock()
defer l.l.Unlock()
l.StartTime = time.Now()
}
func (l *LoggingContext) SetErrs(err ...error) {
l.l.Lock()
defer l.l.Unlock()
l.Errs = append(l.Errs, err...)
}
// Log logs the data within the context
func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
l.l.Lock()
defer l.l.Unlock()
if l.ForceLevel != nil {
logger = l.ForceLevel(logger)
} else if l.Errs != nil || len(l.Extras) > 0 {

View File

@ -64,6 +64,15 @@ func (d *Datastore) SaveDistributedQueryCampaign(ctx context.Context, camp *flee
return nil
}
func (d *Datastore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) {
var campaigns []*fleet.DistributedQueryCampaign
err := sqlx.SelectContext(ctx, d.reader, &campaigns, `SELECT * FROM distributed_query_campaigns WHERE query_id=?`, queryID)
if err != nil {
return nil, errors.Wrap(err, "getting campaigns for query")
}
return campaigns, nil
}
func (d *Datastore) DistributedQueryCampaignTargetIDs(ctx context.Context, id uint) (*fleet.HostTargets, error) {
sqlStatement := `
SELECT * FROM distributed_query_campaign_targets WHERE distributed_query_campaign_id = ?

View File

@ -31,17 +31,6 @@ type DistributedQueryCampaignTarget struct {
TargetID uint `db:"target_id"`
}
// DistributedQueryExecutionStatus is the status of a distributed query
// execution on a single host.
type DistributedQueryExecutionStatus int
const (
ExecutionWaiting DistributedQueryExecutionStatus = iota
ExecutionRequested
ExecutionSucceeded
ExecutionFailed
)
// DistributedQueryResult is the result returned from the execution of a
// distributed query on a single host.
type DistributedQueryResult struct {

View File

@ -86,6 +86,8 @@ type Datastore interface {
// easier to test. The return values indicate how many campaigns were expired and any error.
CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error)
DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*DistributedQueryCampaign, error)
///////////////////////////////////////////////////////////////////////////////
// PackStore is the datastore interface for managing query packs.

View File

@ -221,6 +221,9 @@ type Service interface {
// go-kit RPC style.
StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint)
GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
///////////////////////////////////////////////////////////////////////////////
// AgentOptionsService

View File

@ -75,6 +75,8 @@ type NewDistributedQueryCampaignTargetFunc func(ctx context.Context, target *fle
type CleanupDistributedQueryCampaignsFunc func(ctx context.Context, now time.Time) (expired uint, err error)
type DistributedQueryCampaignsForQueryFunc func(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error)
type ApplyPackSpecsFunc func(ctx context.Context, specs []*fleet.PackSpec) error
type GetPackSpecsFunc func(ctx context.Context) ([]*fleet.PackSpec, error)
@ -406,6 +408,9 @@ type DataStore struct {
CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFunc
CleanupDistributedQueryCampaignsFuncInvoked bool
DistributedQueryCampaignsForQueryFunc DistributedQueryCampaignsForQueryFunc
DistributedQueryCampaignsForQueryFuncInvoked bool
ApplyPackSpecsFunc ApplyPackSpecsFunc
ApplyPackSpecsFuncInvoked bool
@ -918,6 +923,11 @@ func (s *DataStore) CleanupDistributedQueryCampaigns(ctx context.Context, now ti
return s.CleanupDistributedQueryCampaignsFunc(ctx, now)
}
func (s *DataStore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) {
s.DistributedQueryCampaignsForQueryFuncInvoked = true
return s.DistributedQueryCampaignsForQueryFunc(ctx, queryID)
}
func (s *DataStore) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) error {
s.ApplyPackSpecsFuncInvoked = true
return s.ApplyPackSpecsFunc(ctx, specs)

View File

@ -570,15 +570,15 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users")
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user")
r.Handle("/api/v1/fleet/users/{id}", h.GetUser).Methods("GET").Name("get_user")
r.Handle("/api/v1/fleet/users/{id}", h.ModifyUser).Methods("PATCH").Name("modify_user")
r.Handle("/api/v1/fleet/users/{id}", h.DeleteUser).Methods("DELETE").Name("delete_user")
r.Handle("/api/v1/fleet/users/{id}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset")
r.Handle("/api/v1/fleet/users/{id}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user")
r.Handle("/api/v1/fleet/users/{id}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.GetUser).Methods("GET").Name("get_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.ModifyUser).Methods("PATCH").Name("modify_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.DeleteUser).Methods("DELETE").Name("delete_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user")
r.Handle("/api/v1/fleet/sessions/{id}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
r.Handle("/api/v1/fleet/sessions/{id}", h.DeleteSession).Methods("DELETE").Name("delete_session")
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session")
r.Handle("/api/v1/fleet/config/certificate", h.GetCertificate).Methods("GET").Name("get_certificate")
r.Handle("/api/v1/fleet/config", h.GetAppConfig).Methods("GET").Name("get_app_config")
@ -587,17 +587,17 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/spec/enroll_secret", h.GetEnrollSecretSpec).Methods("GET").Name("get_enroll_secret_spec")
r.Handle("/api/v1/fleet/invites", h.CreateInvite).Methods("POST").Name("create_invite")
r.Handle("/api/v1/fleet/invites", h.ListInvites).Methods("GET").Name("list_invites")
r.Handle("/api/v1/fleet/invites/{id}", h.DeleteInvite).Methods("DELETE").Name("delete_invite")
r.Handle("/api/v1/fleet/invites/{id:[0-9]+}", h.DeleteInvite).Methods("DELETE").Name("delete_invite")
r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite")
r.Handle("/api/v1/fleet/email/change/{token}", h.ChangeEmail).Methods("GET").Name("change_email")
r.Handle("/api/v1/fleet/queries/{id}", h.GetQuery).Methods("GET").Name("get_query")
r.Handle("/api/v1/fleet/queries/{id:[0-9]+}", h.GetQuery).Methods("GET").Name("get_query")
r.Handle("/api/v1/fleet/queries", h.ListQueries).Methods("GET").Name("list_queries")
r.Handle("/api/v1/fleet/queries", h.CreateQuery).Methods("POST").Name("create_query")
r.Handle("/api/v1/fleet/queries/{id}", h.ModifyQuery).Methods("PATCH").Name("modify_query")
r.Handle("/api/v1/fleet/queries/{id:[0-9]+}", h.ModifyQuery).Methods("PATCH").Name("modify_query")
r.Handle("/api/v1/fleet/queries/{name}", h.DeleteQuery).Methods("DELETE").Name("delete_query")
r.Handle("/api/v1/fleet/queries/id/{id}", h.DeleteQueryByID).Methods("DELETE").Name("delete_query_by_id")
r.Handle("/api/v1/fleet/queries/id/{id:[0-9]+}", h.DeleteQueryByID).Methods("DELETE").Name("delete_query_by_id")
r.Handle("/api/v1/fleet/queries/delete", h.DeleteQueries).Methods("POST").Name("delete_queries")
r.Handle("/api/v1/fleet/spec/queries", h.ApplyQuerySpecs).Methods("POST").Name("apply_query_specs")
r.Handle("/api/v1/fleet/spec/queries", h.GetQuerySpecs).Methods("GET").Name("get_query_specs")
@ -606,41 +606,41 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names")
r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack")
r.Handle("/api/v1/fleet/packs/{id}", h.ModifyPack).Methods("PATCH").Name("modify_pack")
r.Handle("/api/v1/fleet/packs/{id:[0-9]+}", h.ModifyPack).Methods("PATCH").Name("modify_pack")
r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs")
r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack")
r.Handle("/api/v1/fleet/packs/id/{id}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id")
r.Handle("/api/v1/fleet/packs/{id}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack")
r.Handle("/api/v1/fleet/packs/id/{id:[0-9]+}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id")
r.Handle("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack")
r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query")
r.Handle("/api/v1/fleet/schedule/{id}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query")
r.Handle("/api/v1/fleet/spec/packs", h.ApplyPackSpecs).Methods("POST").Name("apply_pack_specs")
r.Handle("/api/v1/fleet/spec/packs", h.GetPackSpecs).Methods("GET").Name("get_pack_specs")
r.Handle("/api/v1/fleet/spec/packs/{name}", h.GetPackSpec).Methods("GET").Name("get_pack_spec")
r.Handle("/api/v1/fleet/global/schedule", h.GetGlobalSchedule).Methods("GET").Name("set_global_schedule")
r.Handle("/api/v1/fleet/global/schedule", h.GlobalScheduleQuery).Methods("POST").Name("add_to_global_schedule")
r.Handle("/api/v1/fleet/global/schedule/{id}", h.ModifyGlobalSchedule).Methods("PATCH").Name("modify_global_schedule")
r.Handle("/api/v1/fleet/global/schedule/{id}", h.DeleteGlobalSchedule).Methods("DELETE").Name("delete_global_schedule")
r.Handle("/api/v1/fleet/global/schedule/{id:[0-9]+}", h.ModifyGlobalSchedule).Methods("PATCH").Name("modify_global_schedule")
r.Handle("/api/v1/fleet/global/schedule/{id:[0-9]+}", h.DeleteGlobalSchedule).Methods("DELETE").Name("delete_global_schedule")
r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label")
r.Handle("/api/v1/fleet/labels/{id}", h.ModifyLabel).Methods("PATCH").Name("modify_label")
r.Handle("/api/v1/fleet/labels/{id}", h.GetLabel).Methods("GET").Name("get_label")
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.ModifyLabel).Methods("PATCH").Name("modify_label")
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.GetLabel).Methods("GET").Name("get_label")
r.Handle("/api/v1/fleet/labels", h.ListLabels).Methods("GET").Name("list_labels")
r.Handle("/api/v1/fleet/labels/{id}/hosts", h.ListHostsInLabel).Methods("GET").Name("list_hosts_in_label")
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}/hosts", h.ListHostsInLabel).Methods("GET").Name("list_hosts_in_label")
r.Handle("/api/v1/fleet/labels/{name}", h.DeleteLabel).Methods("DELETE").Name("delete_label")
r.Handle("/api/v1/fleet/labels/id/{id}", h.DeleteLabelByID).Methods("DELETE").Name("delete_label_by_id")
r.Handle("/api/v1/fleet/labels/id/{id:[0-9]+}", h.DeleteLabelByID).Methods("DELETE").Name("delete_label_by_id")
r.Handle("/api/v1/fleet/spec/labels", h.ApplyLabelSpecs).Methods("POST").Name("apply_label_specs")
r.Handle("/api/v1/fleet/spec/labels", h.GetLabelSpecs).Methods("GET").Name("get_label_specs")
r.Handle("/api/v1/fleet/spec/labels/{name}", h.GetLabelSpec).Methods("GET").Name("get_label_spec")
r.Handle("/api/v1/fleet/host_summary", h.GetHostSummary).Methods("GET").Name("get_host_summary")
r.Handle("/api/v1/fleet/hosts/identifier/{identifier}", h.HostByIdentifier).Methods("GET").Name("host_by_identifier")
r.Handle("/api/v1/fleet/hosts/{id}", h.DeleteHost).Methods("DELETE").Name("delete_host")
r.Handle("/api/v1/fleet/hosts/{id:[0-9]+}", h.DeleteHost).Methods("DELETE").Name("delete_host")
r.Handle("/api/v1/fleet/hosts/transfer", h.AddHostsToTeam).Methods("POST").Name("add_hosts_to_team")
r.Handle("/api/v1/fleet/hosts/transfer/filter", h.AddHostsToTeamByFilter).Methods("POST").Name("add_hosts_to_team_by_filter")
r.Handle("/api/v1/fleet/hosts/{id}/refetch", h.RefetchHost).Methods("POST").Name("refetch_host")
r.Handle("/api/v1/fleet/hosts/{id:[0-9]+}/refetch", h.RefetchHost).Methods("POST").Name("refetch_host")
r.Handle("/api/v1/fleet/targets", h.SearchTargets).Methods("POST").Name("search_targets")
@ -650,18 +650,18 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query")
r.Handle("/api/v1/fleet/carves", h.ListCarves).Methods("GET").Name("list_carves")
r.Handle("/api/v1/fleet/carves/{id}", h.GetCarve).Methods("GET").Name("get_carve")
r.Handle("/api/v1/fleet/carves/{id}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block")
r.Handle("/api/v1/fleet/carves/{id:[0-9]+}", h.GetCarve).Methods("GET").Name("get_carve")
r.Handle("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block")
r.Handle("/api/v1/fleet/teams", h.CreateTeam).Methods("POST").Name("create_team")
r.Handle("/api/v1/fleet/teams", h.ListTeams).Methods("GET").Name("list_teams")
r.Handle("/api/v1/fleet/teams/{id}", h.ModifyTeam).Methods("PATCH").Name("modify_team")
r.Handle("/api/v1/fleet/teams/{id}", h.DeleteTeam).Methods("DELETE").Name("delete_team")
r.Handle("/api/v1/fleet/teams/{id}/agent_options", h.ModifyTeamAgentOptions).Methods("POST").Name("modify_team_agent_options")
r.Handle("/api/v1/fleet/teams/{id}/users", h.ListTeamUsers).Methods("GET").Name("team_users")
r.Handle("/api/v1/fleet/teams/{id}/users", h.AddTeamUsers).Methods("PATCH").Name("add_team_users")
r.Handle("/api/v1/fleet/teams/{id}/users", h.DeleteTeamUsers).Methods("DELETE").Name("delete_team_users")
r.Handle("/api/v1/fleet/teams/{id}/secrets", h.TeamEnrollSecrets).Methods("GET").Name("get_team_enroll_secrets")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.ModifyTeam).Methods("PATCH").Name("modify_team")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.DeleteTeam).Methods("DELETE").Name("delete_team")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/agent_options", h.ModifyTeamAgentOptions).Methods("POST").Name("modify_team_agent_options")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.ListTeamUsers).Methods("GET").Name("team_users")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.AddTeamUsers).Methods("PATCH").Name("add_team_users")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.DeleteTeamUsers).Methods("DELETE").Name("delete_team_users")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/secrets", h.TeamEnrollSecrets).Methods("GET").Name("get_team_enroll_secrets")
r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent")
r.Handle("/api/v1/osquery/config", h.GetClientConfig).Methods("POST").Name("get_client_config")
@ -719,6 +719,8 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{})
e.GET("/api/v1/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{})
e.GET("/api/v1/fleet/hosts/count", countHostsEndpoint, countHostsRequest{})
e.GET("/api/v1/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{})
}
// TODO: this duplicates the one in makeKitHandler

View File

@ -0,0 +1,398 @@
package service
import (
"context"
"fmt"
"net/http"
"os"
"sort"
"sync"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query"
"github.com/fleetdm/fleet/v4/server/pubsub"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
func TestIntegrationLiveQueriesTestSuite(t *testing.T) {
testingSuite := new(liveQueriesTestSuite)
testingSuite.s = &testingSuite.Suite
suite.Run(t, testingSuite)
}
type liveQueriesTestSuite struct {
withServer
suite.Suite
lq *live_query.MockLiveQuery
hosts []*fleet.Host
}
func (s *liveQueriesTestSuite) SetupSuite() {
require.NoError(s.T(), os.Setenv("FLEET_LIVE_QUERY_REST_PERIOD", "5s"))
s.withDS.SetupSuite("liveQueriesTestSuite")
rs := pubsub.NewInmemQueryResults()
lq := new(live_query.MockLiveQuery)
s.lq = lq
users, server := RunServerForTestsWithDS(s.T(), s.ds, TestServerOpts{Lq: lq, Rs: rs})
s.server = server
s.users = users
s.token = getTestAdminToken(s.T(), s.server)
t := s.T()
for i := 0; i < 3; i++ {
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now().Add(-time.Duration(i) * time.Minute),
OsqueryHostID: fmt.Sprintf("%s%d", t.Name(), i),
NodeKey: fmt.Sprintf("%s%d", t.Name(), i),
UUID: fmt.Sprintf("%s%d", t.Name(), i),
Hostname: fmt.Sprintf("%sfoo.local%d", t.Name(), i),
})
require.NoError(s.T(), err)
s.hosts = append(s.hosts, host)
}
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
t := s.T()
host := s.hosts[0]
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
require.NoError(t, err)
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)
cid := getCIDForQ(s, q1)
distributedReq := submitDistributedQueryResultsRequest{
NodeKey: host.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid: {{"col1": "a", "col2": "b"}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid: "some msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
wg.Wait()
require.Len(t, liveQueryResp.Results, 1)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1)
assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"])
assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"])
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
t := s.T()
host := s.hosts[0]
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
require.NoError(t, err)
q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 2 from osquery;", Description: "desc2", Name: t.Name() + "query2"})
require.NoError(t, err)
s.lq.On("QueriesForHost", host.ID).Return(map[string]string{
fmt.Sprint(q1.ID): "select 1 from osquery;",
fmt.Sprint(q2.ID): "select 2 from osquery;",
}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID, q2.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)
cid1 := getCIDForQ(s, q1)
cid2 := getCIDForQ(s, q2)
distributedReq := submitDistributedQueryResultsRequest{
NodeKey: host.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}},
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid1: 0,
hostDistributedQueryPrefix + cid2: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid1: "some msg",
hostDistributedQueryPrefix + cid2: "some other msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
wg.Wait()
require.Len(t, liveQueryResp.Results, 2)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
sort.Slice(liveQueryResp.Results, func(i, j int) bool {
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
})
require.True(t, q1.ID < q2.ID)
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results, 1)
q1Results := liveQueryResp.Results[0].Results[0]
require.Len(t, q1Results.Rows, 1)
assert.Equal(t, "a", q1Results.Rows[0]["col1"])
assert.Equal(t, "b", q1Results.Rows[0]["col2"])
assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
require.Len(t, liveQueryResp.Results[1].Results, 1)
q2Results := liveQueryResp.Results[1].Results[0]
require.Len(t, q2Results.Rows, 2)
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
}
func getCIDForQ(s *liveQueriesTestSuite, q1 *fleet.Query) string {
t := s.T()
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
require.NoError(t, err)
require.Len(t, campaigns, 1)
cid1 := fmt.Sprint(campaigns[0].ID)
return cid1
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
t := s.T()
h1 := s.hosts[0]
h2 := s.hosts[1]
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
require.NoError(t, err)
q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 2 from osquery;", Description: "desc2", Name: t.Name() + "query2"})
require.NoError(t, err)
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{
fmt.Sprint(q1.ID): "select 1 from osquery;",
fmt.Sprint(q2.ID): "select 2 from osquery;",
}, nil)
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{
fmt.Sprint(q1.ID): "select 1 from osquery;",
fmt.Sprint(q2.ID): "select 2 from osquery;",
}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID, q2.ID},
HostIDs: []uint{h1.ID, h2.ID},
}
liveQueryResp := runLiveQueryResponse{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)
cid1 := getCIDForQ(s, q1)
cid2 := getCIDForQ(s, q2)
for i, h := range []*fleet.Host{h1, h2} {
distributedReq := submitDistributedQueryResultsRequest{
NodeKey: h.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid1: {{"col1": fmt.Sprintf("a%d", i), "col2": fmt.Sprintf("b%d", i)}},
hostDistributedQueryPrefix + cid2: {{"col3": fmt.Sprintf("c%d", i), "col4": fmt.Sprintf("d%d", i)}, {"col3": fmt.Sprintf("e%d", i), "col4": fmt.Sprintf("f%d", i)}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid1: 0,
hostDistributedQueryPrefix + cid2: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid1: "some msg",
hostDistributedQueryPrefix + cid2: "some other msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
}
wg.Wait()
require.Len(t, liveQueryResp.Results, 2) // 2 queries
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
sort.Slice(liveQueryResp.Results, func(i, j int) bool {
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
})
require.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results, 2)
for i, r := range liveQueryResp.Results[0].Results {
require.Len(t, r.Rows, 1)
assert.Equal(t, fmt.Sprintf("a%d", i), r.Rows[0]["col1"])
assert.Equal(t, fmt.Sprintf("b%d", i), r.Rows[0]["col2"])
}
require.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
require.Len(t, liveQueryResp.Results[1].Results, 2)
for i, r := range liveQueryResp.Results[1].Results {
require.Len(t, r.Rows, 2)
assert.Equal(t, fmt.Sprintf("c%d", i), r.Rows[0]["col3"])
assert.Equal(t, fmt.Sprintf("d%d", i), r.Rows[0]["col4"])
assert.Equal(t, fmt.Sprintf("e%d", i), r.Rows[1]["col3"])
assert.Equal(t, fmt.Sprintf("f%d", i), r.Rows[1]["col4"])
}
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
t := s.T()
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{999},
HostIDs: []uint{888},
}
liveQueryResp := runLiveQueryResponse{}
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
require.Len(t, liveQueryResp.Results, 1)
assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount)
require.NotNil(t, liveQueryResp.Results[0].Error)
assert.Equal(t, "selecting query: sql: no rows in result set", *liveQueryResp.Results[0].Error)
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {
t := s.T()
h1 := s.hosts[0]
h2 := s.hosts[1]
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
require.NoError(t, err)
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{h1.ID, h2.ID},
}
liveQueryResp := runLiveQueryResponse{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)
cid1 := getCIDForQ(s, q1)
distributedReq := submitDistributedQueryResultsRequest{
NodeKey: h1.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid1: 0,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid1: "some msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
distributedReq = submitDistributedQueryResultsRequest{
NodeKey: h2.NodeKey,
Results: map[string][]map[string]string{
hostDistributedQueryPrefix + cid1: {},
},
Statuses: map[string]fleet.OsqueryStatus{
hostDistributedQueryPrefix + cid1: 123,
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid1: "some error!",
},
}
distributedResp = submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
wg.Wait()
require.Len(t, liveQueryResp.Results, 1)
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
result := liveQueryResp.Results[0]
require.Len(t, result.Results, 2)
require.Len(t, result.Results[0].Rows, 1)
assert.Equal(t, "a", result.Results[0].Rows[0]["col1"])
assert.Equal(t, "b", result.Results[0].Rows[0]["col2"])
require.Len(t, result.Results[1].Rows, 0)
require.NotNil(t, result.Results[1].Error)
assert.Equal(t, "some error!", *result.Results[1].Error)
}

View File

@ -0,0 +1,168 @@
package service
import (
"context"
"fmt"
"os"
"strconv"
"sync"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/pkg/errors"
)
type runLiveQueryRequest struct {
QueryIDs []uint `json:"query_ids"`
HostIDs []uint `json:"host_ids"`
}
type summaryPayload struct {
TargetedHostCount int `json:"targeted_host_count"`
RespondedHostCount int `json:"responded_host_count"`
}
type runLiveQueryResponse struct {
Summary summaryPayload `json:"summary"`
Err error `json:"error,omitempty"`
Results []queryCampaignResult `json:"live_query_results"`
}
func (r runLiveQueryResponse) error() error { return r.Err }
type queryResult struct {
HostID uint `json:"host_id"`
Rows []map[string]string `json:"rows"`
Error *string `json:"error"`
}
type queryCampaignResult struct {
QueryID uint `json:"query_id"`
Error *string `json:"error,omitempty"`
Results []queryResult `json:"results"`
}
func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*runLiveQueryRequest)
wg := sync.WaitGroup{}
resultsCh := make(chan queryCampaignResult)
counterMutex := sync.Mutex{}
counter := make(map[uint]struct{})
period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD")
if period == "" {
period = "90s"
}
duration, err := time.ParseDuration(period)
if err != nil {
duration = 90 * time.Second
}
for _, queryID := range req.QueryIDs {
queryID := queryID
wg.Add(1)
go func() {
defer wg.Done()
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: req.HostIDs})
if err != nil {
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
return
}
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
if err != nil {
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
return
}
defer cancelFunc()
defer func() {
err := svc.CompleteCampaign(ctx, campaign)
if err != nil {
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
}
}()
ticker := time.NewTicker(duration)
defer ticker.Stop()
var results []queryResult
loop:
for {
select {
case res := <-readChan:
// Receive a result and push it over the websocket
switch res := res.(type) {
case fleet.DistributedQueryResult:
results = append(results, queryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
counterMutex.Lock()
counter[res.Host.ID] = struct{}{}
counterMutex.Unlock()
}
case <-ticker.C:
break loop
}
}
resultsCh <- queryCampaignResult{QueryID: queryID, Results: results}
}()
}
go func() {
wg.Wait()
close(resultsCh)
}()
res := runLiveQueryResponse{
Summary: summaryPayload{
TargetedHostCount: len(req.HostIDs),
RespondedHostCount: 0,
},
}
for result := range resultsCh {
res.Results = append(res.Results, result)
}
res.Summary.RespondedHostCount = len(counter)
return res, nil
}
func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
// Open the channel from which we will receive incoming query results
// (probably from the redis pubsub implementation)
cancelCtx, cancelFunc := context.WithCancel(ctx)
readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign)
if err != nil {
cancelFunc()
return nil, nil, fmt.Errorf("cannot open read channel for campaign %d ", campaign.ID)
}
// Setting status to running will cause the query to be returned to the
// targets when they check in for their queries
campaign.Status = fleet.QueryRunning
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
cancelFunc()
return nil, nil, errors.Wrap(err, "error saving campaign state")
}
return readChan, cancelFunc, nil
}
func (svc *Service) CompleteCampaign(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error {
campaign.Status = fleet.QueryComplete
err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign)
if err != nil {
return errors.Wrap(err, "saving distributed campaign after complete")
}
err = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID)))
if err != nil {
return errors.Wrap(err, "stopping query after after complete")
}
return nil
}

View File

@ -44,7 +44,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware {
// If authorization was not checked, return a response that will
// marshal to a generic error and log that the check was missed.
if !authzctx.Checked {
if !authzctx.Checked() {
return nil, authz.CheckMissingWithResponse(response)
}

View File

@ -18,7 +18,7 @@ func TestAuthzCheck(t *testing.T) {
check := func(ctx context.Context, req interface{}) (interface{}, error) {
authCtx, ok := authz.FromContext(ctx)
require.True(t, ok)
authCtx.Checked = true
authCtx.SetChecked()
return struct{}{}, nil
}
check = checker.AuthzCheck()(check)

View File

@ -222,31 +222,17 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co
// Open the channel from which we will receive incoming query results
// (probably from the redis pubsub implementation)
cancelCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign)
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
if err != nil {
conn.WriteJSONError(fmt.Sprintf("cannot open read channel for campaign %d ", campaignID))
return
}
// Setting status to running will cause the query to be returned to the
// targets when they check in for their queries
campaign.Status = fleet.QueryRunning
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
conn.WriteJSONError("error saving campaign state")
conn.WriteJSONError("error getting campaign reader: " + err.Error())
return
}
defer cancelFunc()
// Setting the status to completed stops the query from being sent to
// targets. If this fails, there is a background job that will clean up
// this campaign.
defer func() {
campaign.Status = fleet.QueryComplete
_ = svc.ds.SaveDistributedQueryCampaign(ctx, campaign)
_ = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID)))
}()
defer svc.CompleteCampaign(ctx, campaign)
status := campaignStatus{
Status: campaignStatusPending,