mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
0e040cc7b0
📺 Looom: https://www.loom.com/share/1aec4616fa4449e7abac579084aef0ba?sid=0884f742-feb3-48bb-82dc-b7834bc9a6e1 Fixed fleetctl issue where it was creating a new query when running a query by name, as opposed to using the existing saved query. #15630 API change will be in a separate PR: https://github.com/fleetdm/fleet/pull/15673 # Checklist for submitter If some of the following don't apply, delete the relevant line. <!-- Note that API documentation changes are now addressed by the product design team. --> - [x] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] Added/updated tests - [x] Manual QA for all new/changed functionality
186 lines
5.1 KiB
Go
186 lines
5.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"net/http"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
|
|
ws "github.com/fleetdm/fleet/v4/server/websocket"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// LiveQueryResultsHandler provides access to all of the information about an
|
|
// incoming stream of live query results.
|
|
type LiveQueryResultsHandler struct {
|
|
errors chan error
|
|
results chan fleet.DistributedQueryResult
|
|
totals atomic.Value // real type: targetTotals
|
|
status atomic.Value // real type: campaignStatus
|
|
}
|
|
|
|
func NewLiveQueryResultsHandler() *LiveQueryResultsHandler {
|
|
return &LiveQueryResultsHandler{
|
|
errors: make(chan error),
|
|
results: make(chan fleet.DistributedQueryResult),
|
|
}
|
|
}
|
|
|
|
// Errors returns a read channel that includes any errors returned by the
|
|
// server or receiving the results.
|
|
func (h *LiveQueryResultsHandler) Errors() <-chan error {
|
|
return h.errors
|
|
}
|
|
|
|
// Results returns a read channel including any received results
|
|
func (h *LiveQueryResultsHandler) Results() <-chan fleet.DistributedQueryResult {
|
|
return h.results
|
|
}
|
|
|
|
// Totals returns the current metadata of hosts targeted by the query
|
|
func (h *LiveQueryResultsHandler) Totals() *targetTotals {
|
|
t := h.totals.Load()
|
|
if t != nil {
|
|
return t.(*targetTotals)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *LiveQueryResultsHandler) Status() *campaignStatus {
|
|
s := h.status.Load()
|
|
if s != nil {
|
|
return s.(*campaignStatus)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LiveQuery creates a new live query and begins streaming results.
|
|
func (c *Client) LiveQuery(query string, queryID *uint, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
|
|
return c.LiveQueryWithContext(context.Background(), query, queryID, labels, hosts)
|
|
}
|
|
|
|
func (c *Client) LiveQueryWithContext(
|
|
ctx context.Context, query string, queryID *uint, labels []string, hosts []string,
|
|
) (*LiveQueryResultsHandler, error) {
|
|
req := createDistributedQueryCampaignByNamesRequest{
|
|
QueryID: queryID,
|
|
QuerySQL: query,
|
|
Selected: distributedQueryCampaignTargetsByNames{Labels: labels, Hosts: hosts},
|
|
}
|
|
verb, path := "POST", "/api/latest/fleet/queries/run_by_names"
|
|
var responseBody createDistributedQueryCampaignResponse
|
|
err := c.authenticatedRequest(req, verb, path, &responseBody)
|
|
if err != nil {
|
|
return nil, ctxerr.Errorf(ctx, "create live query: %v", err)
|
|
}
|
|
|
|
// Copy default dialer but skip cert verification if set.
|
|
dialer := &websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: 45 * time.Second,
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.insecureSkipVerify},
|
|
}
|
|
|
|
wssURL := *c.baseURL
|
|
wssURL.Scheme = "wss"
|
|
if flag.Lookup("test.v") != nil {
|
|
wssURL.Scheme = "ws"
|
|
}
|
|
wssURL.Path = c.urlPrefix + "/api/latest/fleet/results/websocket"
|
|
// Ensure custom headers (set by config) are added to websocket request
|
|
headers := make(http.Header)
|
|
for k, v := range c.customHeaders {
|
|
headers.Set(k, v)
|
|
}
|
|
conn, _, err := dialer.Dial(wssURL.String(), headers)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "upgrade live query result websocket")
|
|
}
|
|
// Cannot defer connection closing here because we need it to remain
|
|
// open for the goroutine below. Manually close for the couple of error
|
|
// cases below until we enter that goroutine.
|
|
|
|
err = conn.WriteJSON(ws.JSONMessage{
|
|
Type: "auth",
|
|
Data: map[string]interface{}{"token": c.token},
|
|
})
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, ctxerr.Wrap(ctx, err, "auth for results")
|
|
}
|
|
|
|
err = conn.WriteJSON(ws.JSONMessage{
|
|
Type: "select_campaign",
|
|
Data: map[string]interface{}{"campaign_id": responseBody.Campaign.ID},
|
|
})
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting results")
|
|
}
|
|
|
|
resHandler := NewLiveQueryResultsHandler()
|
|
go func() {
|
|
defer conn.Close()
|
|
for {
|
|
msg := struct {
|
|
Type string `json:"type"`
|
|
Data json.RawMessage `json:"data"`
|
|
}{}
|
|
|
|
doneReadingChan := make(chan error)
|
|
|
|
go func() {
|
|
doneReadingChan <- conn.ReadJSON(&msg)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case err := <-doneReadingChan:
|
|
if err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "receive ws message")
|
|
if errors.Is(err, websocket.ErrCloseSent) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
close(doneReadingChan)
|
|
|
|
switch msg.Type {
|
|
case "result":
|
|
var res fleet.DistributedQueryResult
|
|
if err := json.Unmarshal(msg.Data, &res); err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal results")
|
|
}
|
|
resHandler.results <- res
|
|
|
|
case "totals":
|
|
var totals targetTotals
|
|
if err := json.Unmarshal(msg.Data, &totals); err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal totals")
|
|
}
|
|
resHandler.totals.Store(&totals)
|
|
|
|
case "status":
|
|
var status campaignStatus
|
|
if err := json.Unmarshal(msg.Data, &status); err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal status")
|
|
}
|
|
resHandler.status.Store(&status)
|
|
|
|
default:
|
|
resHandler.errors <- ctxerr.Errorf(ctx, "unknown msg type %s", msg.Type)
|
|
}
|
|
}
|
|
}()
|
|
|
|
return resHandler, nil
|
|
}
|