mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
083b781c2c
This needed to be set specifically for the live query websocket dialer, since that doesn't use the HTTP client that the rest of the API client methods use. Co-authored-by: Tomas Touceda <chiiph@gmail.com>
183 lines
5.0 KiB
Go
183 lines
5.0 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"net/http"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
|
|
ws "github.com/fleetdm/fleet/v4/server/websocket"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// LiveQueryResultsHandler provides access to all of the information about an
|
|
// incoming stream of live query results.
|
|
type LiveQueryResultsHandler struct {
|
|
errors chan error
|
|
results chan fleet.DistributedQueryResult
|
|
totals atomic.Value // real type: targetTotals
|
|
status atomic.Value // real type: campaignStatus
|
|
}
|
|
|
|
func NewLiveQueryResultsHandler() *LiveQueryResultsHandler {
|
|
return &LiveQueryResultsHandler{
|
|
errors: make(chan error),
|
|
results: make(chan fleet.DistributedQueryResult),
|
|
}
|
|
}
|
|
|
|
// Errors returns a read channel that includes any errors returned by the
|
|
// server or receiving the results.
|
|
func (h *LiveQueryResultsHandler) Errors() <-chan error {
|
|
return h.errors
|
|
}
|
|
|
|
// Results returns a read channel including any received results
|
|
func (h *LiveQueryResultsHandler) Results() <-chan fleet.DistributedQueryResult {
|
|
return h.results
|
|
}
|
|
|
|
// Totals returns the current metadata of hosts targeted by the query
|
|
func (h *LiveQueryResultsHandler) Totals() *targetTotals {
|
|
t := h.totals.Load()
|
|
if t != nil {
|
|
return t.(*targetTotals)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *LiveQueryResultsHandler) Status() *campaignStatus {
|
|
s := h.status.Load()
|
|
if s != nil {
|
|
return s.(*campaignStatus)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LiveQuery creates a new live query and begins streaming results.
|
|
func (c *Client) LiveQuery(query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
|
|
return c.LiveQueryWithContext(context.Background(), query, labels, hosts)
|
|
}
|
|
|
|
func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
|
|
req := createDistributedQueryCampaignByNamesRequest{
|
|
QuerySQL: query,
|
|
Selected: distributedQueryCampaignTargetsByNames{Labels: labels, Hosts: hosts},
|
|
}
|
|
verb, path := "POST", "/api/latest/fleet/queries/run_by_names"
|
|
var responseBody createDistributedQueryCampaignResponse
|
|
err := c.authenticatedRequest(req, verb, path, &responseBody)
|
|
if err != nil {
|
|
return nil, ctxerr.Errorf(ctx, "create live query: %v", err)
|
|
}
|
|
|
|
// Copy default dialer but skip cert verification if set.
|
|
dialer := &websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: 45 * time.Second,
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.insecureSkipVerify},
|
|
}
|
|
|
|
wssURL := *c.baseURL
|
|
wssURL.Scheme = "wss"
|
|
if flag.Lookup("test.v") != nil {
|
|
wssURL.Scheme = "ws"
|
|
}
|
|
wssURL.Path = c.urlPrefix + "/api/latest/fleet/results/websocket"
|
|
// Ensure custom headers (set by config) are added to websocket request
|
|
headers := make(http.Header)
|
|
for k, v := range c.customHeaders {
|
|
headers.Set(k, v)
|
|
}
|
|
conn, _, err := dialer.Dial(wssURL.String(), headers)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "upgrade live query result websocket")
|
|
}
|
|
// Cannot defer connection closing here because we need it to remain
|
|
// open for the goroutine below. Manually close for the couple of error
|
|
// cases below until we enter that goroutine.
|
|
|
|
err = conn.WriteJSON(ws.JSONMessage{
|
|
Type: "auth",
|
|
Data: map[string]interface{}{"token": c.token},
|
|
})
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, ctxerr.Wrap(ctx, err, "auth for results")
|
|
}
|
|
|
|
err = conn.WriteJSON(ws.JSONMessage{
|
|
Type: "select_campaign",
|
|
Data: map[string]interface{}{"campaign_id": responseBody.Campaign.ID},
|
|
})
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting results")
|
|
}
|
|
|
|
resHandler := NewLiveQueryResultsHandler()
|
|
go func() {
|
|
defer conn.Close()
|
|
for {
|
|
msg := struct {
|
|
Type string `json:"type"`
|
|
Data json.RawMessage `json:"data"`
|
|
}{}
|
|
|
|
doneReadingChan := make(chan error)
|
|
|
|
go func() {
|
|
doneReadingChan <- conn.ReadJSON(&msg)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case err := <-doneReadingChan:
|
|
if err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "receive ws message")
|
|
if errors.Is(err, websocket.ErrCloseSent) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
close(doneReadingChan)
|
|
|
|
switch msg.Type {
|
|
case "result":
|
|
var res fleet.DistributedQueryResult
|
|
if err := json.Unmarshal(msg.Data, &res); err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal results")
|
|
}
|
|
resHandler.results <- res
|
|
|
|
case "totals":
|
|
var totals targetTotals
|
|
if err := json.Unmarshal(msg.Data, &totals); err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal totals")
|
|
}
|
|
resHandler.totals.Store(&totals)
|
|
|
|
case "status":
|
|
var status campaignStatus
|
|
if err := json.Unmarshal(msg.Data, &status); err != nil {
|
|
resHandler.errors <- ctxerr.Wrap(ctx, err, "unmarshal status")
|
|
}
|
|
resHandler.status.Store(&status)
|
|
|
|
default:
|
|
resHandler.errors <- ctxerr.Errorf(ctx, "unknown msg type %s", msg.Type)
|
|
}
|
|
}
|
|
}()
|
|
|
|
return resHandler, nil
|
|
}
|