mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 00:45:19 +00:00
4cbafafc91
#15703 `wall_time_ms` (not `wall_time`) coming from osquery will now be put into `wall_time` in our DB. Hence, `wall_time` coming from Fleet API will now be in milliseconds and not seconds # Checklist for submitter If some of the following don't apply, delete the relevant line. <!-- Note that API documentation changes are now addressed by the product design team. --> - [x] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Added/updated tests - [x] If database migrations are included, checked table schema to confirm autoupdate - For database migrations: - [x] Checked schema for all modified table for columns that will auto-update timestamps during migration. - [x] Confirmed that updating the timestamps is acceptable, and will not cause unwanted side effects. - [x] Manual QA for all new/changed functionality
353 lines
12 KiB
Go
353 lines
12 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/WatchBeam/clock"
|
|
"github.com/fleetdm/fleet/v4/server/config"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
|
|
"github.com/fleetdm/fleet/v4/server/mock"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/pubsub"
|
|
ws "github.com/fleetdm/fleet/v4/server/websocket"
|
|
kitlog "github.com/go-kit/kit/log"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) {
|
|
t.Skip("Seems to be a bit problematic in CI")
|
|
|
|
store := pubsub.SetupRedisForTest(t, false, false)
|
|
|
|
mockClock := clock.NewMockClock()
|
|
ds := new(mock.Store)
|
|
lq := live_query_mock.New(t)
|
|
svc, ctx := newTestServiceWithClock(t, ds, store, lq, mockClock)
|
|
|
|
campaign := &fleet.DistributedQueryCampaign{ID: 42}
|
|
|
|
ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) {
|
|
return map[string]string{}, nil
|
|
}
|
|
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
|
return &fleet.AppConfig{}, nil
|
|
}
|
|
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
|
|
return query, nil
|
|
}
|
|
ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) {
|
|
return camp, nil
|
|
}
|
|
ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) {
|
|
return target, nil
|
|
}
|
|
ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) {
|
|
return []uint{1}, nil
|
|
}
|
|
ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) {
|
|
return fleet.TargetMetrics{TotalHosts: 1}, nil
|
|
}
|
|
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
|
|
return nil
|
|
}
|
|
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
|
|
return &fleet.Session{
|
|
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
|
|
ID: 42,
|
|
AccessedAt: time.Now(),
|
|
UserID: 999,
|
|
Key: "asd",
|
|
}, nil
|
|
}
|
|
|
|
host := &fleet.Host{ID: 1, Platform: "windows"}
|
|
|
|
lq.On("QueriesForHost", uint(1)).Return(
|
|
map[string]string{
|
|
strconv.Itoa(int(campaign.ID)): "select * from time",
|
|
},
|
|
nil,
|
|
)
|
|
lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil)
|
|
lq.On("RunQuery", "0", "select year, month, day, hour, minutes, seconds from time", []uint{1}).Return(nil)
|
|
viewerCtx := viewer.NewContext(ctx, viewer.Viewer{
|
|
User: &fleet.User{
|
|
ID: 0,
|
|
GlobalRole: ptr.String(fleet.RoleAdmin),
|
|
},
|
|
})
|
|
q := "select year, month, day, hour, minutes, seconds from time"
|
|
_, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}})
|
|
require.NoError(t, err)
|
|
|
|
pathHandler := makeStreamDistributedQueryCampaignResultsHandler(config.TestConfig().Server, svc, kitlog.NewNopLogger())
|
|
s := httptest.NewServer(pathHandler("/api/latest/fleet/results/"))
|
|
defer s.Close()
|
|
// Convert http://127.0.0.1 to ws://127.0.0.1
|
|
u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/latest/fleet/results/websocket"
|
|
|
|
// Connect to the server
|
|
dialer := &websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: 45 * time.Second,
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
|
|
conn, _, err := dialer.Dial(u, nil)
|
|
if err != nil {
|
|
t.Fatalf("%v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.WriteJSON(ws.JSONMessage{
|
|
Type: "auth",
|
|
Data: map[string]interface{}{"token": "asd"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = conn.WriteJSON(ws.JSONMessage{
|
|
Type: "select_campaign",
|
|
Data: map[string]interface{}{"campaign_id": campaign.ID},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ds.MarkSessionAccessedFunc = func(context.Context, *fleet.Session) error {
|
|
return nil
|
|
}
|
|
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
|
return &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, nil
|
|
}
|
|
ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) {
|
|
return campaign, nil
|
|
}
|
|
ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) error {
|
|
return nil
|
|
}
|
|
ds.DistributedQueryCampaignTargetIDsFunc = func(ctx context.Context, id uint) (targets *fleet.HostTargets, err error) {
|
|
return &fleet.HostTargets{HostIDs: []uint{1}}, nil
|
|
}
|
|
ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) {
|
|
return &fleet.Query{}, nil
|
|
}
|
|
|
|
/*****************************************************************************************/
|
|
/* THE ACTUAL TEST BEGINS HERE */
|
|
/*****************************************************************************************/
|
|
prevActiveConn := 0
|
|
for prevActiveConn < 3 {
|
|
time.Sleep(2 * time.Second)
|
|
|
|
for _, s := range store.Pool().Stats() {
|
|
prevActiveConn = s.ActiveCount
|
|
}
|
|
}
|
|
|
|
conn.Close()
|
|
time.Sleep(10 * time.Second)
|
|
|
|
newActiveConn := prevActiveConn
|
|
for _, s := range store.Pool().Stats() {
|
|
newActiveConn = s.ActiveCount
|
|
}
|
|
require.Equal(t, prevActiveConn-1, newActiveConn)
|
|
}
|
|
|
|
func TestUpdateStats(t *testing.T) {
|
|
ds := mysql.CreateMySQLDS(t)
|
|
defer mysql.TruncateTables(t, ds)
|
|
s, ctx := newTestService(t, ds, nil, nil)
|
|
svc := s.(validationMiddleware).Service.(*Service)
|
|
|
|
tracker := statsTracker{}
|
|
// NOOP cases
|
|
svc.updateStats(ctx, 0, svc.logger, nil, false)
|
|
svc.updateStats(ctx, 0, svc.logger, &tracker, false)
|
|
|
|
// More NOOP cases
|
|
tracker.saveStats = true
|
|
svc.updateStats(ctx, 0, svc.logger, nil, false)
|
|
assert.True(t, tracker.saveStats)
|
|
svc.updateStats(ctx, 0, svc.logger, nil, true)
|
|
assert.True(t, tracker.saveStats)
|
|
|
|
// Populate a batch of data
|
|
hostIDs := []uint{}
|
|
queryID := uint(1)
|
|
myHostID := uint(10000)
|
|
myWallTime := uint64(5)
|
|
myUserTime := uint64(6)
|
|
mySystemTime := uint64(7)
|
|
myMemory := uint64(8)
|
|
myOutputSize := uint64(9)
|
|
tracker.stats = append(
|
|
tracker.stats, statsToSave{
|
|
hostID: myHostID,
|
|
Stats: &fleet.Stats{
|
|
WallTimeMs: myWallTime,
|
|
UserTime: myUserTime,
|
|
SystemTime: mySystemTime,
|
|
Memory: myMemory,
|
|
},
|
|
outputSize: myOutputSize,
|
|
},
|
|
)
|
|
hostIDs = append(hostIDs, myHostID)
|
|
|
|
for i := uint(1); i < statsBatchSize; i++ {
|
|
tracker.stats = append(
|
|
tracker.stats, statsToSave{
|
|
hostID: i,
|
|
Stats: &fleet.Stats{
|
|
WallTimeMs: rand.Uint64(),
|
|
UserTime: rand.Uint64(),
|
|
SystemTime: rand.Uint64(),
|
|
Memory: rand.Uint64(),
|
|
},
|
|
outputSize: rand.Uint64(),
|
|
},
|
|
)
|
|
hostIDs = append(hostIDs, i)
|
|
}
|
|
tracker.saveStats = true
|
|
svc.updateStats(ctx, queryID, svc.logger, &tracker, false)
|
|
assert.True(t, tracker.saveStats)
|
|
assert.Equal(t, 0, len(tracker.stats))
|
|
assert.True(t, tracker.aggregationNeeded)
|
|
|
|
// Get the stats from DB and make sure they match
|
|
currentStats, err := svc.ds.GetLiveQueryStats(ctx, queryID, hostIDs)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, statsBatchSize, len(currentStats))
|
|
currentStats, err = svc.ds.GetLiveQueryStats(ctx, queryID, []uint{myHostID})
|
|
assert.NoError(t, err)
|
|
require.Equal(t, 1, len(currentStats))
|
|
myStat := currentStats[0]
|
|
assert.Equal(t, myHostID, myStat.HostID)
|
|
assert.Equal(t, uint64(1), myStat.Executions)
|
|
assert.Equal(t, myWallTime, myStat.WallTime)
|
|
assert.Equal(t, myUserTime, myStat.UserTime)
|
|
assert.Equal(t, mySystemTime, myStat.SystemTime)
|
|
assert.Equal(t, myMemory, myStat.AverageMemory)
|
|
assert.Equal(t, myOutputSize, myStat.OutputSize)
|
|
|
|
// Aggregate stats
|
|
svc.updateStats(ctx, queryID, svc.logger, &tracker, true)
|
|
aggStats, err := mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, statsBatchSize, int(*aggStats.TotalExecutions))
|
|
// Sanity checks. Complete testing done in aggregated_stats_test.go
|
|
assert.True(t, *aggStats.SystemTimeP50 > 0)
|
|
assert.True(t, *aggStats.SystemTimeP95 > 0)
|
|
assert.True(t, *aggStats.UserTimeP50 > 0)
|
|
assert.True(t, *aggStats.UserTimeP95 > 0)
|
|
|
|
// Write new stats (update) for the same query/hosts
|
|
myNewWallTime := uint64(15)
|
|
myNewUserTime := uint64(16)
|
|
myNewSystemTime := uint64(17)
|
|
myNewMemory := uint64(18)
|
|
myNewOutputSize := uint64(19)
|
|
tracker.stats = append(
|
|
tracker.stats, statsToSave{
|
|
hostID: myHostID,
|
|
Stats: &fleet.Stats{
|
|
WallTimeMs: myNewWallTime,
|
|
UserTime: myNewUserTime,
|
|
SystemTime: myNewSystemTime,
|
|
Memory: myNewMemory,
|
|
},
|
|
outputSize: myNewOutputSize,
|
|
},
|
|
)
|
|
|
|
for i := uint(1); i < statsBatchSize; i++ {
|
|
tracker.stats = append(
|
|
tracker.stats, statsToSave{
|
|
hostID: i,
|
|
Stats: &fleet.Stats{
|
|
WallTimeMs: rand.Uint64(),
|
|
UserTime: rand.Uint64(),
|
|
SystemTime: rand.Uint64(),
|
|
Memory: rand.Uint64(),
|
|
},
|
|
outputSize: rand.Uint64(),
|
|
},
|
|
)
|
|
}
|
|
tracker.saveStats = true
|
|
svc.updateStats(ctx, queryID, svc.logger, &tracker, true)
|
|
assert.True(t, tracker.saveStats)
|
|
assert.Equal(t, 0, len(tracker.stats))
|
|
assert.False(t, tracker.aggregationNeeded)
|
|
|
|
// Check that stats were updated
|
|
currentStats, err = svc.ds.GetLiveQueryStats(ctx, queryID, []uint{myHostID})
|
|
assert.NoError(t, err)
|
|
require.Equal(t, 1, len(currentStats))
|
|
myStat = currentStats[0]
|
|
assert.Equal(t, myHostID, myStat.HostID)
|
|
assert.Equal(t, uint64(2), myStat.Executions)
|
|
assert.Equal(t, myWallTime+myNewWallTime, myStat.WallTime)
|
|
assert.Equal(t, myUserTime+myNewUserTime, myStat.UserTime)
|
|
assert.Equal(t, mySystemTime+myNewSystemTime, myStat.SystemTime)
|
|
assert.Equal(t, (myMemory+myNewMemory)/2, myStat.AverageMemory)
|
|
assert.Equal(t, myOutputSize+myNewOutputSize, myStat.OutputSize)
|
|
|
|
// Check that aggregated stats were updated
|
|
aggStats, err = mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, statsBatchSize*2, int(*aggStats.TotalExecutions))
|
|
// Sanity checks. Complete testing done in aggregated_stats_test.go
|
|
assert.True(t, *aggStats.SystemTimeP50 > 0)
|
|
assert.True(t, *aggStats.SystemTimeP95 > 0)
|
|
assert.True(t, *aggStats.UserTimeP50 > 0)
|
|
assert.True(t, *aggStats.UserTimeP95 > 0)
|
|
}
|
|
|
|
func TestCalculateOutputSize(t *testing.T) {
|
|
createResult := func() *fleet.DistributedQueryResult {
|
|
result := fleet.DistributedQueryResult{}
|
|
result.Rows = append(result.Rows, nil)
|
|
result.Rows = append(result.Rows, map[string]string{})
|
|
result.Rows = append(result.Rows, map[string]string{"a": "b", "a1": "b1"})
|
|
result.Rows = append(result.Rows, map[string]string{"c": "d"})
|
|
result.Stats = &fleet.Stats{}
|
|
return &result
|
|
}
|
|
t.Run(
|
|
"output size save disabled", func(t *testing.T) {
|
|
tracker := statsTracker{saveStats: false}
|
|
size := calculateOutputSize(&tracker, createResult())
|
|
require.Equal(t, uint64(0), size)
|
|
},
|
|
)
|
|
t.Run(
|
|
"output size empty", func(t *testing.T) {
|
|
tracker := statsTracker{saveStats: true}
|
|
size := calculateOutputSize(&tracker, &fleet.DistributedQueryResult{})
|
|
require.Equal(t, uint64(0), size)
|
|
},
|
|
)
|
|
t.Run(
|
|
"output size calculate", func(t *testing.T) {
|
|
tracker := statsTracker{saveStats: true}
|
|
size := calculateOutputSize(&tracker, createResult())
|
|
expected := uint64(8) // manually calculated
|
|
require.Equal(t, expected, size)
|
|
},
|
|
)
|
|
}
|