mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Issue 1798 rest api for live queries (#2675)
* wip * wip * Make tests not flaky * Add changes file * Make logging and authz thread safe * Move cancel defer below err check
This commit is contained in:
parent
58b26ae1ba
commit
79f2c7c391
1
changes/issue-1798-rest-api-for-live-queries
Normal file
1
changes/issue-1798-rest-api-for-live-queries
Normal file
@ -0,0 +1 @@
|
|||||||
|
* Add live query REST API
|
@ -3148,6 +3148,105 @@ o
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Run live queries and retrieve 90 seconds of results using only REST
|
||||||
|
|
||||||
|
The same endpoint used to run queries and gather the results over websockets can also be used to collect the results
|
||||||
|
over a fixed period of 90 seconds and get the results in the REST response.
|
||||||
|
|
||||||
|
WARNING: this endpoint collects responses in memory and the elapsed time is capped at 90 seconds, regardless of whether
|
||||||
|
all results possible have been gathered. This can cause an autoscaling event, depending on the configuration, or the
|
||||||
|
fleet instance crashing.
|
||||||
|
|
||||||
|
`GET /api/v1/fleet/queries/run`
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
|
||||||
|
| Name | Type | In | Description |
|
||||||
|
| --------- | ------ | ---- | --------------------------------------------- |
|
||||||
|
| query_ids | array | body | **Required**. The IDs of the queries to run as live queries. |
|
||||||
|
| host_ids | array | body | **Required**. The IDs of the hosts to run the live queries against. |
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
`GET /api/v1/fleet/queries/run`
|
||||||
|
|
||||||
|
##### Request body
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"query_ids": [ 1, 2 ],
|
||||||
|
"host_ids": [ 1, 4, 34, 27 ]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Default response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"summary": {
|
||||||
|
"targeted_host_count": 4,
|
||||||
|
"responded_host_count": 2
|
||||||
|
},
|
||||||
|
"live_query_results": [
|
||||||
|
{
|
||||||
|
"query_id": 2,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"host_id": 1,
|
||||||
|
"rows": [
|
||||||
|
{
|
||||||
|
"build_distro": "10.12",
|
||||||
|
"build_platform": "darwin",
|
||||||
|
"config_hash": "7bb99fa2c8a998c9459ec71da3a84d66c592d6d3",
|
||||||
|
"config_valid": "1",
|
||||||
|
"extensions": "active",
|
||||||
|
"instance_id": "9a2ec7bf-4946-46ea-93bf-455e0bcbd068",
|
||||||
|
"pid": "23413",
|
||||||
|
"platform_mask": "21",
|
||||||
|
"start_time": "1635194306",
|
||||||
|
"uuid": "4C182AC7-75F7-5AF4-A74B-1E165ED35742",
|
||||||
|
"version": "4.9.0",
|
||||||
|
"watcher": "23412"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"error": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"host_id": 2,
|
||||||
|
"rows": [],
|
||||||
|
"error": "no such table: os_version"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Default response
|
||||||
|
|
||||||
|
`Status: 200`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"campaign": {
|
||||||
|
"created_at": "0001-01-01T00:00:00Z",
|
||||||
|
"updated_at": "0001-01-01T00:00:00Z",
|
||||||
|
"Metrics": {
|
||||||
|
"TotalHosts": 1,
|
||||||
|
"OnlineHosts": 0,
|
||||||
|
"OfflineHosts": 1,
|
||||||
|
"MissingInActionHosts": 0,
|
||||||
|
"NewHosts": 1
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
"query_id": 3,
|
||||||
|
"status": 0,
|
||||||
|
"user_id": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Schedule
|
## Schedule
|
||||||
|
@ -67,7 +67,7 @@ func (a *Authorizer) SkipAuthorization(ctx context.Context) {
|
|||||||
// Mark the authorization context as checked (otherwise middleware will
|
// Mark the authorization context as checked (otherwise middleware will
|
||||||
// error).
|
// error).
|
||||||
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
||||||
authctx.Checked = true
|
authctx.SetChecked()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ func (a *Authorizer) Authorize(ctx context.Context, object, action interface{})
|
|||||||
// Mark the authorization context as checked (otherwise middleware will
|
// Mark the authorization context as checked (otherwise middleware will
|
||||||
// error).
|
// error).
|
||||||
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
||||||
authctx.Checked = true
|
authctx.SetChecked()
|
||||||
}
|
}
|
||||||
|
|
||||||
subject := UserFromContext(ctx)
|
subject := UserFromContext(ctx)
|
||||||
|
@ -2,7 +2,10 @@
|
|||||||
// request has had an authorization check performed before returning results.
|
// request has had an authorization check performed before returning results.
|
||||||
package authz
|
package authz
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
type key int
|
type key int
|
||||||
|
|
||||||
@ -22,6 +25,19 @@ func FromContext(ctx context.Context) (*AuthorizationContext, bool) {
|
|||||||
// AuthorizationContext contains the context information used for the
|
// AuthorizationContext contains the context information used for the
|
||||||
// authorization check.
|
// authorization check.
|
||||||
type AuthorizationContext struct {
|
type AuthorizationContext struct {
|
||||||
// Checked indicates whether a call was made to check authorization for the request.
|
l sync.Mutex
|
||||||
Checked bool
|
// checked indicates whether a call was made to check authorization for the request.
|
||||||
|
checked bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthorizationContext) Checked() bool {
|
||||||
|
a.l.Lock()
|
||||||
|
defer a.l.Unlock()
|
||||||
|
return a.checked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthorizationContext) SetChecked() {
|
||||||
|
a.l.Lock()
|
||||||
|
defer a.l.Unlock()
|
||||||
|
a.checked = true
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package logging
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||||
@ -29,7 +30,7 @@ func FromContext(ctx context.Context) (*LoggingContext, bool) {
|
|||||||
// WithStartTime returns a context with logging.StartTime marked as the current time
|
// WithStartTime returns a context with logging.StartTime marked as the current time
|
||||||
func WithStartTime(ctx context.Context) context.Context {
|
func WithStartTime(ctx context.Context) context.Context {
|
||||||
if logCtx, ok := FromContext(ctx); ok {
|
if logCtx, ok := FromContext(ctx); ok {
|
||||||
logCtx.StartTime = time.Now()
|
logCtx.SetStartTime()
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
@ -37,7 +38,7 @@ func WithStartTime(ctx context.Context) context.Context {
|
|||||||
// WithErr returns a context with logging.Err set as the error provided
|
// WithErr returns a context with logging.Err set as the error provided
|
||||||
func WithErr(ctx context.Context, err ...error) context.Context {
|
func WithErr(ctx context.Context, err ...error) context.Context {
|
||||||
if logCtx, ok := FromContext(ctx); ok {
|
if logCtx, ok := FromContext(ctx); ok {
|
||||||
logCtx.Errs = append(logCtx.Errs, err...)
|
logCtx.SetErrs(err...)
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
@ -45,7 +46,7 @@ func WithErr(ctx context.Context, err ...error) context.Context {
|
|||||||
// WithNoUser returns a context with logging.SkipUser set to true so user won't be logged
|
// WithNoUser returns a context with logging.SkipUser set to true so user won't be logged
|
||||||
func WithNoUser(ctx context.Context) context.Context {
|
func WithNoUser(ctx context.Context) context.Context {
|
||||||
if logCtx, ok := FromContext(ctx); ok {
|
if logCtx, ok := FromContext(ctx); ok {
|
||||||
logCtx.SkipUser = true
|
logCtx.SetSkipUser()
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
@ -53,20 +54,22 @@ func WithNoUser(ctx context.Context) context.Context {
|
|||||||
// WithExtras returns a context with logging.Extras set as the values provided
|
// WithExtras returns a context with logging.Extras set as the values provided
|
||||||
func WithExtras(ctx context.Context, extras ...interface{}) context.Context {
|
func WithExtras(ctx context.Context, extras ...interface{}) context.Context {
|
||||||
if logCtx, ok := FromContext(ctx); ok {
|
if logCtx, ok := FromContext(ctx); ok {
|
||||||
logCtx.Extras = append(logCtx.Extras, extras...)
|
logCtx.SetExtras(extras...)
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithLevel(ctx context.Context, level func(kitlog.Logger) kitlog.Logger) context.Context {
|
func WithLevel(ctx context.Context, level func(kitlog.Logger) kitlog.Logger) context.Context {
|
||||||
if logCtx, ok := FromContext(ctx); ok {
|
if logCtx, ok := FromContext(ctx); ok {
|
||||||
logCtx.ForceLevel = level
|
logCtx.SetForceLevel(level)
|
||||||
}
|
}
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoggingContext contains the context information for logging the current request
|
// LoggingContext contains the context information for logging the current request
|
||||||
type LoggingContext struct {
|
type LoggingContext struct {
|
||||||
|
l sync.Mutex
|
||||||
|
|
||||||
StartTime time.Time
|
StartTime time.Time
|
||||||
Errs []error
|
Errs []error
|
||||||
Extras []interface{}
|
Extras []interface{}
|
||||||
@ -74,8 +77,37 @@ type LoggingContext struct {
|
|||||||
ForceLevel func(kitlog.Logger) kitlog.Logger
|
ForceLevel func(kitlog.Logger) kitlog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *LoggingContext) SetForceLevel(level func(kitlog.Logger) kitlog.Logger) {
|
||||||
|
l.l.Lock()
|
||||||
|
defer l.l.Unlock()
|
||||||
|
l.ForceLevel = level
|
||||||
|
}
|
||||||
|
func (l *LoggingContext) SetExtras(extras ...interface{}) {
|
||||||
|
l.l.Lock()
|
||||||
|
defer l.l.Unlock()
|
||||||
|
l.Extras = append(l.Extras, extras...)
|
||||||
|
}
|
||||||
|
func (l *LoggingContext) SetSkipUser() {
|
||||||
|
l.l.Lock()
|
||||||
|
defer l.l.Unlock()
|
||||||
|
l.SkipUser = true
|
||||||
|
}
|
||||||
|
func (l *LoggingContext) SetStartTime() {
|
||||||
|
l.l.Lock()
|
||||||
|
defer l.l.Unlock()
|
||||||
|
l.StartTime = time.Now()
|
||||||
|
}
|
||||||
|
func (l *LoggingContext) SetErrs(err ...error) {
|
||||||
|
l.l.Lock()
|
||||||
|
defer l.l.Unlock()
|
||||||
|
l.Errs = append(l.Errs, err...)
|
||||||
|
}
|
||||||
|
|
||||||
// Log logs the data within the context
|
// Log logs the data within the context
|
||||||
func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
|
func (l *LoggingContext) Log(ctx context.Context, logger kitlog.Logger) {
|
||||||
|
l.l.Lock()
|
||||||
|
defer l.l.Unlock()
|
||||||
|
|
||||||
if l.ForceLevel != nil {
|
if l.ForceLevel != nil {
|
||||||
logger = l.ForceLevel(logger)
|
logger = l.ForceLevel(logger)
|
||||||
} else if l.Errs != nil || len(l.Extras) > 0 {
|
} else if l.Errs != nil || len(l.Extras) > 0 {
|
||||||
|
@ -64,6 +64,15 @@ func (d *Datastore) SaveDistributedQueryCampaign(ctx context.Context, camp *flee
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Datastore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) {
|
||||||
|
var campaigns []*fleet.DistributedQueryCampaign
|
||||||
|
err := sqlx.SelectContext(ctx, d.reader, &campaigns, `SELECT * FROM distributed_query_campaigns WHERE query_id=?`, queryID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getting campaigns for query")
|
||||||
|
}
|
||||||
|
return campaigns, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Datastore) DistributedQueryCampaignTargetIDs(ctx context.Context, id uint) (*fleet.HostTargets, error) {
|
func (d *Datastore) DistributedQueryCampaignTargetIDs(ctx context.Context, id uint) (*fleet.HostTargets, error) {
|
||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
SELECT * FROM distributed_query_campaign_targets WHERE distributed_query_campaign_id = ?
|
SELECT * FROM distributed_query_campaign_targets WHERE distributed_query_campaign_id = ?
|
||||||
|
@ -31,17 +31,6 @@ type DistributedQueryCampaignTarget struct {
|
|||||||
TargetID uint `db:"target_id"`
|
TargetID uint `db:"target_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DistributedQueryExecutionStatus is the status of a distributed query
|
|
||||||
// execution on a single host.
|
|
||||||
type DistributedQueryExecutionStatus int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ExecutionWaiting DistributedQueryExecutionStatus = iota
|
|
||||||
ExecutionRequested
|
|
||||||
ExecutionSucceeded
|
|
||||||
ExecutionFailed
|
|
||||||
)
|
|
||||||
|
|
||||||
// DistributedQueryResult is the result returned from the execution of a
|
// DistributedQueryResult is the result returned from the execution of a
|
||||||
// distributed query on a single host.
|
// distributed query on a single host.
|
||||||
type DistributedQueryResult struct {
|
type DistributedQueryResult struct {
|
||||||
|
@ -86,6 +86,8 @@ type Datastore interface {
|
|||||||
// easier to test. The return values indicate how many campaigns were expired and any error.
|
// easier to test. The return values indicate how many campaigns were expired and any error.
|
||||||
CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error)
|
CleanupDistributedQueryCampaigns(ctx context.Context, now time.Time) (expired uint, err error)
|
||||||
|
|
||||||
|
DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*DistributedQueryCampaign, error)
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// PackStore is the datastore interface for managing query packs.
|
// PackStore is the datastore interface for managing query packs.
|
||||||
|
|
||||||
|
@ -221,6 +221,9 @@ type Service interface {
|
|||||||
// go-kit RPC style.
|
// go-kit RPC style.
|
||||||
StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint)
|
StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint)
|
||||||
|
|
||||||
|
GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
|
||||||
|
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// AgentOptionsService
|
// AgentOptionsService
|
||||||
|
|
||||||
|
@ -75,6 +75,8 @@ type NewDistributedQueryCampaignTargetFunc func(ctx context.Context, target *fle
|
|||||||
|
|
||||||
type CleanupDistributedQueryCampaignsFunc func(ctx context.Context, now time.Time) (expired uint, err error)
|
type CleanupDistributedQueryCampaignsFunc func(ctx context.Context, now time.Time) (expired uint, err error)
|
||||||
|
|
||||||
|
type DistributedQueryCampaignsForQueryFunc func(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error)
|
||||||
|
|
||||||
type ApplyPackSpecsFunc func(ctx context.Context, specs []*fleet.PackSpec) error
|
type ApplyPackSpecsFunc func(ctx context.Context, specs []*fleet.PackSpec) error
|
||||||
|
|
||||||
type GetPackSpecsFunc func(ctx context.Context) ([]*fleet.PackSpec, error)
|
type GetPackSpecsFunc func(ctx context.Context) ([]*fleet.PackSpec, error)
|
||||||
@ -406,6 +408,9 @@ type DataStore struct {
|
|||||||
CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFunc
|
CleanupDistributedQueryCampaignsFunc CleanupDistributedQueryCampaignsFunc
|
||||||
CleanupDistributedQueryCampaignsFuncInvoked bool
|
CleanupDistributedQueryCampaignsFuncInvoked bool
|
||||||
|
|
||||||
|
DistributedQueryCampaignsForQueryFunc DistributedQueryCampaignsForQueryFunc
|
||||||
|
DistributedQueryCampaignsForQueryFuncInvoked bool
|
||||||
|
|
||||||
ApplyPackSpecsFunc ApplyPackSpecsFunc
|
ApplyPackSpecsFunc ApplyPackSpecsFunc
|
||||||
ApplyPackSpecsFuncInvoked bool
|
ApplyPackSpecsFuncInvoked bool
|
||||||
|
|
||||||
@ -918,6 +923,11 @@ func (s *DataStore) CleanupDistributedQueryCampaigns(ctx context.Context, now ti
|
|||||||
return s.CleanupDistributedQueryCampaignsFunc(ctx, now)
|
return s.CleanupDistributedQueryCampaignsFunc(ctx, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DataStore) DistributedQueryCampaignsForQuery(ctx context.Context, queryID uint) ([]*fleet.DistributedQueryCampaign, error) {
|
||||||
|
s.DistributedQueryCampaignsForQueryFuncInvoked = true
|
||||||
|
return s.DistributedQueryCampaignsForQueryFunc(ctx, queryID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DataStore) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) error {
|
func (s *DataStore) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) error {
|
||||||
s.ApplyPackSpecsFuncInvoked = true
|
s.ApplyPackSpecsFuncInvoked = true
|
||||||
return s.ApplyPackSpecsFunc(ctx, specs)
|
return s.ApplyPackSpecsFunc(ctx, specs)
|
||||||
|
@ -570,15 +570,15 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
|||||||
r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users")
|
r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users")
|
||||||
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
|
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
|
||||||
r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user")
|
r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user")
|
||||||
r.Handle("/api/v1/fleet/users/{id}", h.GetUser).Methods("GET").Name("get_user")
|
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.GetUser).Methods("GET").Name("get_user")
|
||||||
r.Handle("/api/v1/fleet/users/{id}", h.ModifyUser).Methods("PATCH").Name("modify_user")
|
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.ModifyUser).Methods("PATCH").Name("modify_user")
|
||||||
r.Handle("/api/v1/fleet/users/{id}", h.DeleteUser).Methods("DELETE").Name("delete_user")
|
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.DeleteUser).Methods("DELETE").Name("delete_user")
|
||||||
r.Handle("/api/v1/fleet/users/{id}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset")
|
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset")
|
||||||
r.Handle("/api/v1/fleet/users/{id}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user")
|
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user")
|
||||||
r.Handle("/api/v1/fleet/users/{id}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user")
|
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/sessions/{id}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
|
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
|
||||||
r.Handle("/api/v1/fleet/sessions/{id}", h.DeleteSession).Methods("DELETE").Name("delete_session")
|
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/config/certificate", h.GetCertificate).Methods("GET").Name("get_certificate")
|
r.Handle("/api/v1/fleet/config/certificate", h.GetCertificate).Methods("GET").Name("get_certificate")
|
||||||
r.Handle("/api/v1/fleet/config", h.GetAppConfig).Methods("GET").Name("get_app_config")
|
r.Handle("/api/v1/fleet/config", h.GetAppConfig).Methods("GET").Name("get_app_config")
|
||||||
@ -587,17 +587,17 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
|||||||
r.Handle("/api/v1/fleet/spec/enroll_secret", h.GetEnrollSecretSpec).Methods("GET").Name("get_enroll_secret_spec")
|
r.Handle("/api/v1/fleet/spec/enroll_secret", h.GetEnrollSecretSpec).Methods("GET").Name("get_enroll_secret_spec")
|
||||||
r.Handle("/api/v1/fleet/invites", h.CreateInvite).Methods("POST").Name("create_invite")
|
r.Handle("/api/v1/fleet/invites", h.CreateInvite).Methods("POST").Name("create_invite")
|
||||||
r.Handle("/api/v1/fleet/invites", h.ListInvites).Methods("GET").Name("list_invites")
|
r.Handle("/api/v1/fleet/invites", h.ListInvites).Methods("GET").Name("list_invites")
|
||||||
r.Handle("/api/v1/fleet/invites/{id}", h.DeleteInvite).Methods("DELETE").Name("delete_invite")
|
r.Handle("/api/v1/fleet/invites/{id:[0-9]+}", h.DeleteInvite).Methods("DELETE").Name("delete_invite")
|
||||||
r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite")
|
r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/email/change/{token}", h.ChangeEmail).Methods("GET").Name("change_email")
|
r.Handle("/api/v1/fleet/email/change/{token}", h.ChangeEmail).Methods("GET").Name("change_email")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/queries/{id}", h.GetQuery).Methods("GET").Name("get_query")
|
r.Handle("/api/v1/fleet/queries/{id:[0-9]+}", h.GetQuery).Methods("GET").Name("get_query")
|
||||||
r.Handle("/api/v1/fleet/queries", h.ListQueries).Methods("GET").Name("list_queries")
|
r.Handle("/api/v1/fleet/queries", h.ListQueries).Methods("GET").Name("list_queries")
|
||||||
r.Handle("/api/v1/fleet/queries", h.CreateQuery).Methods("POST").Name("create_query")
|
r.Handle("/api/v1/fleet/queries", h.CreateQuery).Methods("POST").Name("create_query")
|
||||||
r.Handle("/api/v1/fleet/queries/{id}", h.ModifyQuery).Methods("PATCH").Name("modify_query")
|
r.Handle("/api/v1/fleet/queries/{id:[0-9]+}", h.ModifyQuery).Methods("PATCH").Name("modify_query")
|
||||||
r.Handle("/api/v1/fleet/queries/{name}", h.DeleteQuery).Methods("DELETE").Name("delete_query")
|
r.Handle("/api/v1/fleet/queries/{name}", h.DeleteQuery).Methods("DELETE").Name("delete_query")
|
||||||
r.Handle("/api/v1/fleet/queries/id/{id}", h.DeleteQueryByID).Methods("DELETE").Name("delete_query_by_id")
|
r.Handle("/api/v1/fleet/queries/id/{id:[0-9]+}", h.DeleteQueryByID).Methods("DELETE").Name("delete_query_by_id")
|
||||||
r.Handle("/api/v1/fleet/queries/delete", h.DeleteQueries).Methods("POST").Name("delete_queries")
|
r.Handle("/api/v1/fleet/queries/delete", h.DeleteQueries).Methods("POST").Name("delete_queries")
|
||||||
r.Handle("/api/v1/fleet/spec/queries", h.ApplyQuerySpecs).Methods("POST").Name("apply_query_specs")
|
r.Handle("/api/v1/fleet/spec/queries", h.ApplyQuerySpecs).Methods("POST").Name("apply_query_specs")
|
||||||
r.Handle("/api/v1/fleet/spec/queries", h.GetQuerySpecs).Methods("GET").Name("get_query_specs")
|
r.Handle("/api/v1/fleet/spec/queries", h.GetQuerySpecs).Methods("GET").Name("get_query_specs")
|
||||||
@ -606,41 +606,41 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
|||||||
r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names")
|
r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack")
|
r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack")
|
||||||
r.Handle("/api/v1/fleet/packs/{id}", h.ModifyPack).Methods("PATCH").Name("modify_pack")
|
r.Handle("/api/v1/fleet/packs/{id:[0-9]+}", h.ModifyPack).Methods("PATCH").Name("modify_pack")
|
||||||
r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs")
|
r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs")
|
||||||
r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack")
|
r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack")
|
||||||
r.Handle("/api/v1/fleet/packs/id/{id}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id")
|
r.Handle("/api/v1/fleet/packs/id/{id:[0-9]+}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id")
|
||||||
r.Handle("/api/v1/fleet/packs/{id}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack")
|
r.Handle("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack")
|
||||||
r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query")
|
r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query")
|
||||||
r.Handle("/api/v1/fleet/schedule/{id}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query")
|
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query")
|
||||||
r.Handle("/api/v1/fleet/schedule/{id}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query")
|
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query")
|
||||||
r.Handle("/api/v1/fleet/schedule/{id}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query")
|
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query")
|
||||||
r.Handle("/api/v1/fleet/spec/packs", h.ApplyPackSpecs).Methods("POST").Name("apply_pack_specs")
|
r.Handle("/api/v1/fleet/spec/packs", h.ApplyPackSpecs).Methods("POST").Name("apply_pack_specs")
|
||||||
r.Handle("/api/v1/fleet/spec/packs", h.GetPackSpecs).Methods("GET").Name("get_pack_specs")
|
r.Handle("/api/v1/fleet/spec/packs", h.GetPackSpecs).Methods("GET").Name("get_pack_specs")
|
||||||
r.Handle("/api/v1/fleet/spec/packs/{name}", h.GetPackSpec).Methods("GET").Name("get_pack_spec")
|
r.Handle("/api/v1/fleet/spec/packs/{name}", h.GetPackSpec).Methods("GET").Name("get_pack_spec")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/global/schedule", h.GetGlobalSchedule).Methods("GET").Name("set_global_schedule")
|
r.Handle("/api/v1/fleet/global/schedule", h.GetGlobalSchedule).Methods("GET").Name("set_global_schedule")
|
||||||
r.Handle("/api/v1/fleet/global/schedule", h.GlobalScheduleQuery).Methods("POST").Name("add_to_global_schedule")
|
r.Handle("/api/v1/fleet/global/schedule", h.GlobalScheduleQuery).Methods("POST").Name("add_to_global_schedule")
|
||||||
r.Handle("/api/v1/fleet/global/schedule/{id}", h.ModifyGlobalSchedule).Methods("PATCH").Name("modify_global_schedule")
|
r.Handle("/api/v1/fleet/global/schedule/{id:[0-9]+}", h.ModifyGlobalSchedule).Methods("PATCH").Name("modify_global_schedule")
|
||||||
r.Handle("/api/v1/fleet/global/schedule/{id}", h.DeleteGlobalSchedule).Methods("DELETE").Name("delete_global_schedule")
|
r.Handle("/api/v1/fleet/global/schedule/{id:[0-9]+}", h.DeleteGlobalSchedule).Methods("DELETE").Name("delete_global_schedule")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label")
|
r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label")
|
||||||
r.Handle("/api/v1/fleet/labels/{id}", h.ModifyLabel).Methods("PATCH").Name("modify_label")
|
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.ModifyLabel).Methods("PATCH").Name("modify_label")
|
||||||
r.Handle("/api/v1/fleet/labels/{id}", h.GetLabel).Methods("GET").Name("get_label")
|
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.GetLabel).Methods("GET").Name("get_label")
|
||||||
r.Handle("/api/v1/fleet/labels", h.ListLabels).Methods("GET").Name("list_labels")
|
r.Handle("/api/v1/fleet/labels", h.ListLabels).Methods("GET").Name("list_labels")
|
||||||
r.Handle("/api/v1/fleet/labels/{id}/hosts", h.ListHostsInLabel).Methods("GET").Name("list_hosts_in_label")
|
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}/hosts", h.ListHostsInLabel).Methods("GET").Name("list_hosts_in_label")
|
||||||
r.Handle("/api/v1/fleet/labels/{name}", h.DeleteLabel).Methods("DELETE").Name("delete_label")
|
r.Handle("/api/v1/fleet/labels/{name}", h.DeleteLabel).Methods("DELETE").Name("delete_label")
|
||||||
r.Handle("/api/v1/fleet/labels/id/{id}", h.DeleteLabelByID).Methods("DELETE").Name("delete_label_by_id")
|
r.Handle("/api/v1/fleet/labels/id/{id:[0-9]+}", h.DeleteLabelByID).Methods("DELETE").Name("delete_label_by_id")
|
||||||
r.Handle("/api/v1/fleet/spec/labels", h.ApplyLabelSpecs).Methods("POST").Name("apply_label_specs")
|
r.Handle("/api/v1/fleet/spec/labels", h.ApplyLabelSpecs).Methods("POST").Name("apply_label_specs")
|
||||||
r.Handle("/api/v1/fleet/spec/labels", h.GetLabelSpecs).Methods("GET").Name("get_label_specs")
|
r.Handle("/api/v1/fleet/spec/labels", h.GetLabelSpecs).Methods("GET").Name("get_label_specs")
|
||||||
r.Handle("/api/v1/fleet/spec/labels/{name}", h.GetLabelSpec).Methods("GET").Name("get_label_spec")
|
r.Handle("/api/v1/fleet/spec/labels/{name}", h.GetLabelSpec).Methods("GET").Name("get_label_spec")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/host_summary", h.GetHostSummary).Methods("GET").Name("get_host_summary")
|
r.Handle("/api/v1/fleet/host_summary", h.GetHostSummary).Methods("GET").Name("get_host_summary")
|
||||||
r.Handle("/api/v1/fleet/hosts/identifier/{identifier}", h.HostByIdentifier).Methods("GET").Name("host_by_identifier")
|
r.Handle("/api/v1/fleet/hosts/identifier/{identifier}", h.HostByIdentifier).Methods("GET").Name("host_by_identifier")
|
||||||
r.Handle("/api/v1/fleet/hosts/{id}", h.DeleteHost).Methods("DELETE").Name("delete_host")
|
r.Handle("/api/v1/fleet/hosts/{id:[0-9]+}", h.DeleteHost).Methods("DELETE").Name("delete_host")
|
||||||
r.Handle("/api/v1/fleet/hosts/transfer", h.AddHostsToTeam).Methods("POST").Name("add_hosts_to_team")
|
r.Handle("/api/v1/fleet/hosts/transfer", h.AddHostsToTeam).Methods("POST").Name("add_hosts_to_team")
|
||||||
r.Handle("/api/v1/fleet/hosts/transfer/filter", h.AddHostsToTeamByFilter).Methods("POST").Name("add_hosts_to_team_by_filter")
|
r.Handle("/api/v1/fleet/hosts/transfer/filter", h.AddHostsToTeamByFilter).Methods("POST").Name("add_hosts_to_team_by_filter")
|
||||||
r.Handle("/api/v1/fleet/hosts/{id}/refetch", h.RefetchHost).Methods("POST").Name("refetch_host")
|
r.Handle("/api/v1/fleet/hosts/{id:[0-9]+}/refetch", h.RefetchHost).Methods("POST").Name("refetch_host")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/targets", h.SearchTargets).Methods("POST").Name("search_targets")
|
r.Handle("/api/v1/fleet/targets", h.SearchTargets).Methods("POST").Name("search_targets")
|
||||||
|
|
||||||
@ -650,18 +650,18 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
|||||||
r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query")
|
r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/carves", h.ListCarves).Methods("GET").Name("list_carves")
|
r.Handle("/api/v1/fleet/carves", h.ListCarves).Methods("GET").Name("list_carves")
|
||||||
r.Handle("/api/v1/fleet/carves/{id}", h.GetCarve).Methods("GET").Name("get_carve")
|
r.Handle("/api/v1/fleet/carves/{id:[0-9]+}", h.GetCarve).Methods("GET").Name("get_carve")
|
||||||
r.Handle("/api/v1/fleet/carves/{id}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block")
|
r.Handle("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block")
|
||||||
|
|
||||||
r.Handle("/api/v1/fleet/teams", h.CreateTeam).Methods("POST").Name("create_team")
|
r.Handle("/api/v1/fleet/teams", h.CreateTeam).Methods("POST").Name("create_team")
|
||||||
r.Handle("/api/v1/fleet/teams", h.ListTeams).Methods("GET").Name("list_teams")
|
r.Handle("/api/v1/fleet/teams", h.ListTeams).Methods("GET").Name("list_teams")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}", h.ModifyTeam).Methods("PATCH").Name("modify_team")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.ModifyTeam).Methods("PATCH").Name("modify_team")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}", h.DeleteTeam).Methods("DELETE").Name("delete_team")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.DeleteTeam).Methods("DELETE").Name("delete_team")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}/agent_options", h.ModifyTeamAgentOptions).Methods("POST").Name("modify_team_agent_options")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/agent_options", h.ModifyTeamAgentOptions).Methods("POST").Name("modify_team_agent_options")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}/users", h.ListTeamUsers).Methods("GET").Name("team_users")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.ListTeamUsers).Methods("GET").Name("team_users")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}/users", h.AddTeamUsers).Methods("PATCH").Name("add_team_users")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.AddTeamUsers).Methods("PATCH").Name("add_team_users")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}/users", h.DeleteTeamUsers).Methods("DELETE").Name("delete_team_users")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/users", h.DeleteTeamUsers).Methods("DELETE").Name("delete_team_users")
|
||||||
r.Handle("/api/v1/fleet/teams/{id}/secrets", h.TeamEnrollSecrets).Methods("GET").Name("get_team_enroll_secrets")
|
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}/secrets", h.TeamEnrollSecrets).Methods("GET").Name("get_team_enroll_secrets")
|
||||||
|
|
||||||
r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent")
|
r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent")
|
||||||
r.Handle("/api/v1/osquery/config", h.GetClientConfig).Methods("POST").Name("get_client_config")
|
r.Handle("/api/v1/osquery/config", h.GetClientConfig).Methods("POST").Name("get_client_config")
|
||||||
@ -719,6 +719,8 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
|
|||||||
e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{})
|
e.POST("/api/v1/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{})
|
||||||
e.GET("/api/v1/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{})
|
e.GET("/api/v1/fleet/hosts/{id:[0-9]+}", getHostEndpoint, getHostRequest{})
|
||||||
e.GET("/api/v1/fleet/hosts/count", countHostsEndpoint, countHostsRequest{})
|
e.GET("/api/v1/fleet/hosts/count", countHostsEndpoint, countHostsRequest{})
|
||||||
|
|
||||||
|
e.GET("/api/v1/fleet/queries/run", runLiveQueryEndpoint, runLiveQueryRequest{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this duplicates the one in makeKitHandler
|
// TODO: this duplicates the one in makeKitHandler
|
||||||
|
398
server/service/integration_live_queries_test.go
Normal file
398
server/service/integration_live_queries_test.go
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||||
|
"github.com/fleetdm/fleet/v4/server/live_query"
|
||||||
|
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntegrationLiveQueriesTestSuite(t *testing.T) {
|
||||||
|
testingSuite := new(liveQueriesTestSuite)
|
||||||
|
testingSuite.s = &testingSuite.Suite
|
||||||
|
suite.Run(t, testingSuite)
|
||||||
|
}
|
||||||
|
|
||||||
|
type liveQueriesTestSuite struct {
|
||||||
|
withServer
|
||||||
|
suite.Suite
|
||||||
|
|
||||||
|
lq *live_query.MockLiveQuery
|
||||||
|
hosts []*fleet.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *liveQueriesTestSuite) SetupSuite() {
|
||||||
|
require.NoError(s.T(), os.Setenv("FLEET_LIVE_QUERY_REST_PERIOD", "5s"))
|
||||||
|
|
||||||
|
s.withDS.SetupSuite("liveQueriesTestSuite")
|
||||||
|
|
||||||
|
rs := pubsub.NewInmemQueryResults()
|
||||||
|
lq := new(live_query.MockLiveQuery)
|
||||||
|
s.lq = lq
|
||||||
|
|
||||||
|
users, server := RunServerForTestsWithDS(s.T(), s.ds, TestServerOpts{Lq: lq, Rs: rs})
|
||||||
|
s.server = server
|
||||||
|
s.users = users
|
||||||
|
s.token = getTestAdminToken(s.T(), s.server)
|
||||||
|
|
||||||
|
t := s.T()
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
|
||||||
|
DetailUpdatedAt: time.Now(),
|
||||||
|
LabelUpdatedAt: time.Now(),
|
||||||
|
PolicyUpdatedAt: time.Now(),
|
||||||
|
SeenTime: time.Now().Add(-time.Duration(i) * time.Minute),
|
||||||
|
OsqueryHostID: fmt.Sprintf("%s%d", t.Name(), i),
|
||||||
|
NodeKey: fmt.Sprintf("%s%d", t.Name(), i),
|
||||||
|
UUID: fmt.Sprintf("%s%d", t.Name(), i),
|
||||||
|
Hostname: fmt.Sprintf("%sfoo.local%d", t.Name(), i),
|
||||||
|
})
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
s.hosts = append(s.hosts, host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||||
|
t := s.T()
|
||||||
|
|
||||||
|
host := s.hosts[0]
|
||||||
|
|
||||||
|
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||||
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||||
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
|
||||||
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||||
|
|
||||||
|
liveQueryRequest := runLiveQueryRequest{
|
||||||
|
QueryIDs: []uint{q1.ID},
|
||||||
|
HostIDs: []uint{host.ID},
|
||||||
|
}
|
||||||
|
liveQueryResp := runLiveQueryResponse{}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give the above call a couple of seconds to create the campaign
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
cid := getCIDForQ(s, q1)
|
||||||
|
|
||||||
|
distributedReq := submitDistributedQueryResultsRequest{
|
||||||
|
NodeKey: host.NodeKey,
|
||||||
|
Results: map[string][]map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid: {{"col1": "a", "col2": "b"}},
|
||||||
|
},
|
||||||
|
Statuses: map[string]fleet.OsqueryStatus{
|
||||||
|
hostDistributedQueryPrefix + cid: 0,
|
||||||
|
},
|
||||||
|
Messages: map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid: "some msg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
distributedResp := submitDistributedQueryResultsResponse{}
|
||||||
|
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.Len(t, liveQueryResp.Results, 1)
|
||||||
|
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
||||||
|
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
||||||
|
require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1)
|
||||||
|
assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"])
|
||||||
|
assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
|
||||||
|
t := s.T()
|
||||||
|
|
||||||
|
host := s.hosts[0]
|
||||||
|
|
||||||
|
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 2 from osquery;", Description: "desc2", Name: t.Name() + "query2"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.lq.On("QueriesForHost", host.ID).Return(map[string]string{
|
||||||
|
fmt.Sprint(q1.ID): "select 1 from osquery;",
|
||||||
|
fmt.Sprint(q2.ID): "select 2 from osquery;",
|
||||||
|
}, nil)
|
||||||
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||||
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
|
||||||
|
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{host.ID}).Return(nil)
|
||||||
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||||
|
|
||||||
|
liveQueryRequest := runLiveQueryRequest{
|
||||||
|
QueryIDs: []uint{q1.ID, q2.ID},
|
||||||
|
HostIDs: []uint{host.ID},
|
||||||
|
}
|
||||||
|
liveQueryResp := runLiveQueryResponse{}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give the above call a couple of seconds to create the campaign
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
cid1 := getCIDForQ(s, q1)
|
||||||
|
cid2 := getCIDForQ(s, q2)
|
||||||
|
|
||||||
|
distributedReq := submitDistributedQueryResultsRequest{
|
||||||
|
NodeKey: host.NodeKey,
|
||||||
|
Results: map[string][]map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}},
|
||||||
|
hostDistributedQueryPrefix + cid2: {{"col3": "c", "col4": "d"}, {"col3": "e", "col4": "f"}},
|
||||||
|
},
|
||||||
|
Statuses: map[string]fleet.OsqueryStatus{
|
||||||
|
hostDistributedQueryPrefix + cid1: 0,
|
||||||
|
hostDistributedQueryPrefix + cid2: 0,
|
||||||
|
},
|
||||||
|
Messages: map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: "some msg",
|
||||||
|
hostDistributedQueryPrefix + cid2: "some other msg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
distributedResp := submitDistributedQueryResultsResponse{}
|
||||||
|
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.Len(t, liveQueryResp.Results, 2)
|
||||||
|
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
||||||
|
|
||||||
|
sort.Slice(liveQueryResp.Results, func(i, j int) bool {
|
||||||
|
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
|
||||||
|
})
|
||||||
|
|
||||||
|
require.True(t, q1.ID < q2.ID)
|
||||||
|
|
||||||
|
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
||||||
|
require.Len(t, liveQueryResp.Results[0].Results, 1)
|
||||||
|
q1Results := liveQueryResp.Results[0].Results[0]
|
||||||
|
require.Len(t, q1Results.Rows, 1)
|
||||||
|
assert.Equal(t, "a", q1Results.Rows[0]["col1"])
|
||||||
|
assert.Equal(t, "b", q1Results.Rows[0]["col2"])
|
||||||
|
|
||||||
|
assert.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
|
||||||
|
require.Len(t, liveQueryResp.Results[1].Results, 1)
|
||||||
|
q2Results := liveQueryResp.Results[1].Results[0]
|
||||||
|
require.Len(t, q2Results.Rows, 2)
|
||||||
|
assert.Equal(t, "c", q2Results.Rows[0]["col3"])
|
||||||
|
assert.Equal(t, "d", q2Results.Rows[0]["col4"])
|
||||||
|
assert.Equal(t, "e", q2Results.Rows[1]["col3"])
|
||||||
|
assert.Equal(t, "f", q2Results.Rows[1]["col4"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCIDForQ(s *liveQueriesTestSuite, q1 *fleet.Query) string {
|
||||||
|
t := s.T()
|
||||||
|
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, campaigns, 1)
|
||||||
|
cid1 := fmt.Sprint(campaigns[0].ID)
|
||||||
|
return cid1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
|
||||||
|
t := s.T()
|
||||||
|
|
||||||
|
h1 := s.hosts[0]
|
||||||
|
h2 := s.hosts[1]
|
||||||
|
|
||||||
|
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
q2, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 2 from osquery;", Description: "desc2", Name: t.Name() + "query2"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{
|
||||||
|
fmt.Sprint(q1.ID): "select 1 from osquery;",
|
||||||
|
fmt.Sprint(q2.ID): "select 2 from osquery;",
|
||||||
|
}, nil)
|
||||||
|
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{
|
||||||
|
fmt.Sprint(q1.ID): "select 1 from osquery;",
|
||||||
|
fmt.Sprint(q2.ID): "select 2 from osquery;",
|
||||||
|
}, nil)
|
||||||
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||||
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
||||||
|
s.lq.On("RunQuery", mock.Anything, "select 2 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
||||||
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||||
|
|
||||||
|
liveQueryRequest := runLiveQueryRequest{
|
||||||
|
QueryIDs: []uint{q1.ID, q2.ID},
|
||||||
|
HostIDs: []uint{h1.ID, h2.ID},
|
||||||
|
}
|
||||||
|
liveQueryResp := runLiveQueryResponse{}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give the above call a couple of seconds to create the campaign
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
cid1 := getCIDForQ(s, q1)
|
||||||
|
cid2 := getCIDForQ(s, q2)
|
||||||
|
for i, h := range []*fleet.Host{h1, h2} {
|
||||||
|
distributedReq := submitDistributedQueryResultsRequest{
|
||||||
|
NodeKey: h.NodeKey,
|
||||||
|
Results: map[string][]map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: {{"col1": fmt.Sprintf("a%d", i), "col2": fmt.Sprintf("b%d", i)}},
|
||||||
|
hostDistributedQueryPrefix + cid2: {{"col3": fmt.Sprintf("c%d", i), "col4": fmt.Sprintf("d%d", i)}, {"col3": fmt.Sprintf("e%d", i), "col4": fmt.Sprintf("f%d", i)}},
|
||||||
|
},
|
||||||
|
Statuses: map[string]fleet.OsqueryStatus{
|
||||||
|
hostDistributedQueryPrefix + cid1: 0,
|
||||||
|
hostDistributedQueryPrefix + cid2: 0,
|
||||||
|
},
|
||||||
|
Messages: map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: "some msg",
|
||||||
|
hostDistributedQueryPrefix + cid2: "some other msg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
distributedResp := submitDistributedQueryResultsResponse{}
|
||||||
|
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.Len(t, liveQueryResp.Results, 2) // 2 queries
|
||||||
|
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
|
||||||
|
|
||||||
|
sort.Slice(liveQueryResp.Results, func(i, j int) bool {
|
||||||
|
return liveQueryResp.Results[i].QueryID < liveQueryResp.Results[j].QueryID
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
||||||
|
require.Len(t, liveQueryResp.Results[0].Results, 2)
|
||||||
|
for i, r := range liveQueryResp.Results[0].Results {
|
||||||
|
require.Len(t, r.Rows, 1)
|
||||||
|
assert.Equal(t, fmt.Sprintf("a%d", i), r.Rows[0]["col1"])
|
||||||
|
assert.Equal(t, fmt.Sprintf("b%d", i), r.Rows[0]["col2"])
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, q2.ID, liveQueryResp.Results[1].QueryID)
|
||||||
|
require.Len(t, liveQueryResp.Results[1].Results, 2)
|
||||||
|
for i, r := range liveQueryResp.Results[1].Results {
|
||||||
|
require.Len(t, r.Rows, 2)
|
||||||
|
assert.Equal(t, fmt.Sprintf("c%d", i), r.Rows[0]["col3"])
|
||||||
|
assert.Equal(t, fmt.Sprintf("d%d", i), r.Rows[0]["col4"])
|
||||||
|
assert.Equal(t, fmt.Sprintf("e%d", i), r.Rows[1]["col3"])
|
||||||
|
assert.Equal(t, fmt.Sprintf("f%d", i), r.Rows[1]["col4"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
|
||||||
|
t := s.T()
|
||||||
|
|
||||||
|
liveQueryRequest := runLiveQueryRequest{
|
||||||
|
QueryIDs: []uint{999},
|
||||||
|
HostIDs: []uint{888},
|
||||||
|
}
|
||||||
|
liveQueryResp := runLiveQueryResponse{}
|
||||||
|
|
||||||
|
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||||
|
|
||||||
|
require.Len(t, liveQueryResp.Results, 1)
|
||||||
|
assert.Equal(t, 0, liveQueryResp.Summary.RespondedHostCount)
|
||||||
|
require.NotNil(t, liveQueryResp.Results[0].Error)
|
||||||
|
assert.Equal(t, "selecting query: sql: no rows in result set", *liveQueryResp.Results[0].Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {
|
||||||
|
t := s.T()
|
||||||
|
|
||||||
|
h1 := s.hosts[0]
|
||||||
|
h2 := s.hosts[1]
|
||||||
|
|
||||||
|
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{Query: "select 1 from osquery;", Description: "desc1", Name: t.Name() + "query1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.lq.On("QueriesForHost", h1.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||||
|
s.lq.On("QueriesForHost", h2.ID).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
|
||||||
|
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
|
||||||
|
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{h1.ID, h2.ID}).Return(nil)
|
||||||
|
s.lq.On("StopQuery", mock.Anything).Return(nil)
|
||||||
|
|
||||||
|
liveQueryRequest := runLiveQueryRequest{
|
||||||
|
QueryIDs: []uint{q1.ID},
|
||||||
|
HostIDs: []uint{h1.ID, h2.ID},
|
||||||
|
}
|
||||||
|
liveQueryResp := runLiveQueryResponse{}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s.DoJSON("GET", "/api/v1/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give the above call a couple of seconds to create the campaign
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
cid1 := getCIDForQ(s, q1)
|
||||||
|
distributedReq := submitDistributedQueryResultsRequest{
|
||||||
|
NodeKey: h1.NodeKey,
|
||||||
|
Results: map[string][]map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: {{"col1": "a", "col2": "b"}},
|
||||||
|
},
|
||||||
|
Statuses: map[string]fleet.OsqueryStatus{
|
||||||
|
hostDistributedQueryPrefix + cid1: 0,
|
||||||
|
},
|
||||||
|
Messages: map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: "some msg",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
distributedResp := submitDistributedQueryResultsResponse{}
|
||||||
|
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||||
|
|
||||||
|
distributedReq = submitDistributedQueryResultsRequest{
|
||||||
|
NodeKey: h2.NodeKey,
|
||||||
|
Results: map[string][]map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: {},
|
||||||
|
},
|
||||||
|
Statuses: map[string]fleet.OsqueryStatus{
|
||||||
|
hostDistributedQueryPrefix + cid1: 123,
|
||||||
|
},
|
||||||
|
Messages: map[string]string{
|
||||||
|
hostDistributedQueryPrefix + cid1: "some error!",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
distributedResp = submitDistributedQueryResultsResponse{}
|
||||||
|
s.DoJSON("POST", "/api/v1/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.Len(t, liveQueryResp.Results, 1)
|
||||||
|
assert.Equal(t, 2, liveQueryResp.Summary.RespondedHostCount)
|
||||||
|
|
||||||
|
result := liveQueryResp.Results[0]
|
||||||
|
require.Len(t, result.Results, 2)
|
||||||
|
require.Len(t, result.Results[0].Rows, 1)
|
||||||
|
assert.Equal(t, "a", result.Results[0].Rows[0]["col1"])
|
||||||
|
assert.Equal(t, "b", result.Results[0].Rows[0]["col2"])
|
||||||
|
require.Len(t, result.Results[1].Rows, 0)
|
||||||
|
require.NotNil(t, result.Results[1].Error)
|
||||||
|
assert.Equal(t, "some error!", *result.Results[1].Error)
|
||||||
|
}
|
168
server/service/live_queries.go
Normal file
168
server/service/live_queries.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||||
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type runLiveQueryRequest struct {
|
||||||
|
QueryIDs []uint `json:"query_ids"`
|
||||||
|
HostIDs []uint `json:"host_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type summaryPayload struct {
|
||||||
|
TargetedHostCount int `json:"targeted_host_count"`
|
||||||
|
RespondedHostCount int `json:"responded_host_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type runLiveQueryResponse struct {
|
||||||
|
Summary summaryPayload `json:"summary"`
|
||||||
|
Err error `json:"error,omitempty"`
|
||||||
|
|
||||||
|
Results []queryCampaignResult `json:"live_query_results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runLiveQueryResponse) error() error { return r.Err }
|
||||||
|
|
||||||
|
type queryResult struct {
|
||||||
|
HostID uint `json:"host_id"`
|
||||||
|
Rows []map[string]string `json:"rows"`
|
||||||
|
Error *string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryCampaignResult struct {
|
||||||
|
QueryID uint `json:"query_id"`
|
||||||
|
Error *string `json:"error,omitempty"`
|
||||||
|
Results []queryResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||||
|
req := request.(*runLiveQueryRequest)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
|
resultsCh := make(chan queryCampaignResult)
|
||||||
|
|
||||||
|
counterMutex := sync.Mutex{}
|
||||||
|
counter := make(map[uint]struct{})
|
||||||
|
|
||||||
|
period := os.Getenv("FLEET_LIVE_QUERY_REST_PERIOD")
|
||||||
|
if period == "" {
|
||||||
|
period = "90s"
|
||||||
|
}
|
||||||
|
duration, err := time.ParseDuration(period)
|
||||||
|
if err != nil {
|
||||||
|
duration = 90 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, queryID := range req.QueryIDs {
|
||||||
|
queryID := queryID
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: req.HostIDs})
|
||||||
|
if err != nil {
|
||||||
|
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
||||||
|
if err != nil {
|
||||||
|
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := svc.CompleteCampaign(ctx, campaign)
|
||||||
|
if err != nil {
|
||||||
|
resultsCh <- queryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error())}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(duration)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var results []queryResult
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case res := <-readChan:
|
||||||
|
// Receive a result and push it over the websocket
|
||||||
|
switch res := res.(type) {
|
||||||
|
case fleet.DistributedQueryResult:
|
||||||
|
results = append(results, queryResult{HostID: res.Host.ID, Rows: res.Rows, Error: res.Error})
|
||||||
|
counterMutex.Lock()
|
||||||
|
counter[res.Host.ID] = struct{}{}
|
||||||
|
counterMutex.Unlock()
|
||||||
|
}
|
||||||
|
case <-ticker.C:
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resultsCh <- queryCampaignResult{QueryID: queryID, Results: results}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(resultsCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
res := runLiveQueryResponse{
|
||||||
|
Summary: summaryPayload{
|
||||||
|
TargetedHostCount: len(req.HostIDs),
|
||||||
|
RespondedHostCount: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for result := range resultsCh {
|
||||||
|
res.Results = append(res.Results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Summary.RespondedHostCount = len(counter)
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (svc *Service) GetCampaignReader(ctx context.Context, campaign *fleet.DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) {
|
||||||
|
// Open the channel from which we will receive incoming query results
|
||||||
|
// (probably from the redis pubsub implementation)
|
||||||
|
cancelCtx, cancelFunc := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign)
|
||||||
|
if err != nil {
|
||||||
|
cancelFunc()
|
||||||
|
return nil, nil, fmt.Errorf("cannot open read channel for campaign %d ", campaign.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setting status to running will cause the query to be returned to the
|
||||||
|
// targets when they check in for their queries
|
||||||
|
campaign.Status = fleet.QueryRunning
|
||||||
|
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
|
||||||
|
cancelFunc()
|
||||||
|
return nil, nil, errors.Wrap(err, "error saving campaign state")
|
||||||
|
}
|
||||||
|
|
||||||
|
return readChan, cancelFunc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (svc *Service) CompleteCampaign(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error {
|
||||||
|
campaign.Status = fleet.QueryComplete
|
||||||
|
err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "saving distributed campaign after complete")
|
||||||
|
}
|
||||||
|
err = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID)))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "stopping query after after complete")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -44,7 +44,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware {
|
|||||||
|
|
||||||
// If authorization was not checked, return a response that will
|
// If authorization was not checked, return a response that will
|
||||||
// marshal to a generic error and log that the check was missed.
|
// marshal to a generic error and log that the check was missed.
|
||||||
if !authzctx.Checked {
|
if !authzctx.Checked() {
|
||||||
return nil, authz.CheckMissingWithResponse(response)
|
return nil, authz.CheckMissingWithResponse(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ func TestAuthzCheck(t *testing.T) {
|
|||||||
check := func(ctx context.Context, req interface{}) (interface{}, error) {
|
check := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
authCtx, ok := authz.FromContext(ctx)
|
authCtx, ok := authz.FromContext(ctx)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
authCtx.Checked = true
|
authCtx.SetChecked()
|
||||||
return struct{}{}, nil
|
return struct{}{}, nil
|
||||||
}
|
}
|
||||||
check = checker.AuthzCheck()(check)
|
check = checker.AuthzCheck()(check)
|
||||||
|
@ -222,31 +222,17 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co
|
|||||||
|
|
||||||
// Open the channel from which we will receive incoming query results
|
// Open the channel from which we will receive incoming query results
|
||||||
// (probably from the redis pubsub implementation)
|
// (probably from the redis pubsub implementation)
|
||||||
cancelCtx, cancelFunc := context.WithCancel(ctx)
|
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
||||||
defer cancelFunc()
|
|
||||||
|
|
||||||
readChan, err := svc.resultStore.ReadChannel(cancelCtx, *campaign)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.WriteJSONError(fmt.Sprintf("cannot open read channel for campaign %d ", campaignID))
|
conn.WriteJSONError("error getting campaign reader: " + err.Error())
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setting status to running will cause the query to be returned to the
|
|
||||||
// targets when they check in for their queries
|
|
||||||
campaign.Status = fleet.QueryRunning
|
|
||||||
if err := svc.ds.SaveDistributedQueryCampaign(ctx, campaign); err != nil {
|
|
||||||
conn.WriteJSONError("error saving campaign state")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
// Setting the status to completed stops the query from being sent to
|
// Setting the status to completed stops the query from being sent to
|
||||||
// targets. If this fails, there is a background job that will clean up
|
// targets. If this fails, there is a background job that will clean up
|
||||||
// this campaign.
|
// this campaign.
|
||||||
defer func() {
|
defer svc.CompleteCampaign(ctx, campaign)
|
||||||
campaign.Status = fleet.QueryComplete
|
|
||||||
_ = svc.ds.SaveDistributedQueryCampaign(ctx, campaign)
|
|
||||||
_ = svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaign.ID)))
|
|
||||||
}()
|
|
||||||
|
|
||||||
status := campaignStatus{
|
status := campaignStatus{
|
||||||
Status: campaignStatusPending,
|
Status: campaignStatusPending,
|
||||||
|
Loading…
Reference in New Issue
Block a user