diff --git a/changes/issue-1798-rest-api-for-live-queries b/changes/issue-1798-rest-api-for-live-queries new file mode 100644 index 000000000..87be8ad25 --- /dev/null +++ b/changes/issue-1798-rest-api-for-live-queries @@ -0,0 +1 @@ +* Add live query REST API diff --git a/docs/01-Using-Fleet/03-REST-API.md b/docs/01-Using-Fleet/03-REST-API.md index 3e6765388..89808ad89 100644 --- a/docs/01-Using-Fleet/03-REST-API.md +++ b/docs/01-Using-Fleet/03-REST-API.md @@ -3148,6 +3148,105 @@ o ] ``` +### Run live queries and retrieve 90 seconds of results using only REST + +The same endpoint used to run queries and gather the results over websockets can also be used to collect the results +over a fixed period of 90 seconds and get the results in the REST response. + +WARNING: this endpoint collects responses in memory and the elapsed time is capped at 90 seconds, regardless of whether +all results possible have been gathered. This can cause an autoscaling event, depending on the configuration, or the +fleet instance crashing. + +`GET /api/v1/fleet/queries/run` + +#### Parameters + + +| Name | Type | In | Description | +| --------- | ------ | ---- | --------------------------------------------- | +| query_ids | array | body | **Required**. The IDs of the queries to run as live queries. | +| host_ids | array | body | **Required**. The IDs of the hosts to run the live queries against. | + +#### Example + +`GET /api/v1/fleet/queries/run` + +##### Request body + +```json +{ + "query_ids": [ 1, 2 ], + "host_ids": [ 1, 4, 34, 27 ] +} +``` + +##### Default response + +```json +{ + "summary": { + "targeted_host_count": 4, + "responded_host_count": 2 + }, + "live_query_results": [ + { + "query_id": 2, + "results": [ + { + "host_id": 1, + "rows": [ + { + "build_distro": "10.12", + "build_platform": "darwin", + "config_hash": "7bb99fa2c8a998c9459ec71da3a84d66c592d6d3", + "config_valid": "1", + "extensions": "active", + "instance_id": "9a2ec7bf-4946-46ea-93bf-455e0bcbd068", + "pid": "23413", + "platform_mask": "21", + "start_time": "1635194306", + "uuid": "4C182AC7-75F7-5AF4-A74B-1E165ED35742", + "version": "4.9.0", + "watcher": "23412" + } + ], + "error": null + }, + { + "host_id": 2, + "rows": [], + "error": "no such table: os_version" + } + ] + } + ] +} +``` + +##### Default response + +`Status: 200` + +```json +{ + "campaign": { + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z", + "Metrics": { + "TotalHosts": 1, + "OnlineHosts": 0, + "OfflineHosts": 1, + "MissingInActionHosts": 0, + "NewHosts": 1 + }, + "id": 1, + "query_id": 3, + "status": 0, + "user_id": 1 + } +} +``` + --- ## Schedule diff --git a/server/authz/authz.go b/server/authz/authz.go index 91e8895e4..7e82c8974 100644 --- a/server/authz/authz.go +++ b/server/authz/authz.go @@ -67,7 +67,7 @@ func (a *Authorizer) SkipAuthorization(ctx context.Context) { // Mark the authorization context as checked (otherwise middleware will // error). if authctx, ok := authz_ctx.FromContext(ctx); ok { - authctx.Checked = true + authctx.SetChecked() } } @@ -81,7 +81,7 @@ func (a *Authorizer) Authorize(ctx context.Context, object, action interface{}) // Mark the authorization context as checked (otherwise middleware will // error). if authctx, ok := authz_ctx.FromContext(ctx); ok { - authctx.Checked = true + authctx.SetChecked() } subject := UserFromContext(ctx) diff --git a/server/contexts/authz/authz.go b/server/contexts/authz/authz.go index 5d089f086..83f04d20b 100644 --- a/server/contexts/authz/authz.go +++ b/server/contexts/authz/authz.go @@ -2,7 +2,10 @@ // request has had an authorization check performed before returning results. package authz -import "context" +import ( + "context" + "sync" +) type key int @@ -22,6 +25,19 @@ func FromContext(ctx context.Context) (*AuthorizationContext, bool) { // AuthorizationContext contains the context information used for the // authorization check. type AuthorizationContext struct { - // Checked indicates whether a call was made to check authorization for the request. - Checked bool + l sync.Mutex + // checked indicates whether a call was made to check authorization for the request. + checked bool +} + +func (a *AuthorizationContext) Checked() bool { + a.l.Lock() + defer a.l.Unlock() + return a.checked +} + +func (a *AuthorizationContext) SetChecked() { + a.l.Lock() + defer a.l.Unlock() + a.checked = true } diff --git a/server/contexts/logging/logging.go b/server/contexts/logging/logging.go index 7f4d2a33e..a47665070 100644 --- a/server/contexts/logging/logging.go +++ b/server/contexts/logging/logging.go @@ -2,6 +2,7 @@ package logging import ( "context" + "sync" "time" "github.com/fleetdm/fleet/v4/server/contexts/viewer" @@ -29,7 +30,7 @@ func FromContext(ctx context.Context) (*LoggingContext, bool) { // WithStartTime returns a context with logging.StartTime marked as the current time func WithStartTime(ctx context.Context) context.Context { if logCtx, ok := FromContext(ctx); ok { - logCtx.StartTime = time.Now() + logCtx.SetStartTime() } return ctx } @@ -37,7 +38,7 @@ func WithStartTime(ctx context.Context) context.Context { // WithErr returns a context with logging.Err set as the error provided func WithErr(ctx context.Context, err ...error) context.Context { if logCtx, ok := FromContext(ctx); ok { - logCtx.Errs = append(logCtx.Errs, err...) + logCtx.SetErrs(err...) } return ctx } @@ -45,7 +46,7 @@ func WithErr(ctx context.Context, err ...error) context.Context { // WithNoUser returns a context with logging.SkipUser set to true so user won't be logged func WithNoUser(ctx context.Context) context.Context { if logCtx, ok := FromContext(ctx); ok { - logCtx.SkipUser = true + logCtx.SetSkipUser() } return ctx } @@ -53,20 +54,22 @@ func WithNoUser(ctx context.Context) context.Context { // WithExtras returns a context with logging.Extras set as the values provided func WithExtras(ctx context.Context, extras ...interface{}) context.Context { if logCtx, ok := FromContext(ctx); ok { - logCtx.Extras = append(logCtx.Extras, extras...) + logCtx.SetExtras(extras...) } return ctx } func WithLevel(ctx context.Context, level func(kitlog.Logger) kitlog.Logger) context.Context { if logCtx, ok := FromContext(ctx); ok { - logCtx.ForceLevel = level + logCtx.SetForceLevel(level) } return ctx } // LoggingContext contains the context information for logging the current request type LoggingContext struct { + l sync.Mutex + StartTime time.Time Errs []error Extras []interface{} @@ -74,8 +77,37 @@ type LoggingContext struct { ForceLevel func(kitlog.Logger) kitlog.Logger } +func (l *LoggingContext) SetForceLevel(level func(kitlog.Logger) kitlog.Logger) { + l.l.Lock() + defer l.l.Unlock() + l.ForceLevel = level +} +func (l *LoggingContext) SetExtras(extras ...interface{}) { + l.l.Lock() + defer l.l.Unlock() + l.Extras = append(l.Extras, extras...) +} +func (l *LoggingContext) SetSkipUser() { + l.l.Lock() + defer l.l.Unlock() + l.SkipUser = true +} +func (l *LoggingContext) SetStartTime() { + l.l.Lock() + defer l.l.Unlock() + l.StartTime = time.Now() +} +func (l *LoggingContext) SetErrs(err ...error) { + l.l.Lock() + defer l.l.Unlock() + l.Errs = append(l.Errs, err...) +} + // Log logs the data within the context func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) { + l.l.Lock() + defer l.l.Unlock() + if l.ForceLevel != nil { logger = l.ForceLevel(logger) } else if l.Errs != nil || len(l.Extras) > 0 { diff --git a/server/datastore/mysql/campaigns.go b/server/datastore/mysql/campaigns.go index 8e43fe5a5..dcd46c196 100644 --- a/server/datastore/mysql/campaigns.go +++ b/server/datastore/mysql/campaigns.go @@ -64,6 +64,15 @@ func (d *Datastore) SaveDistributedQueryCampaign(ctx context.Context, camp *flee return nil } +func (d *Datastore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) { + var campaigns []*fleet.DistributedQueryCampaign + err := sqlx.SelectContext(ctx, d.reader, &campaigns, `SELECT * FROM distributed_query_campaigns WHERE query_id=?`, queryID) + if err != nil { + return nil, errors.Wrap(err, "getting campaigns for query") + } + return campaigns, nil +} + func (d *Datastore) DistributedQueryCampaignTargetIDs(ctx context.Context, id uint) (*fleet.HostTargets, error) { sqlStatement := ` SELECT * FROM distributed_query_campaign_targets WHERE distributed_query_campaign_id = ? diff --git a/server/fleet/campaigns.go b/server/fleet/campaigns.go index 421449bfe..da22231ca 100644 --- a/server/fleet/campaigns.go +++ b/server/fleet/campaigns.go @@ -31,17 +31,6 @@ type DistributedQueryCampaignTarget struct { TargetID uint `db:"target_id"` } -// DistributedQueryExecutionStatus is the status of a distributed query -// execution on a single host. -type DistributedQueryExecutionStatus int - -const ( - ExecutionWaiting DistributedQueryExecutionStatus = iota - ExecutionRequested - ExecutionSucceeded - ExecutionFailed -) - // DistributedQueryResult is the result returned from the execution of a // distributed query on a single host. type DistributedQueryResult struct { diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 43a8fdb8a..04a0642a9 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -86,6 +86,8 @@ type Datastore interface { // easier to test. The return values indicate how many campaigns were expired and any error. CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error) + DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*DistributedQueryCampaign, error) + /////////////////////////////////////////////////////////////////////////////// // PackStore is the datastore interface for managing query packs. diff --git a/server/fleet/service.go b/server/fleet/service.go index ebc3600e7..9698a6401 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -221,6 +221,9 @@ type Service interface { // go-kit RPC style. StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint) + GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) + CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error + /////////////////////////////////////////////////////////////////////////////// // AgentOptionsService diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 8ab166fb9..4e5fdfece 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -75,6 +75,8 @@ type NewDistributedQueryCampaignTargetFunc func(ctx context.Context, target *fle type CleanupDistributedQueryCampaignsFunc func(ctx context.Context, now time.Time) (expired uint, err error) +type DistributedQueryCampaignsForQueryFunc func(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) + type ApplyPackSpecsFunc func(ctx context.Context, specs []*fleet.PackSpec) error type GetPackSpecsFunc func(ctx context.Context) ([]*fleet.PackSpec, error) @@ -406,6 +408,9 @@ type DataStore struct { CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFuncInvoked bool + DistributedQueryCampaignsForQueryFunc DistributedQueryCampaignsForQueryFunc + DistributedQueryCampaignsForQueryFuncInvoked bool + ApplyPackSpecsFunc ApplyPackSpecsFunc ApplyPackSpecsFuncInvoked bool @@ -918,6 +923,11 @@ func (s *DataStore) CleanupDistributedQueryCampaigns(ctx context.Context, now ti return s.CleanupDistributedQueryCampaignsFunc(ctx, now) } +func (s *DataStore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) { + s.DistributedQueryCampaignsForQueryFuncInvoked = true + return s.DistributedQueryCampaignsForQueryFunc(ctx, queryID) +} + func (s *DataStore) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) error { s.ApplyPackSpecsFuncInvoked = true return s.ApplyPackSpecsFunc(ctx, specs) diff --git a/server/service/handler.go b/server/service/handler.go index 467ca3f15..693bd1729 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -570,15 +570,15 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users") r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite") r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user") - r.Handle("/api/v1/fleet/users/{id}", h.GetUser).Methods("GET").Name("get_user") - r.Handle("/api/v1/fleet/users/{id}", h.ModifyUser).Methods("PATCH").Name("modify_user") - r.Handle("/api/v1/fleet/users/{id}", h.DeleteUser).Methods("DELETE").Name("delete_user") - r.Handle("/api/v1/fleet/users/{id}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset") - r.Handle("/api/v1/fleet/users/{id}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user") - r.Handle("/api/v1/fleet/users/{id}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user") + r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.GetUser).Methods("GET").Name("get_user") + r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.ModifyUser).Methods("PATCH").Name("modify_user") + r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.DeleteUser).Methods("DELETE").Name("delete_user") + r.Handle("/api/v1/fleet/users/{id:[0-9]+}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset") + r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user") + r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user") - r.Handle("/api/v1/fleet/sessions/{id}", h.GetSessionInfo).Methods("GET").Name("get_session_info") - r.Handle("/api/v1/fleet/sessions/{id}", h.DeleteSession).Methods("DELETE").Name("delete_session") + r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info") + r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session") r.Handle("/api/v1/fleet/config/certificate", h.GetCertificate).Methods("GET").Name("get_certificate") r.Handle("/api/v1/fleet/config", h.GetAppConfig).Methods("GET").Name("get_app_config") @@ -587,17 +587,17 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/spec/enroll_secret", h.GetEnrollSecretSpec).Methods("GET").Name("get_enroll_secret_spec") r.Handle("/api/v1/fleet/invites", h.CreateInvite).Methods("POST").Name("create_invite") r.Handle("/api/v1/fleet/invites", h.ListInvites).Methods("GET").Name("list_invites") - r.Handle("/api/v1/fleet/invites/{id}", h.DeleteInvite).Methods("DELETE").Name("delete_invite") + r.Handle("/api/v1/fleet/invites/{id:[0-9]+}", h.DeleteInvite).Methods("DELETE").Name("delete_invite") r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite") r.Handle("/api/v1/fleet/email/change/{token}", h.ChangeEmail).Methods("GET").Name("change_email") - r.Handle("/api/v1/fleet/queries/{id}", h.GetQuery).Methods("GET").Name("get_query") + r.Handle("/api/v1/fleet/queries/{id:[0-9]+}", h.GetQuery).Methods("GET").Name("get_query") r.Handle("/api/v1/fleet/queries", h.ListQueries).Methods("GET").Name("list_queries") r.Handle("/api/v1/fleet/queries", h.CreateQuery).Methods("POST").Name("create_query") - r.Handle("/api/v1/fleet/queries/{id}", h.ModifyQuery).Methods("PATCH").Name("modify_query") + r.Handle("/api/v1/fleet/queries/{id:[0-9]+}", h.ModifyQuery).Methods("PATCH").Name("modify_query") r.Handle("/api/v1/fleet/queries/{name}", h.DeleteQuery).Methods("DELETE").Name("delete_query") - r.Handle("/api/v1/fleet/queries/id/{id}", h.DeleteQueryByID).Methods("DELETE").Name("delete_query_by_id") + r.Handle("/api/v1/fleet/queries/id/{id:[0-9]+}", h.DeleteQueryByID).Methods("DELETE").Name("delete_query_by_id") r.Handle("/api/v1/fleet/queries/delete", h.DeleteQueries).Methods("POST").Name("delete_queries") r.Handle("/api/v1/fleet/spec/queries", h.ApplyQuerySpecs).Methods("POST").Name("apply_query_specs") r.Handle("/api/v1/fleet/spec/queries", h.GetQuerySpecs).Methods("GET").Name("get_query_specs") @@ -606,41 +606,41 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names") r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack") - r.Handle("/api/v1/fleet/packs/{id}", h.ModifyPack).Methods("PATCH").Name("modify_pack") + r.Handle("/api/v1/fleet/packs/{id:[0-9]+}", h.ModifyPack).Methods("PATCH").Name("modify_pack") r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs") r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack") - r.Handle("/api/v1/fleet/packs/id/{id}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id") - r.Handle("/api/v1/fleet/packs/{id}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack") + r.Handle("/api/v1/fleet/packs/id/{id:[0-9]+}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id") + r.Handle("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack") r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query") - r.Handle("/api/v1/fleet/schedule/{id}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query") - r.Handle("/api/v1/fleet/schedule/{id}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query") - r.Handle("/api/v1/fleet/schedule/{id}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query") + r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query") + r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query") + r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query") r.Handle("/api/v1/fleet/spec/packs", h.ApplyPackSpecs).Methods("POST").Name("apply_pack_specs") r.Handle("/api/v1/fleet/spec/packs", h.GetPackSpecs).Methods("GET").Name("get_pack_specs") r.Handle("/api/v1/fleet/spec/packs/{name}", h.GetPackSpec).Methods("GET").Name("get_pack_spec") r.Handle("/api/v1/fleet/global/schedule", h.GetGlobalSchedule).Methods("GET").Name("set_global_schedule") r.Handle("/api/v1/fleet/global/schedule", h.GlobalScheduleQuery).Methods("POST").Name("add_to_global_schedule") - r.Handle("/api/v1/fleet/global/schedule/{id}", h.ModifyGlobalSchedule).Methods("PATCH").Name("modify_global_schedule") - r.Handle("/api/v1/fleet/global/schedule/{id}", h.DeleteGlobalSchedule).Methods("DELETE").Name("delete_global_schedule") + r.Handle("/api/v1/fleet/global/schedule/{id:[0-9]+}", h.ModifyGlobalSchedule).Methods("PATCH").Name("modify_global_schedule") + r.Handle("/api/v1/fleet/global/schedule/{id:[0-9]+}", h.DeleteGlobalSchedule).Methods("DELETE").Name("delete_global_schedule") r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label") - r.Handle("/api/v1/fleet/labels/{id}", h.ModifyLabel).Methods("PATCH").Name("modify_label") - r.Handle("/api/v1/fleet/labels/{id}", h.GetLabel).Methods("GET").Name("get_label") + r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.ModifyLabel).Methods("PATCH").Name("modify_label") + r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.GetLabel).Methods("GET").Name("get_label") r.Handle("/api/v1/fleet/labels", h.ListLabels).Methods("GET").Name("list_labels") - r.Handle("/api/v1/fleet/labels/{id}/hosts", h.ListHostsInLabel).Methods("GET").Name("list_hosts_in_label") + r.Handle("/api/v1/fleet/labels/{id:[0-9]+}/hosts", h.ListHostsInLabel).Methods("GET").Name("list_hosts_in_label") r.Handle("/api/v1/fleet/labels/{name}", h.DeleteLabel).Methods("DELETE").Name("delete_label") - r.Handle("/api/v1/fleet/labels/id/{id}", h.DeleteLabelByID).Methods("DELETE").Name("delete_label_by_id") + r.Handle("/api/v1/fleet/labels/id/{id:[0-9]+}", h.DeleteLabelByID).Methods("DELETE").Name("delete_label_by_id") r.Handle("/api/v1/fleet/spec/labels", h.ApplyLabelSpecs).Methods("POST").Name("apply_label_specs") r.Handle("/api/v1/fleet/spec/labels", h.GetLabelSpecs).Methods("GET").Name("get_label_specs") r.Handle("/api/v1/fleet/spec/labels/{name}", h.GetLabelSpec).Methods("GET").Name("get_label_spec") r.Handle("/api/v1/fleet/host_summary", h.GetHostSummary).Methods("GET").Name("get_host_summary") r.Handle("/api/v1/fleet/hosts/identifier/{identifier}", h.HostByIdentifier).Methods("GET").Name("host_by_identifier") - r.Handle("/api/v1/fleet/hosts/{id}", h.DeleteHost).Methods("DELETE").Name("delete_host") + r.Handle("/api/v1/fleet/hosts/{id:[0-9]+}", h.DeleteHost).Methods("DELETE").Name("delete_host") r.Handle("/api/v1/fleet/hosts/transfer", h.AddHostsToTeam).Methods("POST").Name("add_hosts_to_team") r.Handle("/api/v1/fleet/hosts/transfer/filter", h.AddHostsToTeamByFilter).Methods("POST").Name("add_hosts_to_team_by_filter") - r.Handle("/api/v1/fleet/hosts/{id}/refetch", h.RefetchHost).Methods("POST").Name("refetch_host") + r.Handle("/api/v1/fleet/hosts/{id:[0-9]+}/refetch", h.RefetchHost).Methods("POST").Name("refetch_host") r.Handle("/api/v1/fleet/targets", h.SearchTargets).Methods("POST").Name("search_targets") @@ -650,18 +650,18 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query") r.Handle("/api/v1/fleet/carves", h.ListCarves).Methods("GET").Name("list_carves") - r.Handle("/api/v1/fleet/carves/{id}", h.GetCarve).Methods("GET").Name("get_carve") - r.Handle("/api/v1/fleet/carves/{id}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block") + r.Handle("/api/v1/fleet/carves/{id:[0-9]+}", h.GetCarve).Methods("GET").Name("get_carve") + r.Handle("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block") r.Handle("/api/v1/fleet/teams", h.CreateTeam).Methods("POST").Name("create_team") r.Handle("/api/v1/fleet/teams", h.ListTeams).Methods("GET").Name("list_teams") - r.Handle("/api/v1/fleet/teams/{id}", h.ModifyTeam).Methods("PATCH").Name("modify_team") - r.Handle("/api/v1/fleet/teams/{id}", h.DeleteTeam).Methods("DELETE").Name("delete_team") - r.Handle("/api/v1/fleet/teams/{id}/agent_options", h.ModifyTeamAgentOptions).Methods("POST").Name("modify_team_agent_options") - r.Handle("/api/v1/fleet/teams/{id}/users", h.ListTeamUsers).Methods("GET").Name("team_users") - r.Handle("/api/v1/fleet/teams/{id}/users", h.AddTeamUsers).Methods("PATCH").Name("add_team_users") - r.Handle("/api/v1/fleet/teams/{id}/users", h.DeleteTeamUsers).Methods("DELETE").Name("delete_team_users") - r.Handle("/api/v1/fleet/teams/{id}/secrets", h.TeamEnrollSecrets).Methods("GET").Name("get_team_enroll_secrets") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.ModifyTeam).Methods("PATCH").Name("modify_team") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.DeleteTeam).Methods("DELETE").Name("delete_team") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/agent_options", h.ModifyTeamAgentOptions).Methods("POST").Name("modify_team_agent_options") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.ListTeamUsers).Methods("GET").Name("team_users") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.AddTeamUsers).Methods("PATCH").Name("add_team_users") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.DeleteTeamUsers).Methods("DELETE").Name("delete_team_users") + r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/secrets", h.TeamEnrollSecrets).Methods("GET").Name("get_team_enroll_secrets") r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent") r.Handle("/api/v1/osquery/config", h.GetClientConfig).Methods("POST").Name("get_client_config") @@ -719,6 +719,8 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{}) e.GET("/api/v1/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{}) e.GET("/api/v1/fleet/hosts/count", countHostsEndpoint, countHostsRequest{}) + + e.GET("/api/v1/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{}) } // TODO: this duplicates the one in makeKitHandler diff --git a/server/service/integration_live_queries_test.go b/server/service/integration_live_queries_test.go new file mode 100644 index 000000000..ee407cac9 --- /dev/null +++ b/server/service/integration_live_queries_test.go @@ -0,0 +1,398 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "os" + "sort" + "sync" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/live_query" + "github.com/fleetdm/fleet/v4/server/pubsub" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestIntegrationLiveQueriesTestSuite(t *testing.T) { + testingSuite := new(liveQueriesTestSuite) + testingSuite.s = &testingSuite.Suite + suite.Run(t, testingSuite) +} + +type liveQueriesTestSuite struct { + withServer + suite.Suite + + lq *live_query.MockLiveQuery + hosts []*fleet.Host +} + +func (s *liveQueriesTestSuite) SetupSuite() { + require.NoError(s.T(), os.Setenv("FLEET_LIVE_QUERY_REST_PERIOD", "5s")) + + s.withDS.SetupSuite("liveQueriesTestSuite") + + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + s.lq = lq + + users, server := RunServerForTestsWithDS(s.T(), s.ds, TestServerOpts{Lq: lq, Rs: rs}) + s.server = server + s.users = users + s.token = getTestAdminToken(s.T(), s.server) + + t := s.T() + for i := 0; i < 3; i++ { + host, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-time.Duration(i) * time.Minute), + OsqueryHostID: fmt.Sprintf("%s%d", t.Name(), i), + NodeKey: fmt.Sprintf("%s%d", t.Name(), i), + UUID: fmt.Sprintf("%s%d", t.Name(), i), + Hostname: fmt.Sprintf("%sfoo.local%d", t.Name(), i), + }) + require.NoError(s.T(), err) + s.hosts = append(s.hosts, host) + } +} + +func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { + t := s.T() + + host := s.hosts[0] + + q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"}) + require.NoError(t, err) + + s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil) + s.lq.On("StopQuery", mock.Anything).Return(nil) + + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{q1.ID}, + HostIDs: []uint{host.ID}, + } + liveQueryResp := runLiveQueryResponse{} + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + }() + + // Give the above call a couple of seconds to create the campaign + time.Sleep(2 * time.Second) + + cid := getCIDForQ(s, q1) + + distributedReq := submitDistributedQueryResultsRequest{ + NodeKey: host.NodeKey, + Results: map[string][]map[string]string{ + hostDistributedQueryPrefix + cid: {{"col1": "a", "col2": "b"}}, + }, + Statuses: map[string]fleet.OsqueryStatus{ + hostDistributedQueryPrefix + cid: 0, + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid: "some msg", + }, + } + distributedResp := submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + + wg.Wait() + + require.Len(t, liveQueryResp.Results, 1) + assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount) + assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) + require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1) + assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"]) + assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"]) +} + +func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() { + t := s.T() + + host := s.hosts[0] + + q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"}) + require.NoError(t, err) + + q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 2 from osquery;", Description: "desc2", Name: t.Name() + "query2"}) + require.NoError(t, err) + + s.lq.On("QueriesForHost", host.ID).Return(map[string]string{ + fmt.Sprint(q1.ID): "select 1 from osquery;", + fmt.Sprint(q2.ID): "select 2 from osquery;", + }, nil) + s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil) + s.lq.On("StopQuery", mock.Anything).Return(nil) + + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{q1.ID, q2.ID}, + HostIDs: []uint{host.ID}, + } + liveQueryResp := runLiveQueryResponse{} + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + }() + + // Give the above call a couple of seconds to create the campaign + time.Sleep(2 * time.Second) + + cid1 := getCIDForQ(s, q1) + cid2 := getCIDForQ(s, q2) + + distributedReq := submitDistributedQueryResultsRequest{ + NodeKey: host.NodeKey, + Results: map[string][]map[string]string{ + hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}}, + hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}}, + }, + Statuses: map[string]fleet.OsqueryStatus{ + hostDistributedQueryPrefix + cid1: 0, + hostDistributedQueryPrefix + cid2: 0, + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid1: "some msg", + hostDistributedQueryPrefix + cid2: "some other msg", + }, + } + distributedResp := submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + + wg.Wait() + + require.Len(t, liveQueryResp.Results, 2) + assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount) + + sort.Slice(liveQueryResp.Results, func(i, j int) bool { + return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID + }) + + require.True(t, q1.ID < q2.ID) + + assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) + require.Len(t, liveQueryResp.Results[0].Results, 1) + q1Results := liveQueryResp.Results[0].Results[0] + require.Len(t, q1Results.Rows, 1) + assert.Equal(t, "a", q1Results.Rows[0]["col1"]) + assert.Equal(t, "b", q1Results.Rows[0]["col2"]) + + assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID) + require.Len(t, liveQueryResp.Results[1].Results, 1) + q2Results := liveQueryResp.Results[1].Results[0] + require.Len(t, q2Results.Rows, 2) + assert.Equal(t, "c", q2Results.Rows[0]["col3"]) + assert.Equal(t, "d", q2Results.Rows[0]["col4"]) + assert.Equal(t, "e", q2Results.Rows[1]["col3"]) + assert.Equal(t, "f", q2Results.Rows[1]["col4"]) +} + +func getCIDForQ(s *liveQueriesTestSuite, q1 *fleet.Query) string { + t := s.T() + campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID) + require.NoError(t, err) + require.Len(t, campaigns, 1) + cid1 := fmt.Sprint(campaigns[0].ID) + return cid1 +} + +func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() { + t := s.T() + + h1 := s.hosts[0] + h2 := s.hosts[1] + + q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"}) + require.NoError(t, err) + + q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 2 from osquery;", Description: "desc2", Name: t.Name() + "query2"}) + require.NoError(t, err) + + s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{ + fmt.Sprint(q1.ID): "select 1 from osquery;", + fmt.Sprint(q2.ID): "select 2 from osquery;", + }, nil) + s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{ + fmt.Sprint(q1.ID): "select 1 from osquery;", + fmt.Sprint(q2.ID): "select 2 from osquery;", + }, nil) + s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{h1.ID, h2.ID}).Return(nil) + s.lq.On("StopQuery", mock.Anything).Return(nil) + + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{q1.ID, q2.ID}, + HostIDs: []uint{h1.ID, h2.ID}, + } + liveQueryResp := runLiveQueryResponse{} + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + }() + + // Give the above call a couple of seconds to create the campaign + time.Sleep(2 * time.Second) + cid1 := getCIDForQ(s, q1) + cid2 := getCIDForQ(s, q2) + for i, h := range []*fleet.Host{h1, h2} { + distributedReq := submitDistributedQueryResultsRequest{ + NodeKey: h.NodeKey, + Results: map[string][]map[string]string{ + hostDistributedQueryPrefix + cid1: {{"col1": fmt.Sprintf("a%d", i), "col2": fmt.Sprintf("b%d", i)}}, + hostDistributedQueryPrefix + cid2: {{"col3": fmt.Sprintf("c%d", i), "col4": fmt.Sprintf("d%d", i)}, {"col3": fmt.Sprintf("e%d", i), "col4": fmt.Sprintf("f%d", i)}}, + }, + Statuses: map[string]fleet.OsqueryStatus{ + hostDistributedQueryPrefix + cid1: 0, + hostDistributedQueryPrefix + cid2: 0, + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid1: "some msg", + hostDistributedQueryPrefix + cid2: "some other msg", + }, + } + distributedResp := submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + } + + wg.Wait() + + require.Len(t, liveQueryResp.Results, 2) // 2 queries + assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount) + + sort.Slice(liveQueryResp.Results, func(i, j int) bool { + return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID + }) + + require.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) + require.Len(t, liveQueryResp.Results[0].Results, 2) + for i, r := range liveQueryResp.Results[0].Results { + require.Len(t, r.Rows, 1) + assert.Equal(t, fmt.Sprintf("a%d", i), r.Rows[0]["col1"]) + assert.Equal(t, fmt.Sprintf("b%d", i), r.Rows[0]["col2"]) + } + + require.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID) + require.Len(t, liveQueryResp.Results[1].Results, 2) + for i, r := range liveQueryResp.Results[1].Results { + require.Len(t, r.Rows, 2) + assert.Equal(t, fmt.Sprintf("c%d", i), r.Rows[0]["col3"]) + assert.Equal(t, fmt.Sprintf("d%d", i), r.Rows[0]["col4"]) + assert.Equal(t, fmt.Sprintf("e%d", i), r.Rows[1]["col3"]) + assert.Equal(t, fmt.Sprintf("f%d", i), r.Rows[1]["col4"]) + } +} + +func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() { + t := s.T() + + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{999}, + HostIDs: []uint{888}, + } + liveQueryResp := runLiveQueryResponse{} + + s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + + require.Len(t, liveQueryResp.Results, 1) + assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount) + require.NotNil(t, liveQueryResp.Results[0].Error) + assert.Equal(t, "selecting query: sql: no rows in result set", *liveQueryResp.Results[0].Error) +} + +func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() { + t := s.T() + + h1 := s.hosts[0] + h2 := s.hosts[1] + + q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"}) + require.NoError(t, err) + + s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil) + s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil) + s.lq.On("StopQuery", mock.Anything).Return(nil) + + liveQueryRequest := runLiveQueryRequest{ + QueryIDs: []uint{q1.ID}, + HostIDs: []uint{h1.ID, h2.ID}, + } + liveQueryResp := runLiveQueryResponse{} + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) + }() + + // Give the above call a couple of seconds to create the campaign + time.Sleep(2 * time.Second) + cid1 := getCIDForQ(s, q1) + distributedReq := submitDistributedQueryResultsRequest{ + NodeKey: h1.NodeKey, + Results: map[string][]map[string]string{ + hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}}, + }, + Statuses: map[string]fleet.OsqueryStatus{ + hostDistributedQueryPrefix + cid1: 0, + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid1: "some msg", + }, + } + distributedResp := submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + + distributedReq = submitDistributedQueryResultsRequest{ + NodeKey: h2.NodeKey, + Results: map[string][]map[string]string{ + hostDistributedQueryPrefix + cid1: {}, + }, + Statuses: map[string]fleet.OsqueryStatus{ + hostDistributedQueryPrefix + cid1: 123, + }, + Messages: map[string]string{ + hostDistributedQueryPrefix + cid1: "some error!", + }, + } + distributedResp = submitDistributedQueryResultsResponse{} + s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp) + + wg.Wait() + + require.Len(t, liveQueryResp.Results, 1) + assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount) + + result := liveQueryResp.Results[0] + require.Len(t, result.Results, 2) + require.Len(t, result.Results[0].Rows, 1) + assert.Equal(t, "a", result.Results[0].Rows[0]["col1"]) + assert.Equal(t, "b", result.Results[0].Rows[0]["col2"]) + require.Len(t, result.Results[1].Rows, 0) + require.NotNil(t, result.Results[1].Error) + assert.Equal(t, "some error!", *result.Results[1].Error) +} diff --git a/server/service/live_queries.go b/server/service/live_queries.go new file mode 100644 index 000000000..f4b5f9ffc --- /dev/null +++ b/server/service/live_queries.go @@ -0,0 +1,168 @@ +package service + +import ( + "context" + "fmt" + "os" + "strconv" + "sync" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/pkg/errors" +) + +type runLiveQueryRequest struct { + QueryIDs []uint `json:"query_ids"` + HostIDs []uint `json:"host_ids"` +} + +type summaryPayload struct { + TargetedHostCount int `json:"targeted_host_count"` + RespondedHostCount int `json:"responded_host_count"` +} + +type runLiveQueryResponse struct { + Summary summaryPayload `json:"summary"` + Err error `json:"error,omitempty"` + + Results []queryCampaignResult `json:"live_query_results"` +} + +func (r runLiveQueryResponse) error() error { return r.Err } + +type queryResult struct { + HostID uint `json:"host_id"` + Rows []map[string]string `json:"rows"` + Error *string `json:"error"` +} + +type queryCampaignResult struct { + QueryID uint `json:"query_id"` + Error *string `json:"error,omitempty"` + Results []queryResult `json:"results"` +} + +func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*runLiveQueryRequest) + wg := sync.WaitGroup{} + + resultsCh := make(chan queryCampaignResult) + + counterMutex := sync.Mutex{} + counter := make(map[uint]struct{}) + + period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD") + if period == "" { + period = "90s" + } + duration, err := time.ParseDuration(period) + if err != nil { + duration = 90 * time.Second + } + + for _, queryID := range req.QueryIDs { + queryID := queryID + wg.Add(1) + go func() { + defer wg.Done() + campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: req.HostIDs}) + if err != nil { + resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())} + return + } + + readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign) + if err != nil { + resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())} + return + } + defer cancelFunc() + + defer func() { + err := svc.CompleteCampaign(ctx, campaign) + if err != nil { + resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())} + } + }() + + ticker := time.NewTicker(duration) + defer ticker.Stop() + + var results []queryResult + loop: + for { + select { + case res := <-readChan: + // Receive a result and push it over the websocket + switch res := res.(type) { + case fleet.DistributedQueryResult: + results = append(results, queryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error}) + counterMutex.Lock() + counter[res.Host.ID] = struct{}{} + counterMutex.Unlock() + } + case <-ticker.C: + break loop + } + } + resultsCh <- queryCampaignResult{QueryID: queryID, Results: results} + }() + } + + go func() { + wg.Wait() + close(resultsCh) + }() + + res := runLiveQueryResponse{ + Summary: summaryPayload{ + TargetedHostCount: len(req.HostIDs), + RespondedHostCount: 0, + }, + } + + for result := range resultsCh { + res.Results = append(res.Results, result) + } + + res.Summary.RespondedHostCount = len(counter) + + return res, nil +} + +func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) { + // Open the channel from which we will receive incoming query results + // (probably from the redis pubsub implementation) + cancelCtx, cancelFunc := context.WithCancel(ctx) + + readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign) + if err != nil { + cancelFunc() + return nil, nil, fmt.Errorf("cannot open read channel for campaign %d ", campaign.ID) + } + + // Setting status to running will cause the query to be returned to the + // targets when they check in for their queries + campaign.Status = fleet.QueryRunning + if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil { + cancelFunc() + return nil, nil, errors.Wrap(err, "error saving campaign state") + } + + return readChan, cancelFunc, nil +} + +func (svc *Service) CompleteCampaign(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { + campaign.Status = fleet.QueryComplete + err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign) + if err != nil { + return errors.Wrap(err, "saving distributed campaign after complete") + } + err = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID))) + if err != nil { + return errors.Wrap(err, "stopping query after after complete") + } + return nil +} diff --git a/server/service/middleware/authzcheck/authzcheck.go b/server/service/middleware/authzcheck/authzcheck.go index d8e79000b..44db75b40 100644 --- a/server/service/middleware/authzcheck/authzcheck.go +++ b/server/service/middleware/authzcheck/authzcheck.go @@ -44,7 +44,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware { // If authorization was not checked, return a response that will // marshal to a generic error and log that the check was missed. - if !authzctx.Checked { + if !authzctx.Checked() { return nil, authz.CheckMissingWithResponse(response) } diff --git a/server/service/middleware/authzcheck/authzcheck_test.go b/server/service/middleware/authzcheck/authzcheck_test.go index f7aaec1b6..d95aafdc8 100644 --- a/server/service/middleware/authzcheck/authzcheck_test.go +++ b/server/service/middleware/authzcheck/authzcheck_test.go @@ -18,7 +18,7 @@ func TestAuthzCheck(t *testing.T) { check := func(ctx context.Context, req interface{}) (interface{}, error) { authCtx, ok := authz.FromContext(ctx) require.True(t, ok) - authCtx.Checked = true + authCtx.SetChecked() return struct{}{}, nil } check = checker.AuthzCheck()(check) diff --git a/server/service/service_campaigns.go b/server/service/service_campaigns.go index ec1cd90e1..15dd1be6f 100644 --- a/server/service/service_campaigns.go +++ b/server/service/service_campaigns.go @@ -222,31 +222,17 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co // Open the channel from which we will receive incoming query results // (probably from the redis pubsub implementation) - cancelCtx, cancelFunc := context.WithCancel(ctx) - defer cancelFunc() - - readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign) + readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign) if err != nil { - conn.WriteJSONError(fmt.Sprintf("cannot open read channel for campaign %d ", campaignID)) - return - } - - // Setting status to running will cause the query to be returned to the - // targets when they check in for their queries - campaign.Status = fleet.QueryRunning - if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil { - conn.WriteJSONError("error saving campaign state") + conn.WriteJSONError("error getting campaign reader: " + err.Error()) return } + defer cancelFunc() // Setting the status to completed stops the query from being sent to // targets. If this fails, there is a background job that will clean up // this campaign. - defer func() { - campaign.Status = fleet.QueryComplete - _ = svc.ds.SaveDistributedQueryCampaign(ctx, campaign) - _ = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID))) - }() + defer svc.CompleteCampaign(ctx, campaign) status := campaignStatus{ Status: campaignStatusPending,