Performance stats for live queries (#15440)

📺 Loom explaining the PR:
https://www.loom.com/share/8c6ec0c362014910931f183d68307525?sid=7de51e6f-a59d-4d8f-b06b-f703f1692f17

#467 
Live Queries now collect stats. Stats are collected for saved queries
present in the Queries tab.
- After running a live query, user will see updated stats in Queries
tab.
- Query stats on Host Details page will no longer be cleared after
host/agent reboots.
- Query stats are now deleted when query is deleted.

# Checklist for submitter
- [x] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
This commit is contained in:
Victor Lyuboslavsky 2023-12-13 14:46:59 -06:00 committed by GitHub
parent 2520fc2179
commit 835eedae42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1120 additions and 183 deletions

View File

@ -0,0 +1,4 @@
Live Queries now collect stats. Stats are collected for saved queries present in the Queries tab.
- After running a live query, user will see updated stats in Queries tab.
- Query stats on Host Details page will no longer be cleared after host/agent reboots.
- Query stats are now deleted when query is deleted.

View File

@ -92,6 +92,18 @@ func TestLiveQuery(t *testing.T) {
ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) {
return &fleet.Query{}, nil
}
ds.IsSavedQueryFunc = func(ctx context.Context, queryID uint) (bool, error) {
return true, nil
}
ds.GetLiveQueryStatsFunc = func(ctx context.Context, queryID uint, hostIDs []uint) ([]*fleet.LiveQueryStats, error) {
return nil, nil
}
ds.UpdateLiveQueryStatsFunc = func(ctx context.Context, queryID uint, stats []*fleet.LiveQueryStats) error {
return nil
}
ds.CalculateAggregatedPerfStatsPercentilesFunc = func(ctx context.Context, aggregate fleet.AggregatedStatsType, queryID uint) error {
return nil
}
go func() {
time.Sleep(2 * time.Second)
@ -104,6 +116,12 @@ func TestLiveQuery(t *testing.T) {
Hostname: "somehostname",
DisplayName: "somehostname",
},
Stats: &fleet.Stats{
WallTimeMs: 10,
UserTime: 20,
SystemTime: 30,
Memory: 40,
},
},
))
}()
@ -111,4 +129,7 @@ func TestLiveQuery(t *testing.T) {
expected := `{"host":"somehostname","rows":[{"bing":"fds","host_display_name":"somehostname","host_hostname":"somehostname"}]}
`
assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query", "select 42, * from time"}))
assert.True(t, ds.GetLiveQueryStatsFuncInvoked)
assert.True(t, ds.UpdateLiveQueryStatsFuncInvoked)
assert.True(t, ds.CalculateAggregatedPerfStatsPercentilesFuncInvoked)
}

View File

