fleet/server/service/client_live_query_test.go

121 lines
2.8 KiB
Go

package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLiveQueryWithContext(t *testing.T) {
upgrader := websocket.Upgrader{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/fleet/queries/run_by_names":
resp := createDistributedQueryCampaignResponse{
Campaign: &fleet.DistributedQueryCampaign{
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()},
},
Metrics: fleet.TargetMetrics{
TotalHosts: 1,
OnlineHosts: 1,
OfflineHosts: 0,
MissingInActionHosts: 0,
NewHosts: 0,
},
ID: 99,
QueryID: 42,
Status: 0,
UserID: 23,
},
}
err := json.NewEncoder(w).Encode(resp)
assert.NoError(t, err)
case "/api/v1/fleet/results/websocket":
ws, _ := upgrader.Upgrade(w, r, nil)
defer ws.Close()
for {
time.Sleep(1 * time.Second)
mt, message, _ := ws.ReadMessage()
if string(message) == `{"type":"auth","data":{"token":"1234"}}` {
return
}
if string(message) == `{"type":"select_campaign","data":{"campaign_id":99}}` {
return
}
result := struct {
Type string `json:"type"`
Data fleet.DistributedQueryResult `json:"data"`
}{
Type: "result",
Data: fleet.DistributedQueryResult{
DistributedQueryCampaignID: 99,
Host: fleet.Host{
ID: 23,
Hostname: "somehostaaa",
},
Rows: []map[string]string{
{
"col1": "aaa",
"col2": "bbb",
},
},
Error: nil,
},
}
b, err := json.Marshal(result)
assert.NoError(t, err)
_ = ws.WriteMessage(mt, b)
}
}
}))
defer ts.Close()
baseURL, err := url.Parse(ts.URL)
require.NoError(t, err)
client := &Client{
baseURL: baseURL,
urlPrefix: "",
token: "1234",
http: fleethttp.NewClient(),
insecureSkipVerify: false,
writer: nil,
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFunc()
res, err := client.LiveQueryWithContext(ctx, "select 1;", nil, []string{"host1"})
require.NoError(t, err)
gotResults := false
go func() {
for {
select {
case <-res.Results():
gotResults = true
cancelFunc()
case err := <-res.Errors():
require.NoError(t, err)
case <-ctx.Done():
return
}
}
}()
<-ctx.Done()
assert.True(t, gotResults)
}