fleet/server/service/client_live_query.go

178 lines
4.8 KiB
Go

package service
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"net/http"
"sync/atomic"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
ws "github.com/fleetdm/fleet/v4/server/websocket"
"github.com/gorilla/websocket"
)
// LiveQueryResultsHandler provides access to all of the information about an
// incoming stream of live query results.
type LiveQueryResultsHandler struct {
errors chan error
results chan fleet.DistributedQueryResult
totals atomic.Value // real type: targetTotals
status atomic.Value // real type: campaignStatus
}
func NewLiveQueryResultsHandler() *LiveQueryResultsHandler {
return &LiveQueryResultsHandler{
errors: make(chan error),
results: make(chan fleet.DistributedQueryResult),
}
}
// Errors returns a read channel that includes any errors returned by the
// server or receiving the results.
func (h *LiveQueryResultsHandler) Errors() <-chan error {
return h.errors
}
// Results returns a read channel including any received results
func (h *LiveQueryResultsHandler) Results() <-chan fleet.DistributedQueryResult {
return h.results
}
// Totals returns the current metadata of hosts targeted by the query
func (h *LiveQueryResultsHandler) Totals() *targetTotals {
t := h.totals.Load()
if t != nil {
return t.(*targetTotals)
}
return nil
}
func (h *LiveQueryResultsHandler) Status() *campaignStatus {
s := h.status.Load()
if s != nil {
return s.(*campaignStatus)
}
return nil
}
// LiveQuery creates a new live query and begins streaming results.
func (c *Client) LiveQuery(query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
return c.LiveQueryWithContext(context.Background(), query, labels, hosts)
}
func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
req := createDistributedQueryCampaignByNamesRequest{
QuerySQL: query,
Selected: distributedQueryCampaignTargetsByNames{Labels: labels, Hosts: hosts},
}
verb, path := "POST", "/api/latest/fleet/queries/run_by_names"
var responseBody createDistributedQueryCampaignResponse
err := c.authenticatedRequest(req, verb, path, &responseBody)
if err != nil {
return nil, ctxerr.Errorf(ctx, "create live query: %v", err)
}
// Copy default dialer but skip cert verification if set.
dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.insecureSkipVerify},
}
wssURL := *c.baseURL
wssURL.Scheme = "wss"
if flag.Lookup("test.v") != nil {
wssURL.Scheme = "ws"
}
wssURL.Path = c.urlPrefix + "/api/latest/fleet/results/websocket"
conn, _, err := dialer.Dial(wssURL.String(), nil)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "upgrade live query result websocket")
}
// Cannot defer connection closing here because we need it to remain
// open for the goroutine below. Manually close for the couple of error
// cases below until we enter that goroutine.
err = conn.WriteJSON(ws.JSONMessage{
Type: "auth",
Data: map[string]interface{}{"token": c.token},
})
if err != nil {
_ = conn.Close()
return nil, ctxerr.Wrap(ctx, err, "auth for results")
}
err = conn.WriteJSON(ws.JSONMessage{
Type: "select_campaign",
Data: map[string]interface{}{"campaign_id": responseBody.Campaign.ID},
})
if err != nil {
_ = conn.Close()
return nil, ctxerr.Wrap(ctx, err, "selecting results")
}
resHandler := NewLiveQueryResultsHandler()
go func() {
defer conn.Close()
for {
msg := struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}{}
doneReadingChan := make(chan error)
go func() {
doneReadingChan <- conn.ReadJSON(&msg)
}()
select {
case <-ctx.Done():
return
case err := <-doneReadingChan:
if err != nil {
resHandler.errors <- ctxerr.Wrap(ctx, err, "receive ws message")
if errors.Is(err, websocket.ErrCloseSent) {
return
}
}
}
close(doneReadingChan)
switch msg.Type {
case "result":
var res fleet.DistributedQueryResult
if err := json.Unmarshal(msg.Data, &res); err != nil {
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal results")
}
resHandler.results <- res
case "totals":
var totals targetTotals
if err := json.Unmarshal(msg.Data, &totals); err != nil {
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal totals")
}
resHandler.totals.Store(&totals)
case "status":
var status campaignStatus
if err := json.Unmarshal(msg.Data, &status); err != nil {
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal status")
}
resHandler.status.Store(&status)
default:
resHandler.errors <- ctxerr.Errorf(ctx, "unknown msg type %s", msg.Type)
}
}
}()
return resHandler, nil
}