@ -1292,31 +1292,39 @@ func (a *agent) diskEncryptionLinux() []map[string]string {
}
}
func (a *agent) runLiveQuery(query string) (results []map[string]string, status *fleet.OsqueryStatus, message *string) {
func (a *agent) runLiveQuery(query string) (results []map[string]string, status *fleet.OsqueryStatus, message *string, stats *fleet.Stats) {
if a.liveQueryFailProb > 0.0 && rand.Float64() <= a.liveQueryFailProb {
ss := fleet.OsqueryStatus(1)
return []map[string]string{}, &ss, ptr.String("live query failed with error foobar")
return []map[string]string{}, &ss, ptr.String("live query failed with error foobar"), nil
}
ss := fleet.OsqueryStatus(0)
if a.liveQueryNoResultsProb > 0.0 && rand.Float64() <= a.liveQueryNoResultsProb {
return []map[string]string{}, &ss, nil
return []map[string]string{}, &ss, nil, nil
}
return []map[string]string{{
"admindir": "/var/lib/dpkg",
"arch": "amd64",
"maintainer": "foobar",
"name": "netconf",
"priority": "optional",
"revision": "",
"section": "default",
"size": "112594",
"source": "",
"status": "install ok installed",
"version": "20230224000000",
}}, &ss, nil
"admindir": "/var/lib/dpkg",
"arch": "amd64",
"maintainer": "foobar",
"name": "netconf",
"priority": "optional",
"revision": "",
"section": "default",
"size": "112594",
"source": "",
"status": "install ok installed",
"version": "20230224000000",
},
}, &ss, nil, &fleet.Stats{
WallTimeMs: uint64(rand.Intn(1000) * 1000),
UserTime: uint64(rand.Intn(1000)),
SystemTime: uint64(rand.Intn(1000)),
Memory: uint64(rand.Intn(1000)),
}
}
func (a *agent) processQuery(name, query string) (handled bool, results []map[string]string, status *fleet.OsqueryStatus, message *string) {
func (a *agent) processQuery(name, query string) (
handled bool, results []map[string]string, status *fleet.OsqueryStatus, message *string, stats *fleet.Stats,
) {
const (
hostPolicyQueryPrefix = "fleet_policy_query_"
hostDetailQueryPrefix = "fleet_detail_query_"
@ -1327,60 +1335,60 @@ func (a *agent) processQuery(name, query string) (handled bool, results []map[st
switch {
case strings.HasPrefix(name, liveQueryPrefix):
results, status, message = a.runLiveQuery(query)
return true, results, status, message
results, status, message, stats = a.runLiveQuery(query)
return true, results, status, message, stats
case strings.HasPrefix(name, hostPolicyQueryPrefix):
return true, a.runPolicy(query), &statusOK, nil
return true, a.runPolicy(query), &statusOK, nil, nil
case name == hostDetailQueryPrefix+"scheduled_query_stats":
return true, a.randomQueryStats(), &statusOK, nil
return true, a.randomQueryStats(), &statusOK, nil, nil
case name == hostDetailQueryPrefix+"mdm":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.mdmMac()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"mdm_windows":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.mdmWindows()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"munki_info":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.munkiInfo()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"google_chrome_profiles":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.googleChromeProfiles()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"battery":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.batteries()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"users":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.hostUsers()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"software_macos":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.softwareMacOS()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"software_windows":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = windowsSoftware
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"software_linux":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
@ -1389,37 +1397,37 @@ func (a *agent) processQuery(name, query string) (handled bool, results []map[st
results = ubuntuSoftware
}
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"disk_space_unix" || name == hostDetailQueryPrefix+"disk_space_windows":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.diskSpace()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case strings.HasPrefix(name, hostDetailQueryPrefix+"disk_encryption_linux"):
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.diskEncryptionLinux()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"disk_encryption_darwin" ||
name == hostDetailQueryPrefix+"disk_encryption_windows":
ss := fleet.OsqueryStatus(rand.Intn(2))
if ss == fleet.StatusOK {
results = a.diskEncryption()
}
return true, results, &ss, nil
return true, results, &ss, nil, nil
case name == hostDetailQueryPrefix+"kubequery_info" && a.os != "kubequery":
// Real osquery running on hosts would return no results if it was not
// running kubequery (due to discovery query). Returning true here so that
// the caller knows it is handled, will not try to return lorem-ipsum-style
// results.
return true, nil, &statusNotOK, nil
return true, nil, &statusNotOK, nil, nil
default:
// Look for results in the template file.
if t := a.templates.Lookup(name); t == nil {
return false, nil, nil, nil
return false, nil, nil, nil, nil
}
var ni bytes.Buffer
err := a.templates.ExecuteTemplate(&ni, name, a)
@ -1431,7 +1439,7 @@ func (a *agent) processQuery(name, query string) (handled bool, results []map[st
panic(err)
}
return true, results, &statusOK, nil
return true, results, &statusOK, nil, nil
}
}
@ -1440,10 +1448,11 @@ func (a *agent) DistributedWrite(queries map[string]string) {
Results: make(fleet.OsqueryDistributedQueryResults),
Statuses: make(map[string]fleet.OsqueryStatus),
Messages: make(map[string]string),
Stats: make(map[string]*fleet.Stats),
}
r.NodeKey = a.nodeKey
for name, query := range queries {
handled, results, status, message := a.processQuery(name, query)
handled, results, status, message, stats := a.processQuery(name, query)
if !handled {
// If osquery-perf does not handle the incoming query,
// always return status OK and the default query result.
@ -1459,6 +1468,9 @@ func (a *agent) DistributedWrite(queries map[string]string) {
if message != nil {
r.Messages[name] = *message
}
if stats != nil {
r.Stats[name] = stats
}
}
}
body, err := json.Marshal(r)

View File

@ -7,13 +7,11 @@ import (
"fmt"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/jmoiron/sqlx"
)
type aggregatedStatsType string
const (
aggregatedStatsTypeScheduledQuery = "scheduled_query"
aggregatedStatsTypeMunkiVersions = "munki_versions"
aggregatedStatsTypeMunkiIssues = "munki_issues"
aggregatedStatsTypeOSVersions = "os_versions"
@ -29,37 +27,38 @@ const (
// a slightly simpler version but that adds the rownum before sorting.
const scheduledQueryPercentileQuery = `
SELECT
coalesce((t1.%s / t1.executions), 0)
FROM (
SELECT (@rownum := @rownum + 1) AS row_number_value, mm.* FROM (
SELECT d.scheduled_query_id, d.%s, d.executions
FROM scheduled_query_stats d
WHERE d.scheduled_query_id=?
ORDER BY (d.%s / d.executions) ASC
) AS mm
) AS t1,
(SELECT @rownum := 0) AS r,
(
SELECT count(*) AS total_rows
FROM scheduled_query_stats d
WHERE d.scheduled_query_id=?
) AS t2
WHERE t1.row_number_value = floor(total_rows * %s) + 1;`
SELECT COALESCE((t1.%[1]s_total / t1.executions_total), 0)
FROM (SELECT (@rownum := @rownum + 1) AS row_number_value, sum1.*
FROM (SELECT SUM(d.%[1]s) as %[1]s_total, SUM(d.executions) as executions_total
FROM scheduled_query_stats d
WHERE d.scheduled_query_id = ?
AND d.executions > 0
GROUP BY d.host_id) as sum1
ORDER BY (%[1]s_total / executions_total)) AS t1,
(SELECT @rownum := 0) AS r,
(SELECT COUNT(*) AS total_rows
FROM (SELECT COUNT(*)
FROM scheduled_query_stats d
WHERE d.scheduled_query_id = ?
AND d.executions > 0
GROUP BY d.host_id) as sum2) AS t2
WHERE t1.row_number_value = FLOOR(total_rows * %[2]s) + 1`
const (
scheduledQueryTotalExecutions = `SELECT coalesce(sum(executions), 0) FROM scheduled_query_stats WHERE scheduled_query_id=?`
)
func getPercentileQuery(aggregate aggregatedStatsType, time string, percentile string) string {
func getPercentileQuery(aggregate fleet.AggregatedStatsType, time string, percentile string) string {
switch aggregate {
case aggregatedStatsTypeScheduledQuery:
return fmt.Sprintf(scheduledQueryPercentileQuery, time, time, time, percentile)
case fleet.AggregatedStatsTypeScheduledQuery:
return fmt.Sprintf(scheduledQueryPercentileQuery, time, percentile)
}
return ""
}
func setP50AndP95Map(ctx context.Context, tx sqlx.QueryerContext, aggregate aggregatedStatsType, time string, id uint, statsMap map[string]interface{}) error {
func setP50AndP95Map(
ctx context.Context, tx sqlx.QueryerContext, aggregate fleet.AggregatedStatsType, time string, id uint, statsMap map[string]interface{},
) error {
var p50, p95 float64
err := sqlx.GetContext(ctx, tx, &p50, getPercentileQuery(aggregate, time, "0.5"), id, id)
@ -83,9 +82,10 @@ func setP50AndP95Map(ctx context.Context, tx sqlx.QueryerContext, aggregate aggr
}
func (ds *Datastore) UpdateQueryAggregatedStats(ctx context.Context) error {
err := walkIdsInTable(ctx, ds.reader(ctx), "queries", func(id uint) error {
return calculatePercentiles(ctx, ds.writer(ctx), aggregatedStatsTypeScheduledQuery, id)
})
err := walkIdsInTable(
ctx, ds.reader(ctx), "queries", func(queryID uint) error {
return ds.CalculateAggregatedPerfStatsPercentiles(ctx, fleet.AggregatedStatsTypeScheduledQuery, queryID)
})
if err != nil {
return ctxerr.Wrap(ctx, err, "looping through query ids")
}
@ -93,22 +93,24 @@ func (ds *Datastore) UpdateQueryAggregatedStats(ctx context.Context) error {
return nil
}
func calculatePercentiles(ctx context.Context, tx sqlx.ExtContext, aggregate aggregatedStatsType, id uint) error {
// CalculateAggregatedPerfStatsPercentiles calculates the aggregated user/system time performance statistics for the given query.
func (ds *Datastore) CalculateAggregatedPerfStatsPercentiles(ctx context.Context, aggregate fleet.AggregatedStatsType, queryID uint) error {
tx := ds.writer(ctx)
var totalExecutions int
statsMap := make(map[string]interface{})
// many queries is not ideal, but getting both values and totals in the same query was a bit more complicated
// so I went for the simpler approach first, we can optimize later
if err := setP50AndP95Map(ctx, tx, aggregate, "user_time", id, statsMap); err != nil {
if err := setP50AndP95Map(ctx, tx, aggregate, "user_time", queryID, statsMap); err != nil {
return err
}
if err := setP50AndP95Map(ctx, tx, aggregate, "system_time", id, statsMap); err != nil {
if err := setP50AndP95Map(ctx, tx, aggregate, "system_time", queryID, statsMap); err != nil {
return err
}
err := sqlx.GetContext(ctx, tx, &totalExecutions, getTotalExecutionsQuery(aggregate), id)
err := sqlx.GetContext(ctx, tx, &totalExecutions, getTotalExecutionsQuery(aggregate), queryID)
if err != nil {
return ctxerr.Wrapf(ctx, err, "getting total executions for %s %d", aggregate, id)
return ctxerr.Wrapf(ctx, err, "getting total executions for %s %d", aggregate, queryID)
}
statsMap["total_executions"] = totalExecutions
@ -126,17 +128,17 @@ func calculatePercentiles(ctx context.Context, tx sqlx.ExtContext, aggregate agg
VALUES (?, ?, 0, ?)
ON DUPLICATE KEY UPDATE json_value=VALUES(json_value)
`,
id, aggregate, statsJson,
queryID, aggregate, statsJson,
)
if err != nil {
return ctxerr.Wrapf(ctx, err, "inserting stats for %s id %d", aggregate, id)
return ctxerr.Wrapf(ctx, err, "inserting stats for %s id %d", aggregate, queryID)
}
return nil
}
func getTotalExecutionsQuery(aggregate aggregatedStatsType) string {
func getTotalExecutionsQuery(aggregate fleet.AggregatedStatsType) string {
switch aggregate {
case aggregatedStatsTypeScheduledQuery:
case fleet.AggregatedStatsTypeScheduledQuery:
return scheduledQueryTotalExecutions
}
return ""

View File

@ -15,7 +15,15 @@ import (
)
func slowStats(t *testing.T, ds *Datastore, id uint, percentile int, column string) float64 {
queriesSQL := fmt.Sprintf(`SELECT d.%s / d.executions FROM scheduled_query_stats d JOIN queries q ON (d.scheduled_query_id=q.id) WHERE q.id=? ORDER BY (d.%s / d.executions) ASC`, column, column)
queriesSQL := fmt.Sprintf(
`
SELECT SUM(d.%[1]s) / SUM(d.executions)
FROM scheduled_query_stats d
JOIN queries q ON (d.scheduled_query_id=q.id)
WHERE q.id=? AND d.executions > 0
GROUP BY d.host_id
ORDER BY (SUM(d.%[1]s) / SUM(d.executions))`, column,
)
rows, err := ds.writer(context.Background()).Queryx(queriesSQL, id)
require.NoError(t, err)
defer rows.Close()
@ -56,19 +64,27 @@ func TestAggregatedStats(t *testing.T) {
_, err := ds.writer(context.Background()).Exec(`INSERT INTO scheduled_queries(query_id, name, query_name) VALUES (?,?,?)`, rand.Intn(queryCount)+1, fmt.Sprint(i), fmt.Sprint(i))
require.NoError(t, err)
}
insertScheduledQuerySQL := `INSERT IGNORE INTO scheduled_query_stats(host_id, scheduled_query_id, system_time, user_time, executions) VALUES %s`
insertScheduledQuerySQL := `INSERT IGNORE INTO scheduled_query_stats(host_id, scheduled_query_id, system_time, user_time, executions, query_type) VALUES %s`
scheduledQueryStatsCount := 100 // 1000000
for i := 0; i < scheduledQueryStatsCount; i++ {
if len(args) > batchSize {
values := strings.TrimSuffix(strings.Repeat("(?,?,?,?,?),", len(args)/5), ",")
values := strings.TrimSuffix(strings.Repeat("(?,?,?,?,?,?),", len(args)/6), ",")
_, err := ds.writer(context.Background()).Exec(fmt.Sprintf(insertScheduledQuerySQL, values), args...)
require.NoError(t, err)
args = []interface{}{}
}
args = append(args, rand.Intn(hostCount)+1, rand.Intn(queryCount)+1, rand.Intn(10000)+100, rand.Intn(10000)+100, rand.Intn(10000)+100)
// Occasionally set 0 executions
executions := rand.Intn(10000) + 100
if rand.Intn(100) < 5 {
executions = 0
}
args = append(
args, rand.Intn(hostCount)+1, rand.Intn(queryCount)+1, rand.Intn(10000)+100, rand.Intn(10000)+100, executions,
rand.Intn(2),
)
}
if len(args) > 0 {
values := strings.TrimSuffix(strings.Repeat("(?,?,?,?,?),", len(args)/5), ",")
values := strings.TrimSuffix(strings.Repeat("(?,?,?,?,?,?),", len(args)/6), ",")
_, err := ds.writer(context.Background()).Exec(fmt.Sprintf(insertScheduledQuerySQL, values), args...)
require.NoError(t, err)
}
@ -87,10 +103,10 @@ func TestAggregatedStats(t *testing.T) {
testcases := []struct {
table string
aggregate aggregatedStatsType
aggregate fleet.AggregatedStatsType
aggFunc func(ctx context.Context) error
}{
{"queries", aggregatedStatsTypeScheduledQuery, ds.UpdateQueryAggregatedStats},
{"queries", fleet.AggregatedStatsTypeScheduledQuery, ds.UpdateQueryAggregatedStats},
}
for _, tt := range testcases {
t.Run(tt.table, func(t *testing.T) {

View File

@ -328,8 +328,20 @@ func loadHostPackStatsDB(ctx context.Context, db sqlx.QueryerContext, hid uint,
goqu.I("queries").As("q"),
goqu.On(goqu.I("sq.query_id").Eq(goqu.I("q.id"))),
).LeftJoin(
dialect.From("scheduled_query_stats").As("sqs").Where(
goqu.I("host_id").Eq(hid),
goqu.L(
`
(SELECT
stats.scheduled_query_id,
CAST(AVG(stats.average_memory) AS UNSIGNED) AS average_memory,
MAX(stats.denylisted) AS denylisted,
SUM(stats.executions) AS executions,
MAX(stats.last_executed) AS last_executed,
SUM(stats.output_size) AS output_size,
SUM(stats.system_time) AS system_time,
SUM(stats.user_time) AS user_time,
SUM(stats.wall_time) AS wall_time
FROM scheduled_query_stats stats WHERE stats.host_id = ? GROUP BY stats.scheduled_query_id) as sqs
`, hid,
),
goqu.On(goqu.I("sqs.scheduled_query_id").Eq(goqu.I("sq.query_id"))),
).Where(
@ -383,17 +395,28 @@ func loadHostScheduledQueryStatsDB(ctx context.Context, db sqlx.QueryerContext,
q.discard_data,
q.automations_enabled,
MAX(qr.last_fetched) as last_fetched,
COALESCE(MAX(sqs.average_memory), 0) AS average_memory,
COALESCE(MAX(sqs.denylisted), false) AS denylisted,
COALESCE(MAX(sqs.executions), 0) AS executions,
COALESCE(MAX(sqs.last_executed), TIMESTAMP(?)) AS last_executed,
COALESCE(MAX(sqs.output_size), 0) AS output_size,
COALESCE(MAX(sqs.system_time), 0) AS system_time,
COALESCE(MAX(sqs.user_time), 0) AS user_time,
COALESCE(MAX(sqs.wall_time), 0) AS wall_time
COALESCE(sqs.average_memory, 0) AS average_memory,
COALESCE(sqs.denylisted, false) AS denylisted,
COALESCE(sqs.executions, 0) AS executions,
COALESCE(sqs.last_executed, TIMESTAMP(?)) AS last_executed,
COALESCE(sqs.output_size, 0) AS output_size,
COALESCE(sqs.system_time, 0) AS system_time,
COALESCE(sqs.user_time, 0) AS user_time,
COALESCE(sqs.wall_time, 0) AS wall_time
FROM
queries q
LEFT JOIN scheduled_query_stats sqs ON (q.id = sqs.scheduled_query_id AND sqs.host_id = ?)
LEFT JOIN
(SELECT
stats.scheduled_query_id,
CAST(AVG(stats.average_memory) AS UNSIGNED) AS average_memory,
MAX(stats.denylisted) AS denylisted,
SUM(stats.executions) AS executions,
MAX(stats.last_executed) AS last_executed,
SUM(stats.output_size) AS output_size,
SUM(stats.system_time) AS system_time,
SUM(stats.user_time) AS user_time,
SUM(stats.wall_time) AS wall_time
FROM scheduled_query_stats stats WHERE stats.host_id = ? GROUP BY stats.scheduled_query_id) as sqs ON (q.id = sqs.scheduled_query_id)
LEFT JOIN query_results qr ON (q.id = qr.query_id AND qr.host_id = ?)
WHERE
(q.platform = '' OR q.platform IS NULL OR FIND_IN_SET(?, q.platform) != 0)
@ -417,7 +440,6 @@ func loadHostScheduledQueryStatsDB(ctx context.Context, db sqlx.QueryerContext,
teamID_,
hid,
}
var stats []fleet.QueryStats
if err := sqlx.SelectContext(ctx, db, &stats, sqlQuery, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "load query stats")
@ -3685,11 +3707,11 @@ func (ds *Datastore) AggregatedMDMStatus(ctx context.Context, teamID *uint, plat
return status, statusJson.UpdatedAt, nil
}
func platformKey(key aggregatedStatsType, platform string) aggregatedStatsType {
func platformKey(key fleet.AggregatedStatsType, platform string) fleet.AggregatedStatsType {
if platform == "" {
return key
}
return key + "_" + aggregatedStatsType(platform)
return key + "_" + fleet.AggregatedStatsType(platform)
}
func (ds *Datastore) AggregatedMDMSolutions(ctx context.Context, teamID *uint, platform string) ([]fleet.AggregatedMDMSolutions, time.Time, error) {

View File

@ -0,0 +1,43 @@
package tables
import (
"database/sql"
"fmt"
)
func init() {
MigrationClient.AddMigration(Up_20231212161121, Down_20231212161121)
}
func Up_20231212161121(tx *sql.Tx) error {
stmt := `
ALTER TABLE scheduled_query_stats
ADD COLUMN query_type TINYINT NOT NULL DEFAULT 0;
`
if _, err := tx.Exec(stmt); err != nil {
return fmt.Errorf("add query_type to scheduled_query_stats: %w", err)
}
// Add query_type to primary key
stmt = `
ALTER TABLE scheduled_query_stats
DROP PRIMARY KEY,
ADD PRIMARY KEY (host_id, scheduled_query_id, query_type);
`
if _, err := tx.Exec(stmt); err != nil {
return fmt.Errorf("add query_type to scheduled_query_stats primary key: %w", err)
}
return nil
}
func Down_20231212161121(*sql.Tx) error {
/*
ALTER TABLE scheduled_query_stats
DROP PRIMARY KEY,
ADD PRIMARY KEY (host_id, scheduled_query_id);
ALTER TABLE scheduled_query_stats
DROP COLUMN query_type;
*/
return nil
}

View File

@ -0,0 +1,72 @@
package tables
import (
"fmt"
"github.com/stretchr/testify/require"
"testing"
)
func TestUp_20231212161121(t *testing.T) {
db := applyUpToPrev(t)
insertStmt := `
INSERT INTO scheduled_query_stats (host_id, scheduled_query_id, average_memory, denylisted, executions, schedule_interval, output_size, system_time, user_time, wall_time) VALUES
(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d);
`
setupStmt := fmt.Sprintf(insertStmt, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
_, err := db.Exec(setupStmt)
require.NoError(t, err)
// Apply current migration.
applyNext(t, db)
stmt := `
SELECT host_id, query_type FROM scheduled_query_stats WHERE host_id = 1;
`
rows, err := db.Query(stmt)
require.NoError(t, rows.Err())
require.NoError(t, err)
defer rows.Close()
count := 0
for rows.Next() {
count += 1
var hostId, queryType int
err := rows.Scan(&hostId, &queryType)
require.NoError(t, err)
require.Equal(t, 1, hostId)
require.Equal(t, 0, queryType)
}
require.Equal(t, 1, count)
insertStmt = `
INSERT INTO scheduled_query_stats (host_id, scheduled_query_id, average_memory, denylisted, executions, schedule_interval, output_size, system_time, user_time, wall_time, query_type) VALUES
(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d);
`
stmt = fmt.Sprintf(insertStmt, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1)
_, err = db.Exec(stmt)
require.NoError(t, err)
stmt = `
SELECT host_id, query_type FROM scheduled_query_stats WHERE host_id = 1 AND query_type = 1;
`
rows, err = db.Query(stmt)
require.NoError(t, rows.Err())
require.NoError(t, err)
defer rows.Close()
count = 0
for rows.Next() {
count += 1
var hostId, queryType int
err := rows.Scan(&hostId, &queryType)
require.NoError(t, err)
require.Equal(t, 1, hostId)
require.Equal(t, 1, queryType)
}
require.Equal(t, 1, count)
// Testing unique constraint -- expect error due to duplicate entry for primary key
stmt = fmt.Sprintf(insertStmt, 1, 2, 30, 40, 50, 60, 70, 80, 90, 100, 1)
_, err = db.Exec(stmt)
require.Error(t, err)
}

View File

@ -4,10 +4,16 @@ import (
"context"
"database/sql"
"fmt"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/log/level"
"github.com/jmoiron/sqlx"
"strings"
)
const (
statsScheduledQueryType = iota
statsLiveQueryType
)
func (ds *Datastore) ApplyQueries(ctx context.Context, authorID uint, queries []*fleet.Query, queriesToDiscardResults map[uint]struct{}) error {
@ -342,6 +348,16 @@ func (ds *Datastore) DeleteQuery(ctx context.Context, teamID *uint, name string)
return ctxerr.Wrap(ctx, notFound("queries").WithName(name))
}
// Delete any associated stats asynchronously.
ctxWithoutCancel := context.WithoutCancel(ctx)
go func() {
stmt := "DELETE FROM scheduled_query_stats WHERE scheduled_query_id = ?"
_, err := ds.writer(ctxWithoutCancel).ExecContext(ctxWithoutCancel, stmt, queryID)
if err != nil {
level.Error(ds.logger).Log("msg", "error deleting query stats", "err", err)
}
}()
// Opportunistically delete associated query_results.
//
// TODO(lucas): We should run this on a transaction but we found
@ -361,6 +377,21 @@ func (ds *Datastore) DeleteQueries(ctx context.Context, ids []uint) (uint, error
return deleted, err
}
// Delete any associated stats asynchronously.
ctxWithoutCancel := context.WithoutCancel(ctx)
go func() {
stmt := "DELETE FROM scheduled_query_stats WHERE scheduled_query_id IN (?)"
stmt, args, err := sqlx.In(stmt, ids)
if err != nil {
level.Error(ds.logger).Log("msg", "error creating delete query statement", "err", err)
return
}
_, err = ds.writer(ctxWithoutCancel).ExecContext(ctxWithoutCancel, stmt, args...)
if err != nil {
level.Error(ds.logger).Log("msg", "error deleting multiple query stats", "err", err)
}
}()
// Opportunistically delete associated query_results.
//
// TODO(lucas): We should run this on a transaction but we found
@ -407,7 +438,7 @@ func (ds *Datastore) Query(ctx context.Context, id uint) (*fleet.Query, error) {
WHERE q.id = ?
`
query := &fleet.Query{}
if err := sqlx.GetContext(ctx, ds.reader(ctx), query, sqlQuery, false, aggregatedStatsTypeScheduledQuery, id); err != nil {
if err := sqlx.GetContext(ctx, ds.reader(ctx), query, sqlQuery, false, fleet.AggregatedStatsTypeScheduledQuery, id); err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("Query").WithID(id))
}
@ -455,7 +486,7 @@ func (ds *Datastore) ListQueries(ctx context.Context, opt fleet.ListQueryOptions
LEFT JOIN aggregated_stats ag ON (ag.id = q.id AND ag.global_stats = ? AND ag.type = ?)
`
args := []interface{}{false, aggregatedStatsTypeScheduledQuery}
args := []interface{}{false, fleet.AggregatedStatsTypeScheduledQuery}
whereClauses := "WHERE saved = true"
if opt.OnlyObserverCanRun {
@ -606,3 +637,60 @@ func (ds *Datastore) CleanupGlobalDiscardQueryResults(ctx context.Context) error
return nil
}
// IsSavedQuery returns true if the given query is a saved query.
func (ds *Datastore) IsSavedQuery(ctx context.Context, queryID uint) (bool, error) {
stmt := `
SELECT saved
FROM queries
WHERE id = ?
`
var result bool
err := sqlx.GetContext(ctx, ds.reader(ctx), &result, stmt, queryID)
return result, err
}
// GetLiveQueryStats returns the live query stats for the given query and hosts.
func (ds *Datastore) GetLiveQueryStats(ctx context.Context, queryID uint, hostIDs []uint) ([]*fleet.LiveQueryStats, error) {
stmt, args, err := sqlx.In(
`SELECT host_id, average_memory, executions, system_time, user_time, wall_time, output_size
FROM scheduled_query_stats
WHERE host_id IN (?) AND scheduled_query_id = ? AND query_type = ?
`, hostIDs, queryID, statsLiveQueryType,
)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building get live query stats stmt")
}
results := []*fleet.LiveQueryStats{}
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, stmt, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "get live query stats")
}
return results, nil
}
// UpdateLiveQueryStats writes new stats as a batch
func (ds *Datastore) UpdateLiveQueryStats(ctx context.Context, queryID uint, stats []*fleet.LiveQueryStats) error {
if len(stats) == 0 {
return nil
}
// Bulk insert/update
const valueStr = "(?,?,?,?,?,?,?,?,?,?,?),"
stmt := "REPLACE INTO scheduled_query_stats (scheduled_query_id, host_id, query_type, executions, average_memory, system_time, user_time, wall_time, output_size, denylisted, schedule_interval) VALUES " +
strings.Repeat(valueStr, len(stats))
stmt = strings.TrimSuffix(stmt, ",")
var args []interface{}
for _, s := range stats {
args = append(
args, queryID, s.HostID, statsLiveQueryType, s.Executions, s.AverageMemory, s.SystemTime, s.UserTime, s.WallTime, s.OutputSize,
0, 0,
)
}
_, err := ds.writer(ctx).ExecContext(ctx, stmt, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "update live query stats")
}
return nil
}

View File

@ -3,6 +3,7 @@ package mysql
import (
"context"
"fmt"
"math"
"sort"
"testing"
@ -33,6 +34,7 @@ func TestQueries(t *testing.T) {
{"ListQueriesFiltersByTeamID", testListQueriesFiltersByTeamID},
{"ListQueriesFiltersByIsScheduled", testListQueriesFiltersByIsScheduled},
{"ListScheduledQueriesForAgents", testListScheduledQueriesForAgents},
{"IsSavedQuery", testIsSavedQuery},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -162,6 +164,7 @@ func testQueriesApply(t *testing.T, ds *Datastore) {
func testQueriesDelete(t *testing.T, ds *Datastore) {
user := test.NewUser(t, ds, "Zach", "zwass@fleet.co", true)
hostID := uint(1)
query := &fleet.Query{
Name: "foo",
Query: "bar",
@ -172,6 +175,14 @@ func testQueriesDelete(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.NotNil(t, query)
assert.NotEqual(t, query.ID, 0)
err = ds.UpdateLiveQueryStats(
context.Background(), query.ID, []*fleet.LiveQueryStats{
&fleet.LiveQueryStats{
HostID: hostID,
},
},
)
require.NoError(t, err)
err = ds.DeleteQuery(context.Background(), query.TeamID, query.Name)
require.NoError(t, err)
@ -180,6 +191,13 @@ func testQueriesDelete(t *testing.T, ds *Datastore) {
_, err = ds.Query(context.Background(), query.ID)
require.Error(t, err)
require.True(t, fleet.IsNotFound(err))
// Ensure stats were deleted.
// The actual delete occurs asynchronously, but enough time should have passed
// given the above DB access to ensure the original query completed.
stats, err := ds.GetLiveQueryStats(context.Background(), query.ID, []uint{hostID})
require.NoError(t, err)
require.Equal(t, 0, len(stats))
}
func testQueriesGetByName(t *testing.T, ds *Datastore) {
@ -230,6 +248,28 @@ func testQueriesDeleteMany(t *testing.T, ds *Datastore) {
require.Nil(t, err)
assert.Len(t, queries, 4)
// Add query stats
hostIDs := []uint{10, 20}
err = ds.UpdateLiveQueryStats(
context.Background(), q1.ID, []*fleet.LiveQueryStats{
{
HostID: hostIDs[0],
},
{
HostID: hostIDs[1],
},
},
)
require.NoError(t, err)
err = ds.UpdateLiveQueryStats(
context.Background(), q3.ID, []*fleet.LiveQueryStats{
{
HostID: hostIDs[0],
},
},
)
require.NoError(t, err)
deleted, err := ds.DeleteQueries(context.Background(), []uint{q1.ID, q3.ID})
require.Nil(t, err)
assert.Equal(t, uint(2), deleted)
@ -237,6 +277,15 @@ func testQueriesDeleteMany(t *testing.T, ds *Datastore) {
queries, err = ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
assert.Len(t, queries, 2)
// Ensure stats were deleted.
// The actual delete occurs asynchronously, but enough time should have passed
// given the above DB access to ensure the original query completed.
stats, err := ds.GetLiveQueryStats(context.Background(), q1.ID, hostIDs)
require.NoError(t, err)
require.Equal(t, 0, len(stats))
stats, err = ds.GetLiveQueryStats(context.Background(), q3.ID, hostIDs)
require.NoError(t, err)
require.Equal(t, 0, len(stats))
deleted, err = ds.DeleteQueries(context.Background(), []uint{q2.ID})
require.Nil(t, err)
@ -336,7 +385,8 @@ func testQueriesList(t *testing.T, ds *Datastore) {
_, err = ds.writer(context.Background()).Exec(
`INSERT INTO aggregated_stats(id,global_stats,type,json_value) VALUES (?,?,?,?)`,
idWithAgg, false, aggregatedStatsTypeScheduledQuery, `{"user_time_p50": 10.5777, "user_time_p95": 111.7308, "system_time_p50": 0.6936, "system_time_p95": 95.8654, "total_executions": 5038}`,
idWithAgg, false, fleet.AggregatedStatsTypeScheduledQuery,
`{"user_time_p50": 10.5777, "user_time_p95": 111.7308, "system_time_p50": 0.6936, "system_time_p95": 95.8654, "total_executions": 5038}`,
)
require.NoError(t, err)
@ -972,3 +1022,40 @@ func testListScheduledQueriesForAgents(t *testing.T, ds *Datastore) {
test.QueryElementsMatch(t, result, []*fleet.Query{q14, q15, q16, q17}, i)
}
}
func testIsSavedQuery(t *testing.T, ds *Datastore) {
user := test.NewUser(t, ds, "Zach", "zwass@fleet.co", true)
// NOT saved query
query := &fleet.Query{
Name: "foo",
Query: "bar",
AuthorID: &user.ID,
Logging: fleet.LoggingSnapshot,
Saved: false,
}
query, err := ds.NewQuery(context.Background(), query)
require.NoError(t, err)
isSaved, err := ds.IsSavedQuery(context.Background(), query.ID)
require.NoError(t, err)
assert.False(t, isSaved)
// Saved query
query = &fleet.Query{
Name: "foo2",
Query: "bar",
AuthorID: &user.ID,
Logging: fleet.LoggingSnapshot,
Saved: true,
}
query, err = ds.NewQuery(context.Background(), query)
require.NoError(t, err)
isSaved, err = ds.IsSavedQuery(context.Background(), query.ID)
require.NoError(t, err)
assert.True(t, isSaved)
// error case
_, err = ds.IsSavedQuery(context.Background(), math.MaxUint)
require.Error(t, err)
}

View File

@ -41,7 +41,7 @@ func (ds *Datastore) ListScheduledQueriesInPackWithStats(ctx context.Context, id
LEFT JOIN aggregated_stats ag ON (ag.id = sq.id AND ag.global_stats = ? AND ag.type = ?)
WHERE sq.pack_id = ?
`
params := []interface{}{false, aggregatedStatsTypeScheduledQuery, id}
params := []interface{}{false, fleet.AggregatedStatsTypeScheduledQuery, id}
query, params = appendListOptionsWithCursorToSQL(query, params, &opts)
results := []*fleet.ScheduledQuery{}

View File

@ -107,7 +107,8 @@ func testScheduledQueriesListInPackWithStats(t *testing.T, ds *Datastore) {
_, err = ds.writer(context.Background()).Exec(
`INSERT INTO aggregated_stats(id,global_stats,type,json_value) VALUES (?,?,?,?)`,
idWithAgg, false, aggregatedStatsTypeScheduledQuery, `{"user_time_p50": 10.5777, "user_time_p95": 111.7308, "system_time_p50": 0.6936, "system_time_p95": 95.8654, "total_executions": 5038}`,
idWithAgg, false, fleet.AggregatedStatsTypeScheduledQuery,
`{"user_time_p50": 10.5777, "user_time_p95": 111.7308, "system_time_p50": 0.6936, "system_time_p95": 95.8654, "total_executions": 5038}`,
)
require.NoError(t, err)

File diff suppressed because one or more lines are too long

View File

@ -16,6 +16,7 @@ import (
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/log"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
@ -436,3 +437,19 @@ func InsertWindowsProfileForTest(t *testing.T, ds *Datastore, teamID uint) strin
})
return profUUID
}
// GetAggregatedStats retrieves aggregated stats for the given query
func GetAggregatedStats(ctx context.Context, ds *Datastore, aggregate fleet.AggregatedStatsType, id uint) (fleet.AggregatedStats, error) {
result := fleet.AggregatedStats{}
stmt := `
SELECT
JSON_EXTRACT(json_value, '$.user_time_p50') as user_time_p50,
JSON_EXTRACT(json_value, '$.user_time_p95') as user_time_p95,
JSON_EXTRACT(json_value, '$.system_time_p50') as system_time_p50,
JSON_EXTRACT(json_value, '$.system_time_p95') as system_time_p95,
JSON_EXTRACT(json_value, '$.total_executions') as total_executions
FROM aggregated_stats WHERE id=? AND type=?
`
err := sqlx.GetContext(ctx, ds.reader(ctx), &result, stmt, id, aggregate)
return result, err
}

View File

@ -0,0 +1,7 @@
package fleet
const (
AggregatedStatsTypeScheduledQuery = "scheduled_query"
)
type AggregatedStatsType string

View File

@ -41,8 +41,9 @@ type DistributedQueryResult struct {
// DistributedQueryCampaignID is the unique ID of the live query campaign.
DistributedQueryCampaignID uint `json:"distributed_query_execution_id"`
// Host holds the host's data from where the query result comes from.
Host ResultHostData `json:"host"`
Rows []map[string]string `json:"rows"`
Host ResultHostData `json:"host"`
Rows []map[string]string `json:"rows"`
Stats *Stats `json:"stats"`
// Error contains any error reported by osquery when running the query.
// Note we can't use the error interface here because something
// implementing that interface may not (un)marshal properly

View File

@ -93,6 +93,14 @@ type Datastore interface {
ObserverCanRunQuery(ctx context.Context, queryID uint) (bool, error)
// CleanupGlobalDiscardQueryResults deletes all cached query results. Used in cleanups_then_aggregation cron.
CleanupGlobalDiscardQueryResults(ctx context.Context) error
// IsSavedQuery returns true if the given query is a saved query.
IsSavedQuery(ctx context.Context, queryID uint) (bool, error)
// GetLiveQueryStats returns the live query stats for the given query and hosts.
GetLiveQueryStats(ctx context.Context, queryID uint, hostIDs []uint) ([]*LiveQueryStats, error)
// UpdateLiveQueryStats writes new live query stats as a single operation.
UpdateLiveQueryStats(ctx context.Context, queryID uint, stats []*LiveQueryStats) error
// CalculateAggregatedPerfStatsPercentiles calculates the aggregated user/system time performance statistics for the given query.
CalculateAggregatedPerfStatsPercentiles(ctx context.Context, aggregate AggregatedStatsType, queryID uint) error
///////////////////////////////////////////////////////////////////////////////
// CampaignStore defines the distributed query campaign related datastore methods

View File

@ -8,6 +8,14 @@ type OsqueryDistributedQueryResults map[string][]map[string]string
// failure)
type OsqueryStatus int
// Stats contains the performance statistics about the execution of a specific osquery query.
type Stats struct {
WallTimeMs uint64 `json:"wall_time_ms"`
UserTime uint64 `json:"user_time"`
SystemTime uint64 `json:"system_time"`
Memory uint64 `json:"memory"`
}
const (
// StatusOK is the success code returned by osquery
StatusOK OsqueryStatus = 0

View File

@ -155,6 +155,17 @@ func (q *Query) Copy() *Query {
return &clone
}
type LiveQueryStats struct {
// host_id, average_memory, execution, system_time, user_time
HostID uint `db:"host_id"`
Executions uint64 `db:"executions"`
AverageMemory uint64 `db:"average_memory"`
SystemTime uint64 `db:"system_time"`
UserTime uint64 `db:"user_time"`
WallTime uint64 `db:"wall_time"`
OutputSize uint64 `db:"output_size"`
}
var (
LoggingSnapshot = "snapshot"
LoggingDifferential = "differential"

View File

@ -60,6 +60,7 @@ type OsqueryService interface {
results OsqueryDistributedQueryResults,
statuses map[string]OsqueryStatus,
messages map[string]string,
stats map[string]*Stats,
) (err error)
SubmitStatusLogs(ctx context.Context, logs []json.RawMessage) (err error)
SubmitResultLogs(ctx context.Context, logs []json.RawMessage) (err error)

View File

@ -124,15 +124,24 @@ func (svc *launcherWrapper) PublishResults(ctx context.Context, nodeKey string,
osqueryResults := make(fleet.OsqueryDistributedQueryResults, len(results))
statuses := make(map[string]fleet.OsqueryStatus, len(results))
stats := make(map[string]*fleet.Stats, len(results))
for _, result := range results {
statuses[result.QueryName] = fleet.OsqueryStatus(result.Status)
osqueryResults[result.QueryName] = result.Rows
if result.QueryStats != nil {
stats[result.QueryName] = &fleet.Stats{
WallTimeMs: uint64(result.QueryStats.WallTimeMs),
UserTime: uint64(result.QueryStats.UserTime),
SystemTime: uint64(result.QueryStats.SystemTime),
Memory: uint64(result.QueryStats.Memory),
}
}
}
// TODO can Launcher expose the error messages?
messages := make(map[string]string)
err = svc.tls.SubmitDistributedQueryResults(newCtx, osqueryResults, statuses, messages)
err = svc.tls.SubmitDistributedQueryResults(newCtx, osqueryResults, statuses, messages, stats)
return "", "", false, ctxerr.Wrap(ctx, err, "submit launcher results")
}

View File

@ -68,6 +68,7 @@ func TestLauncherPublishResults(t *testing.T) {
results fleet.OsqueryDistributedQueryResults,
statuses map[string]fleet.OsqueryStatus,
messages map[string]string,
stats map[string]*fleet.Stats,
) (err error) {
assert.Equal(t, results["query"][0], result)
return nil
@ -151,6 +152,7 @@ func newTLSService(t *testing.T) *mock.TLSService {
results fleet.OsqueryDistributedQueryResults,
statuses map[string]fleet.OsqueryStatus,
messages map[string]string,
stats map[string]*fleet.Stats,
) (err error) {
return
},

View File

@ -78,6 +78,14 @@ type ObserverCanRunQueryFunc func(ctx context.Context, queryID uint) (bool, erro
type CleanupGlobalDiscardQueryResultsFunc func(ctx context.Context) error
type IsSavedQueryFunc func(ctx context.Context, queryID uint) (bool, error)
type GetLiveQueryStatsFunc func(ctx context.Context, queryID uint, hostIDs []uint) ([]*fleet.LiveQueryStats, error)
type UpdateLiveQueryStatsFunc func(ctx context.Context, queryID uint, stats []*fleet.LiveQueryStats) error
type CalculateAggregatedPerfStatsPercentilesFunc func(ctx context.Context, aggregate fleet.AggregatedStatsType, queryID uint) error
type NewDistributedQueryCampaignFunc func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error)
type DistributedQueryCampaignFunc func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error)
@ -867,6 +875,18 @@ type DataStore struct {
CleanupGlobalDiscardQueryResultsFunc CleanupGlobalDiscardQueryResultsFunc
CleanupGlobalDiscardQueryResultsFuncInvoked bool
IsSavedQueryFunc IsSavedQueryFunc
IsSavedQueryFuncInvoked bool
GetLiveQueryStatsFunc GetLiveQueryStatsFunc
GetLiveQueryStatsFuncInvoked bool
UpdateLiveQueryStatsFunc UpdateLiveQueryStatsFunc
UpdateLiveQueryStatsFuncInvoked bool
CalculateAggregatedPerfStatsPercentilesFunc CalculateAggregatedPerfStatsPercentilesFunc
CalculateAggregatedPerfStatsPercentilesFuncInvoked bool
NewDistributedQueryCampaignFunc NewDistributedQueryCampaignFunc
NewDistributedQueryCampaignFuncInvoked bool
@ -2127,6 +2147,34 @@ func (s *DataStore) CleanupGlobalDiscardQueryResults(ctx context.Context) error
return s.CleanupGlobalDiscardQueryResultsFunc(ctx)
}
func (s *DataStore) IsSavedQuery(ctx context.Context, queryID uint) (bool, error) {
s.mu.Lock()
s.IsSavedQueryFuncInvoked = true
s.mu.Unlock()
return s.IsSavedQueryFunc(ctx, queryID)
}
func (s *DataStore) GetLiveQueryStats(ctx context.Context, queryID uint, hostIDs []uint) ([]*fleet.LiveQueryStats, error) {
s.mu.Lock()
s.GetLiveQueryStatsFuncInvoked = true
s.mu.Unlock()
return s.GetLiveQueryStatsFunc(ctx, queryID, hostIDs)
}
func (s *DataStore) UpdateLiveQueryStats(ctx context.Context, queryID uint, stats []*fleet.LiveQueryStats) error {
s.mu.Lock()
s.UpdateLiveQueryStatsFuncInvoked = true
s.mu.Unlock()
return s.UpdateLiveQueryStatsFunc(ctx, queryID, stats)
}
func (s *DataStore) CalculateAggregatedPerfStatsPercentiles(ctx context.Context, aggregate fleet.AggregatedStatsType, queryID uint) error {
s.mu.Lock()
s.CalculateAggregatedPerfStatsPercentilesFuncInvoked = true
s.mu.Unlock()
return s.CalculateAggregatedPerfStatsPercentilesFunc(ctx, aggregate, queryID)
}
func (s *DataStore) NewDistributedQueryCampaign(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) {
s.mu.Lock()
s.NewDistributedQueryCampaignFuncInvoked = true

View File

@ -2,8 +2,10 @@ package service
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"sort"
"sync"
@ -11,6 +13,7 @@ import (
"time"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
"github.com/fleetdm/fleet/v4/server/ptr"
@ -80,67 +83,100 @@ func (s *liveQueriesTestSuite) TearDownTest() {
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
t := s.T()
test := func(savedQuery bool, hasStats bool) {
t := s.T()
host := s.hosts[0]
host := s.hosts[0]
q1, err := s.ds.NewQuery(context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
q1, err := s.ds.NewQuery(
context.Background(), &fleet.Query{
Query: "select 1 from osquery;",
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
Saved: savedQuery,
},
)
require.NoError(t, err)
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)
s.lq.On("QueriesForHost", uint(1)).Return(map[string]string{fmt.Sprint(q1.ID): "select 1 from osquery;"}, nil)
s.lq.On("QueryCompletedByHost", mock.Anything, mock.Anything).Return(nil)
s.lq.On("RunQuery", mock.Anything, "select 1 from osquery;", []uint{host.ID}).Return(nil)
s.lq.On("StopQuery", mock.Anything).Return(nil)
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)
cid := getCIDForQ(s, q1)
var stats *fleet.Stats
if hasStats {
stats = &fleet.Stats{
UserTime: uint64(1),
SystemTime: uint64(2),
}
}
distributedReq := submitDistributedQueryResultsRequestShim{
NodeKey: *host.NodeKey,
Results: map[string]json.RawMessage{
hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`),
hostDistributedQueryPrefix + "invalidcid": json.RawMessage(`""`), // empty string is sometimes sent for no results
hostDistributedQueryPrefix + "9999": json.RawMessage(`""`),
},
Statuses: map[string]interface{}{
hostDistributedQueryPrefix + cid: 0,
hostDistributedQueryPrefix + "9999": "0",
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid: "some msg",
},
Stats: map[string]*fleet.Stats{
hostDistributedQueryPrefix + cid: stats,
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
wg.Wait()
require.Len(t, liveQueryResp.Results, 1)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1)
assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"])
assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"])
// Allow time for aggregated stats to update
time.Sleep(500 * time.Millisecond)
aggStats, err := mysql.GetAggregatedStats(context.Background(), s.ds, fleet.AggregatedStatsTypeScheduledQuery, q1.ID)
if savedQuery && hasStats {
require.NoError(t, err)
assert.Equal(t, 1, int(*aggStats.TotalExecutions))
assert.Equal(t, float64(2), *aggStats.SystemTimeP50)
assert.Equal(t, float64(2), *aggStats.SystemTimeP95)
assert.Equal(t, float64(1), *aggStats.UserTimeP50)
assert.Equal(t, float64(1), *aggStats.UserTimeP95)
} else {
require.ErrorAs(t, err, &sql.ErrNoRows)
}
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{host.ID},
}
liveQueryResp := runLiveQueryResponse{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
// Give the above call a couple of seconds to create the campaign
time.Sleep(2 * time.Second)
cid := getCIDForQ(s, q1)
distributedReq := submitDistributedQueryResultsRequestShim{
NodeKey: *host.NodeKey,
Results: map[string]json.RawMessage{
hostDistributedQueryPrefix + cid: json.RawMessage(`[{"col1": "a", "col2": "b"}]`),
hostDistributedQueryPrefix + "invalidcid": json.RawMessage(`""`), // empty string is sometimes sent for no results
hostDistributedQueryPrefix + "9999": json.RawMessage(`""`),
},
Statuses: map[string]interface{}{
hostDistributedQueryPrefix + cid: 0,
hostDistributedQueryPrefix + "9999": "0",
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid: "some msg",
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
wg.Wait()
require.Len(t, liveQueryResp.Results, 1)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results[0].Rows, 1)
assert.Equal(t, "a", liveQueryResp.Results[0].Results[0].Rows[0]["col1"])
assert.Equal(t, "b", liveQueryResp.Results[0].Results[0].Rows[0]["col2"])
s.Run("not saved query", func() { test(false, true) })
s.Run("saved query without stats", func() { test(true, false) })
s.Run("saved query with stats", func() { test(true, true) })
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
@ -153,6 +189,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
Saved: rand.Intn(2) == 1, //nolint:gosec
})
require.NoError(t, err)
@ -161,6 +198,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
Description: "desc2",
Name: t.Name() + "query2",
Logging: fleet.LoggingSnapshot,
Saved: rand.Intn(2) == 1, //nolint:gosec
})
require.NoError(t, err)
@ -206,6 +244,12 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
hostDistributedQueryPrefix + cid1: "some msg",
hostDistributedQueryPrefix + cid2: "some other msg",
},
Stats: map[string]*fleet.Stats{
hostDistributedQueryPrefix + cid1: &fleet.Stats{
UserTime: uint64(1),
SystemTime: uint64(2),
},
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)
@ -258,6 +302,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
Description: "desc1",
Name: t.Name() + "query1",
Logging: fleet.LoggingSnapshot,
Saved: rand.Intn(2) == 1, //nolint:gosec
})
require.NoError(t, err)
@ -266,6 +311,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
Description: "desc2",
Name: t.Name() + "query2",
Logging: fleet.LoggingSnapshot,
Saved: rand.Intn(2) == 1, //nolint:gosec
})
require.NoError(t, err)
@ -314,6 +360,12 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
hostDistributedQueryPrefix + cid1: "some msg",
hostDistributedQueryPrefix + cid2: "some other msg",
},
Stats: map[string]*fleet.Stats{
hostDistributedQueryPrefix + cid1: &fleet.Stats{
UserTime: uint64(1),
SystemTime: uint64(2),
},
},
}
distributedResp := submitDistributedQueryResultsResponse{}
s.DoJSON("POST", "/api/osquery/distributed/write", distributedReq, http.StatusOK, &distributedResp)

View File

@ -14,6 +14,7 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/log/level"
)
type runLiveQueryRequest struct {
@ -50,6 +51,8 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
logging.WithExtras(ctx, "live_query_rest_period_err", err)
}
// Only allow a query to be specified once
req.QueryIDs = server.RemoveDuplicatesFromSlice(req.QueryIDs)
// Only allow a host to be specified once in HostIDs
req.HostIDs = server.RemoveDuplicatesFromSlice(req.HostIDs)
res := runLiveQueryResponse{
@ -123,6 +126,17 @@ func (svc *Service) RunLiveQueryDeadline(
var results []fleet.QueryResult
timeout := time.After(deadline)
// We process stats along with results as they are sent back to the user.
// We do a batch update of the stats.
// We update aggregated stats once online hosts have reported.
const statsBatchSize = 1000
perfStatsTracker := statsTracker{}
perfStatsTracker.saveStats, err = svc.ds.IsSavedQuery(ctx, campaign.QueryID)
if err != nil {
level.Error(svc.logger).Log("msg", "error checking saved query", "query.id", campaign.QueryID, "err", err)
perfStatsTracker.saveStats = false
}
loop:
for {
select {
@ -133,11 +147,29 @@ func (svc *Service) RunLiveQueryDeadline(
counterMutex.Lock()
respondedHostIDs[res.Host.ID] = struct{}{}
counterMutex.Unlock()
if perfStatsTracker.saveStats && res.Stats != nil {
perfStatsTracker.stats = append(
perfStatsTracker.stats,
statsToSave{
hostID: res.Host.ID, Stats: res.Stats, outputSize: calculateOutputSize(&perfStatsTracker, &res),
},
)
if len(perfStatsTracker.stats) >= statsBatchSize {
svc.updateStats(ctx, campaign.QueryID, svc.logger, &perfStatsTracker, false)
}
}
case error:
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(res.Error())}
return
}
case <-timeout:
// This is the normal path for returning results. We only update aggregated stats here, and without blocking.
if perfStatsTracker.saveStats {
ctxWithoutCancel := context.WithoutCancel(ctx) // to make sure stats DB operations don't get killed after we return results.
go func() {
svc.updateStats(ctxWithoutCancel, campaign.QueryID, svc.logger, &perfStatsTracker, true)
}()
}
break loop
case <-ctx.Done():
break loop

View File

@ -20,7 +20,7 @@ type GetClientConfigFunc func(ctx context.Context) (config map[string]interface{
type GetDistributedQueriesFunc func(ctx context.Context) (queries map[string]string, discovery map[string]string, accelerate uint, err error)
type SubmitDistributedQueryResultsFunc func(ctx context.Context, results fleet.OsqueryDistributedQueryResults, statuses map[string]fleet.OsqueryStatus, messages map[string]string) (err error)
type SubmitDistributedQueryResultsFunc func(ctx context.Context, results fleet.OsqueryDistributedQueryResults, statuses map[string]fleet.OsqueryStatus, messages map[string]string, stats map[string]*fleet.Stats) (err error)
type SubmitStatusLogsFunc func(ctx context.Context, logs []json.RawMessage) (err error)
@ -79,11 +79,11 @@ func (s *TLSService) GetDistributedQueries(ctx context.Context) (queries map[str
return s.GetDistributedQueriesFunc(ctx)
}
func (s *TLSService) SubmitDistributedQueryResults(ctx context.Context, results fleet.OsqueryDistributedQueryResults, statuses map[string]fleet.OsqueryStatus, messages map[string]string) (err error) {
func (s *TLSService) SubmitDistributedQueryResults(ctx context.Context, results fleet.OsqueryDistributedQueryResults, statuses map[string]fleet.OsqueryStatus, messages map[string]string, stats map[string]*fleet.Stats) (err error) {
s.mu.Lock()
s.SubmitDistributedQueryResultsFuncInvoked = true
s.mu.Unlock()
return s.SubmitDistributedQueryResultsFunc(ctx, results, statuses, messages)
return s.SubmitDistributedQueryResultsFunc(ctx, results, statuses, messages, stats)
}
func (s *TLSService) SubmitStatusLogs(ctx context.Context, logs []json.RawMessage) (err error) {

View File

@ -804,6 +804,7 @@ type submitDistributedQueryResultsRequestShim struct {
Results map[string]json.RawMessage `json:"queries"`
Statuses map[string]interface{} `json:"statuses"`
Messages map[string]string `json:"messages"`
Stats map[string]*fleet.Stats `json:"stats"`
}
func (shim *submitDistributedQueryResultsRequestShim) hostNodeKey() string {
@ -845,6 +846,7 @@ func (shim *submitDistributedQueryResultsRequestShim) toRequest(ctx context.Cont
Results: results,
Statuses: statuses,
Messages: shim.Messages,
Stats: shim.Stats,
}, nil
}
@ -853,6 +855,7 @@ type SubmitDistributedQueryResultsRequest struct {
Results fleet.OsqueryDistributedQueryResults `json:"queries"`
Statuses map[string]fleet.OsqueryStatus `json:"statuses"`
Messages map[string]string `json:"messages"`
Stats map[string]*fleet.Stats `json:"stats"`
}
type submitDistributedQueryResultsResponse struct {
@ -868,7 +871,7 @@ func submitDistributedQueryResultsEndpoint(ctx context.Context, request interfac
return submitDistributedQueryResultsResponse{Err: err}, nil
}
err = svc.SubmitDistributedQueryResults(ctx, req.Results, req.Statuses, req.Messages)
err = svc.SubmitDistributedQueryResults(ctx, req.Results, req.Statuses, req.Messages, req.Stats)
if err != nil {
return submitDistributedQueryResultsResponse{Err: err}, nil
}
@ -913,6 +916,7 @@ func (svc *Service) SubmitDistributedQueryResults(
results fleet.OsqueryDistributedQueryResults,
statuses map[string]fleet.OsqueryStatus,
messages map[string]string,
stats map[string]*fleet.Stats,
) error {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
@ -929,7 +933,7 @@ func (svc *Service) SubmitDistributedQueryResults(
policyResults := map[uint]*bool{}
refetchCriticalSet := host.RefetchCriticalQueriesUntil != nil
svc.maybeDebugHost(ctx, host, results, statuses, messages)
svc.maybeDebugHost(ctx, host, results, statuses, messages, stats)
var hostWithoutPolicies bool
for query, rows := range results {
@ -951,9 +955,10 @@ func (svc *Service) SubmitDistributedQueryResults(
}
ll.Log("query", query, "message", messages[query], "hostID", host.ID)
}
queryStats, _ := stats[query]
ingestedDetailUpdated, ingestedAdditionalUpdated, err := svc.ingestQueryResults(
ctx, query, host, rows, failed, messages, policyResults, labelResults, additionalResults,
ctx, query, host, rows, failed, messages, policyResults, labelResults, additionalResults, queryStats,
)
if err != nil {
logging.WithErr(ctx, ctxerr.New(ctx, "error in query ingestion"))
@ -1076,6 +1081,7 @@ func (svc *Service) ingestQueryResults(
policyResults map[uint]*bool,
labelResults map[uint]*bool,
additionalResults fleet.OsqueryDistributedQueryResults,
stats *fleet.Stats,
) (bool, bool, error) {
var detailUpdated, additionalUpdated bool
@ -1087,7 +1093,7 @@ func (svc *Service) ingestQueryResults(
var err error
switch {
case strings.HasPrefix(query, hostDistributedQueryPrefix):
err = svc.ingestDistributedQuery(ctx, *host, query, rows, messages[query])
err = svc.ingestDistributedQuery(ctx, *host, query, rows, messages[query], stats)
case strings.HasPrefix(query, hostPolicyQueryPrefix):
err = ingestMembershipQuery(hostPolicyQueryPrefix, query, rows, policyResults, failed)
case strings.HasPrefix(query, hostLabelQueryPrefix):
@ -1155,7 +1161,9 @@ func (svc *Service) directIngestDetailQuery(ctx context.Context, host *fleet.Hos
// ingestDistributedQuery takes the results of a distributed query and modifies the
// provided fleet.Host appropriately.
func (svc *Service) ingestDistributedQuery(ctx context.Context, host fleet.Host, name string, rows []map[string]string, errMsg string) error {
func (svc *Service) ingestDistributedQuery(
ctx context.Context, host fleet.Host, name string, rows []map[string]string, errMsg string, stats *fleet.Stats,
) error {
trimmedQuery := strings.TrimPrefix(name, hostDistributedQueryPrefix)
campaignID, err := strconv.Atoi(osquery_utils.EmptyToZero(trimmedQuery))
@ -1171,7 +1179,8 @@ func (svc *Service) ingestDistributedQuery(ctx context.Context, host fleet.Host,
Hostname: host.Hostname,
DisplayName: host.DisplayName(),
},
Rows: rows,
Rows: rows,
Stats: stats,
}
if errMsg != "" {
res.Error = &errMsg
@ -1328,6 +1337,7 @@ func (svc *Service) maybeDebugHost(
results fleet.OsqueryDistributedQueryResults,
statuses map[string]fleet.OsqueryStatus,
messages map[string]string,
stats map[string]*fleet.Stats,
) {
if svc.debugEnabledForHost(ctx, host.ID) {
hlogger := log.With(svc.logger, "host-id", host.ID)
@ -1336,6 +1346,7 @@ func (svc *Service) maybeDebugHost(
logJSON(hlogger, results, "results")
logJSON(hlogger, statuses, "statuses")
logJSON(hlogger, messages, "messages")
logJSON(hlogger, stats, "stats")
}
}

View File

@ -1249,6 +1249,7 @@ func TestLabelQueries(t *testing.T) {
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
host.LabelUpdatedAt = mockClock.Now()
@ -1268,6 +1269,7 @@ func TestLabelQueries(t *testing.T) {
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
host.LabelUpdatedAt = mockClock.Now()
@ -1305,6 +1307,7 @@ func TestLabelQueries(t *testing.T) {
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
host.LabelUpdatedAt = mockClock.Now()
@ -1465,7 +1468,9 @@ func TestDetailQueriesWithEmptyStrings(t *testing.T) {
}
// Verify that results are ingested properly
require.NoError(t, svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}))
require.NoError(
t, svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}, map[string]*fleet.Stats{}),
)
// osquery_info
assert.Equal(t, "darwin", gotHost.Platform)
@ -1755,7 +1760,9 @@ func TestDetailQueries(t *testing.T) {
}
// Verify that results are ingested properly
require.NoError(t, svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}))
require.NoError(
t, svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}, map[string]*fleet.Stats{}),
)
require.NotNil(t, gotHost)
require.True(t, ds.SetOrUpdateMDMDataFuncInvoked)
@ -2058,6 +2065,12 @@ func TestDistributedQueryResults(t *testing.T) {
results := map[string][]map[string]string{
queryKey: expectedRows,
}
expectedStats := fleet.Stats{
UserTime: uint64(1),
}
stats := map[string]*fleet.Stats{
queryKey: &expectedStats,
}
// TODO use service method
readChan, err := rs.ReadChannel(context.Background(), *campaign)
@ -2078,6 +2091,7 @@ func TestDistributedQueryResults(t *testing.T) {
assert.Equal(t, host.ID, res.Host.ID)
assert.Equal(t, host.Hostname, res.Host.Hostname)
assert.Equal(t, host.DisplayName(), res.Host.DisplayName)
assert.Equal(t, &expectedStats, res.Stats)
} else {
t.Error("Wrong result type")
}
@ -2097,8 +2111,12 @@ func TestDistributedQueryResults(t *testing.T) {
// this test.
time.Sleep(10 * time.Millisecond)
err = svc.SubmitDistributedQueryResults(hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{})
err = svc.SubmitDistributedQueryResults(
hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}, stats,
)
require.NoError(t, err)
// Sleep to ensure checks in the goroutine are actually done.
time.Sleep(10 * time.Millisecond)
}
func TestIngestDistributedQueryParseIdError(t *testing.T) {
@ -2115,7 +2133,7 @@ func TestIngestDistributedQueryParseIdError(t *testing.T) {
}
host := fleet.Host{ID: 1}
err := svc.ingestDistributedQuery(context.Background(), host, "bad_name", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "bad_name", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to parse campaign")
}
@ -2141,7 +2159,7 @@ func TestIngestDistributedQueryOrphanedCampaignLoadError(t *testing.T) {
host := fleet.Host{ID: 1}
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "loading orphaned campaign")
}
@ -2174,7 +2192,7 @@ func TestIngestDistributedQueryOrphanedCampaignWaitListener(t *testing.T) {
host := fleet.Host{ID: 1}
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "campaignID=42 waiting for listener")
}
@ -2210,7 +2228,7 @@ func TestIngestDistributedQueryOrphanedCloseError(t *testing.T) {
host := fleet.Host{ID: 1}
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "closing orphaned campaign")
}
@ -2247,7 +2265,7 @@ func TestIngestDistributedQueryOrphanedStopError(t *testing.T) {
host := fleet.Host{ID: 1}
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "stopping orphaned campaign")
}
@ -2284,7 +2302,7 @@ func TestIngestDistributedQueryOrphanedStop(t *testing.T) {
host := fleet.Host{ID: 1}
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "campaignID=42 stopped")
lq.AssertExpectations(t)
@ -2315,7 +2333,7 @@ func TestIngestDistributedQueryRecordCompletionError(t *testing.T) {
}()
time.Sleep(10 * time.Millisecond)
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "record query completion")
lq.AssertExpectations(t)
@ -2346,7 +2364,7 @@ func TestIngestDistributedQuery(t *testing.T) {
}()
time.Sleep(10 * time.Millisecond)
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "")
err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, "", nil)
require.NoError(t, err)
lq.AssertExpectations(t)
}
@ -2662,6 +2680,7 @@ func TestDistributedQueriesLogsManyErrors(t *testing.T) {
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
@ -2703,6 +2722,7 @@ func TestDistributedQueriesReloadsHostIfDetailsAreIn(t *testing.T) {
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
assert.True(t, ds.UpdateHostFuncInvoked)
@ -2915,6 +2935,7 @@ func TestPolicyQueries(t *testing.T) {
hostPolicyQueryPrefix + "2": 1,
},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
require.Len(t, recordedResults, 2)
@ -2964,6 +2985,7 @@ func TestPolicyQueries(t *testing.T) {
hostPolicyQueryPrefix + "2": 1,
},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
require.Len(t, recordedResults, 2)
@ -3002,6 +3024,7 @@ func TestPolicyQueries(t *testing.T) {
hostPolicyQueryPrefix + "2": 1,
},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
require.NotNil(t, recordedResults[1])
@ -3123,6 +3146,7 @@ func TestPolicyWebhooks(t *testing.T) {
hostPolicyQueryPrefix + "2": 1, // didn't execute
},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
require.Len(t, recordedResults, 3)
@ -3226,6 +3250,7 @@ func TestPolicyWebhooks(t *testing.T) {
hostPolicyQueryPrefix + "2": 1, // didn't execute
},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
require.Len(t, recordedResults, 3)
@ -3270,6 +3295,7 @@ func TestPolicyWebhooks(t *testing.T) {
},
map[string]fleet.OsqueryStatus{},
map[string]string{},
map[string]*fleet.Stats{},
)
require.NoError(t, err)
require.Len(t, recordedResults, 3)

View File

@ -1158,6 +1158,17 @@ func directIngestScheduledQueryStats(ctx context.Context, logger log.Logger, hos
continue
}
// Do not save stats without executions so that we do not overwrite existing stats.
// It is normal for host to have no executions when the query just got scheduled.
executions := cast.ToUint64(row["executions"])
if executions == 0 {
level.Debug(logger).Log(
"msg", "host reported scheduled query with no executions",
"host", host.Hostname,
)
continue
}
// Split with a limit of 2 in case query name includes the
// delimiter. Not much we can do if pack name includes the
// delimiter.
@ -1179,7 +1190,7 @@ func directIngestScheduledQueryStats(ctx context.Context, logger log.Logger, hos
PackName: packName,
AverageMemory: cast.ToUint64(row["average_memory"]),
Denylisted: cast.ToBool(row["denylisted"]),
Executions: cast.ToUint64(row["executions"]),
Executions: executions,
Interval: cast.ToInt(row["interval"]),
// Cast to int first to allow cast.ToTime to interpret the unix timestamp.
LastExecuted: time.Unix(cast.ToInt64(row["last_executed"]), 0).UTC(),

View File

@ -3,7 +3,7 @@ package service
import (
"context"
"crypto/tls"
"github.com/fleetdm/fleet/v4/server/config"
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
@ -12,7 +12,9 @@ import (
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
"github.com/fleetdm/fleet/v4/server/mock"
@ -21,6 +23,7 @@ import (
ws "github.com/fleetdm/fleet/v4/server/websocket"
kitlog "github.com/go-kit/kit/log"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -161,3 +164,188 @@ func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) {
}
require.Equal(t, prevActiveConn-1, newActiveConn)
}
func TestUpdateStats(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer mysql.TruncateTables(t, ds)
s, ctx := newTestService(t, ds, nil, nil)
svc := s.(validationMiddleware).Service.(*Service)
tracker := statsTracker{}
// NOOP cases
svc.updateStats(ctx, 0, svc.logger, nil, false)
svc.updateStats(ctx, 0, svc.logger, &tracker, false)
// More NOOP cases
tracker.saveStats = true
svc.updateStats(ctx, 0, svc.logger, nil, false)
assert.True(t, tracker.saveStats)
svc.updateStats(ctx, 0, svc.logger, nil, true)
assert.True(t, tracker.saveStats)
// Populate a batch of data
hostIDs := []uint{}
queryID := uint(1)
myHostID := uint(10000)
myWallTime := uint64(5)
myUserTime := uint64(6)
mySystemTime := uint64(7)
myMemory := uint64(8)
myOutputSize := uint64(9)
tracker.stats = append(
tracker.stats, statsToSave{
hostID: myHostID,
Stats: &fleet.Stats{
WallTimeMs: myWallTime * 1000,
UserTime: myUserTime,
SystemTime: mySystemTime,
Memory: myMemory,
},
outputSize: myOutputSize,
},
)
hostIDs = append(hostIDs, myHostID)
for i := uint(1); i < statsBatchSize; i++ {
tracker.stats = append(
tracker.stats, statsToSave{
hostID: i,
Stats: &fleet.Stats{
WallTimeMs: rand.Uint64(),
UserTime: rand.Uint64(),
SystemTime: rand.Uint64(),
Memory: rand.Uint64(),
},
outputSize: rand.Uint64(),
},
)
hostIDs = append(hostIDs, i)
}
tracker.saveStats = true
svc.updateStats(ctx, queryID, svc.logger, &tracker, false)
assert.True(t, tracker.saveStats)
assert.Equal(t, 0, len(tracker.stats))
assert.True(t, tracker.aggregationNeeded)
// Get the stats from DB and make sure they match
currentStats, err := svc.ds.GetLiveQueryStats(ctx, queryID, hostIDs)
assert.NoError(t, err)
assert.Equal(t, statsBatchSize, len(currentStats))
currentStats, err = svc.ds.GetLiveQueryStats(ctx, queryID, []uint{myHostID})
assert.NoError(t, err)
require.Equal(t, 1, len(currentStats))
myStat := currentStats[0]
assert.Equal(t, myHostID, myStat.HostID)
assert.Equal(t, uint64(1), myStat.Executions)
assert.Equal(t, myWallTime, myStat.WallTime)
assert.Equal(t, myUserTime, myStat.UserTime)
assert.Equal(t, mySystemTime, myStat.SystemTime)
assert.Equal(t, myMemory, myStat.AverageMemory)
assert.Equal(t, myOutputSize, myStat.OutputSize)
// Aggregate stats
svc.updateStats(ctx, queryID, svc.logger, &tracker, true)
aggStats, err := mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID)
require.NoError(t, err)
assert.Equal(t, statsBatchSize, int(*aggStats.TotalExecutions))
// Sanity checks. Complete testing done in aggregated_stats_test.go
assert.True(t, *aggStats.SystemTimeP50 > 0)
assert.True(t, *aggStats.SystemTimeP95 > 0)
assert.True(t, *aggStats.UserTimeP50 > 0)
assert.True(t, *aggStats.UserTimeP95 > 0)
// Write new stats (update) for the same query/hosts
myNewWallTime := uint64(15)
myNewUserTime := uint64(16)
myNewSystemTime := uint64(17)
myNewMemory := uint64(18)
myNewOutputSize := uint64(19)
tracker.stats = append(
tracker.stats, statsToSave{
hostID: myHostID,
Stats: &fleet.Stats{
WallTimeMs: myNewWallTime * 1000,
UserTime: myNewUserTime,
SystemTime: myNewSystemTime,
Memory: myNewMemory,
},
outputSize: myNewOutputSize,
},
)
for i := uint(1); i < statsBatchSize; i++ {
tracker.stats = append(
tracker.stats, statsToSave{
hostID: i,
Stats: &fleet.Stats{
WallTimeMs: rand.Uint64(),
UserTime: rand.Uint64(),
SystemTime: rand.Uint64(),
Memory: rand.Uint64(),
},
outputSize: rand.Uint64(),
},
)
}
tracker.saveStats = true
svc.updateStats(ctx, queryID, svc.logger, &tracker, true)
assert.True(t, tracker.saveStats)
assert.Equal(t, 0, len(tracker.stats))
assert.False(t, tracker.aggregationNeeded)
// Check that stats were updated
currentStats, err = svc.ds.GetLiveQueryStats(ctx, queryID, []uint{myHostID})
assert.NoError(t, err)
require.Equal(t, 1, len(currentStats))
myStat = currentStats[0]
assert.Equal(t, myHostID, myStat.HostID)
assert.Equal(t, uint64(2), myStat.Executions)
assert.Equal(t, myWallTime+myNewWallTime, myStat.WallTime)
assert.Equal(t, myUserTime+myNewUserTime, myStat.UserTime)
assert.Equal(t, mySystemTime+myNewSystemTime, myStat.SystemTime)
assert.Equal(t, (myMemory+myNewMemory)/2, myStat.AverageMemory)
assert.Equal(t, myOutputSize+myNewOutputSize, myStat.OutputSize)
// Check that aggregated stats were updated
aggStats, err = mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID)
require.NoError(t, err)
assert.Equal(t, statsBatchSize*2, int(*aggStats.TotalExecutions))
// Sanity checks. Complete testing done in aggregated_stats_test.go
assert.True(t, *aggStats.SystemTimeP50 > 0)
assert.True(t, *aggStats.SystemTimeP95 > 0)
assert.True(t, *aggStats.UserTimeP50 > 0)
assert.True(t, *aggStats.UserTimeP95 > 0)
}
func TestCalculateOutputSize(t *testing.T) {
createResult := func() *fleet.DistributedQueryResult {
result := fleet.DistributedQueryResult{}
result.Rows = append(result.Rows, nil)
result.Rows = append(result.Rows, map[string]string{})
result.Rows = append(result.Rows, map[string]string{"a": "b", "a1": "b1"})
result.Rows = append(result.Rows, map[string]string{"c": "d"})
return &result
}
t.Run(
"output size save disabled", func(t *testing.T) {
tracker := statsTracker{saveStats: false}
size := calculateOutputSize(&tracker, createResult())
require.Equal(t, uint64(0), size)
},
)
t.Run(
"output size empty", func(t *testing.T) {
tracker := statsTracker{saveStats: true}
size := calculateOutputSize(&tracker, &fleet.DistributedQueryResult{})
require.Equal(t, uint64(0), size)
},
)
t.Run(
"output size calculate", func(t *testing.T) {
tracker := statsTracker{saveStats: true}
size := calculateOutputSize(&tracker, createResult())
expected := uint64(8) // manually calculated
require.Equal(t, expected, size)
},
)
}

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"math"
"time"
"github.com/fleetdm/fleet/v4/server/authz"
@ -26,6 +27,7 @@ type targetTotals struct {
const (
campaignStatusPending = "pending"
campaignStatusFinished = "finished"
statsBatchSize = 1000
)
type campaignStatus struct {
@ -34,6 +36,18 @@ type campaignStatus struct {
Status string `json:"status"`
}
type statsToSave struct {
hostID uint
*fleet.Stats
outputSize uint64
}
type statsTracker struct {
saveStats bool
aggregationNeeded bool
stats []statsToSave
}
func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint) {
logging.WithExtras(ctx, "campaign_id", campaignID)
logger := log.With(svc.logger, "campaignID", campaignID)
@ -159,6 +173,17 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co
// Push status updates every 5 seconds at most
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
// We process stats along with results as they are sent back to the user.
// We do a batch update of the stats.
// We update aggregated stats once online hosts have reported, and again (if needed) on client disconnect.
perfStatsTracker := statsTracker{}
perfStatsTracker.saveStats, err = svc.ds.IsSavedQuery(ctx, campaign.QueryID)
if err != nil {
level.Error(logger).Log("msg", "error checking saved query", "query.id", campaign.QueryID, "err", err)
perfStatsTracker.saveStats = false
}
// Loop, pushing updates to results and expected totals
for {
// Update the expected hosts total (Should happen before
@ -169,9 +194,20 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co
// Receive a result and push it over the websocket
switch res := res.(type) {
case fleet.DistributedQueryResult:
// Calculate result size for performance stats
outputSize := calculateOutputSize(&perfStatsTracker, &res)
mapHostnameRows(&res)
err = conn.WriteJSONMessage("result", res)
if perfStatsTracker.saveStats && res.Stats != nil {
perfStatsTracker.stats = append(
perfStatsTracker.stats, statsToSave{hostID: res.Host.ID, Stats: res.Stats, outputSize: outputSize},
)
if len(perfStatsTracker.stats) >= statsBatchSize {
svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, false)
}
}
if ctxerr.Cause(err) == sockjs.ErrSessionNotOpen {
svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true)
// return and stop sending the query if the session was closed
// by the client
return
@ -189,6 +225,7 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co
case <-ticker.C:
if conn.GetSessionState() == sockjs.SessionClosed {
svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true)
// return and stop sending the query if the session was closed
// by the client
return
@ -196,8 +233,107 @@ func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Co
// Update status
if err := updateStatus(); err != nil {
level.Error(logger).Log("msg", "error updating status", "err", err)
svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true)
return
}
if status.ActualResults == status.ExpectedResults {
svc.updateStats(ctx, campaign.QueryID, logger, &perfStatsTracker, true)
}
}
}
}
func calculateOutputSize(perfStatsTracker *statsTracker, res *fleet.DistributedQueryResult) uint64 {
outputSize := uint64(0)
if perfStatsTracker.saveStats {
for _, row := range res.Rows {
if row == nil {
continue
}
for key, value := range row {
outputSize = outputSize + uint64(len(key)) + uint64(len(value))
}
}
}
return outputSize
}
func (svc Service) updateStats(
ctx context.Context, queryID uint, logger log.Logger, tracker *statsTracker, aggregateStats bool,
) {
// If we are not saving stats
if tracker == nil || !tracker.saveStats ||
// Or there are no stats to save, and we don't need to calculate aggregated stats
(len(tracker.stats) == 0 && (!aggregateStats || !tracker.aggregationNeeded)) {
return
}
if len(tracker.stats) > 0 {
// Get the existing stats from DB
hostIDs := []uint{}
for i := range tracker.stats {
hostIDs = append(hostIDs, tracker.stats[i].hostID)
}
currentStats, err := svc.ds.GetLiveQueryStats(ctx, queryID, hostIDs)
if err != nil {
level.Error(logger).Log("msg", "error getting current live query stats", "err", err)
tracker.saveStats = false
return
}
// Convert current Stats to a map
statsMap := make(map[uint]*fleet.LiveQueryStats)
for i := range currentStats {
statsMap[currentStats[i].HostID] = currentStats[i]
}
// Update stats
for _, gatheredStats := range tracker.stats {
stats, ok := statsMap[gatheredStats.hostID]
// We round here to get more accurate wall time
wallTime := uint64(math.Floor(float64(gatheredStats.WallTimeMs)/1000 + 0.5))
if !ok {
newStats := fleet.LiveQueryStats{
HostID: gatheredStats.hostID,
Executions: 1,
AverageMemory: gatheredStats.Memory,
SystemTime: gatheredStats.SystemTime,
UserTime: gatheredStats.UserTime,
WallTime: wallTime,
OutputSize: gatheredStats.outputSize,
}
currentStats = append(currentStats, &newStats)
} else {
// Combine old and new stats.
stats.AverageMemory = (stats.AverageMemory*stats.Executions + gatheredStats.Memory) / (stats.Executions + 1)
stats.Executions = stats.Executions + 1
stats.SystemTime = stats.SystemTime + gatheredStats.SystemTime
stats.UserTime = stats.UserTime + gatheredStats.UserTime
stats.WallTime = stats.WallTime + wallTime
stats.OutputSize = stats.OutputSize + gatheredStats.outputSize
}
}
// Insert/overwrite updated stats
err = svc.ds.UpdateLiveQueryStats(ctx, queryID, currentStats)
if err != nil {
level.Error(logger).Log("msg", "error updating live query stats", "err", err)
tracker.saveStats = false
return
}
tracker.aggregationNeeded = true
tracker.stats = nil
}
// Do aggregation
if aggregateStats && tracker.aggregationNeeded {
err := svc.ds.CalculateAggregatedPerfStatsPercentiles(ctx, fleet.AggregatedStatsTypeScheduledQuery, queryID)
if err != nil {
level.Error(logger).Log("msg", "error aggregating performance stats", "err", err)
tracker.saveStats = false
return
}
tracker.aggregationNeeded = false
}
return
}