Fix performance regressions in hosts APIs (#14036)

#13926 

I re-added the old queries back and we are now using them when not using
pagination or if the page size is larger than 100 (UI always uses
per_page=50)

TODO: We need to run actual load tests with high number of hosts (50k
hosts with 140 policies each)

- [X] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [X] Documented any API changes (docs/Using-Fleet/REST-API.md or
docs/Contributing/API-for-contributors.md)
- ~[ ] Documented any permissions changes (docs/Using
Fleet/manage-access.md)~
- ~[ ] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)~
- ~[ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for
new osquery data ingestion features.~
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality
  - ~For Orbit and Fleet Desktop changes:~
- ~[ ] Manual QA must be performed in the three main OSs, macOS, Windows
and Linux.~
- ~[ ] Auto-update manual QA, from released version of component to new
version (see [tools/tuf/test](../tools/tuf/test/README.md)).~
This commit is contained in:
Lucas Manuel Rodriguez 2023-09-22 05:19:02 -03:00 committed by GitHub
parent 2daebb41b1
commit eb8349567c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 293 additions and 8 deletions

View File

@ -0,0 +1,4 @@
* Fix v4.37.0 performance regressions in the following API endpoints:
* `/api/v1/fleet/hosts/report`
* `/api/v1/fleet/hosts` when using `per_page=0` or a large number for `per_page` (in the thousands)
* Allow specifying `disable_failing_policies` on the `/api/v1/fleet/hosts/report` API endpoint for increased performance. This is useful if the user is not interested on counting failed policies (`issues` column).

View File

@ -776,6 +776,24 @@ func amountEnrolledHostsByOSDB(ctx context.Context, db sqlx.QueryerContext) (byO
return byOS, totalCount, nil
}
// HostFailingPoliciesCountOptimPageSizeThreshold is the value of the page size that determines whether
// to run an optimized version of the hosts queries when pagination is used.
//
// If the page size is under this value then the queries will be optimized assuming a low number of hosts.
// If the page size is 0 or higher than this value then the queries will be optimized assuming a high number of hosts.
//
// IMPORTANT: The UI currently always uses PerPage=50 to list hosts. For better performance,
// HostFailingPoliciesCountOptimPageSizeThreshold should always be higher than what the UI uses.
//
// The optimization consists on calculating the failing policy count (which involves querying a large table, `policy_membership`)
// differently depending on the page size:
// - When the page size is short (lower than or equal to this value) then hosts are queried and filtered first, and
// then the failure policy count is calculated on such hosts only (with an IN clause).
// - When the page size is large (higher than this value) or ALL hosts are being retrieved then the hosts are
// filtered and their failing policy count are calculated on the same query (the IN clause performs worse
// than a LEFT JOIN when the number of rows is high).
var HostFailingPoliciesCountOptimPageSizeThreshold = 100
func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
sql := `SELECT
h.id,
@ -833,6 +851,16 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt
`
}
// See definition of HostFailingPoliciesCountOptimPageSizeThreshold for more details.
useHostPaginationOptim := opt.PerPage != 0 && opt.PerPage <= uint(HostFailingPoliciesCountOptimPageSizeThreshold)
if !opt.DisableFailingPolicies && !useHostPaginationOptim {
sql += `,
COALESCE(failing_policies.count, 0) AS failing_policies_count,
COALESCE(failing_policies.count, 0) AS total_issues_count
`
}
var params []interface{}
// Only include "additional" if filter provided.
@ -855,14 +883,15 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt
`
}
sql, params = ds.applyHostFilters(opt, sql, filter, params)
leftJoinFailingPolicies := !useHostPaginationOptim
sql, params = ds.applyHostFilters(opt, sql, filter, params, leftJoinFailingPolicies)
hosts := []*fleet.Host{}
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &hosts, sql, params...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "list hosts")
}
if !opt.DisableFailingPolicies {
if !opt.DisableFailingPolicies && useHostPaginationOptim {
var err error
hosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, hosts)
if err != nil {
@ -874,7 +903,7 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt
}
// TODO(Sarah): Do we need to reconcile mutually exclusive filters?
func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, filter fleet.TeamFilter, params []interface{}) (string, []interface{}) {
func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, filter fleet.TeamFilter, params []interface{}, leftJoinFailingPolicies bool) (string, []interface{}) {
opt.OrderKey = defaultHostColumnTableAlias(opt.OrderKey)
deviceMappingJoin := `LEFT JOIN (
@ -902,6 +931,14 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil
params = append(params, opt.SoftwareIDFilter)
}
failingPoliciesJoin := ""
if !opt.DisableFailingPolicies && leftJoinFailingPolicies {
failingPoliciesJoin = `LEFT JOIN (
SELECT host_id, count(*) as count FROM policy_membership WHERE passes = 0
GROUP BY host_id
) as failing_policies ON (h.id=failing_policies.host_id)`
}
operatingSystemJoin := ""
if opt.OSIDFilter != nil || (opt.OSNameFilter != nil && opt.OSVersionFilter != nil) {
operatingSystemJoin = `JOIN host_operating_system hos ON h.id = hos.host_id`
@ -936,6 +973,7 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil
%s
%s
%s
%s
%s
WHERE TRUE AND %s AND %s AND %s AND %s
`,
@ -944,6 +982,7 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil
hostMDMJoin,
deviceMappingJoin,
policyMembershipJoin,
failingPoliciesJoin,
operatingSystemJoin,
munkiJoin,
displayNameJoin,
@ -1161,12 +1200,14 @@ func filterHostsByMDMBootstrapPackageStatus(sql string, opt fleet.HostListOption
func (ds *Datastore) CountHosts(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) (int, error) {
sql := `SELECT count(*) `
// ignore pagination in count
// Ignore pagination in count.
opt.Page = 0
opt.PerPage = 0
// We don't need the failing policy counts of each host for counting hosts.
leftJoinFailingPolicies := false
var params []interface{}
sql, params = ds.applyHostFilters(opt, sql, filter, params)
sql, params = ds.applyHostFilters(opt, sql, filter, params, leftJoinFailingPolicies)
var count int
if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, sql, params...); err != nil {

View File

@ -153,6 +153,7 @@ func TestHosts(t *testing.T) {
{"GetMatchingHostSerials", testGetMatchingHostSerials},
{"ListHostsLiteByIDs", testHostsListHostsLiteByIDs},
{"HostScriptResult", testHostScriptResult},
{"ListHostsWithPagination", testListHostsWithPagination},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -7431,3 +7432,76 @@ func testHostScriptResult(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.Equal(t, expectedOutput, script.Output)
}
func testListHostsWithPagination(t *testing.T, ds *Datastore) {
ctx := context.Background()
newHostFunc := func(name string) *fleet.Host {
host, err := ds.NewHost(ctx, &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(name),
UUID: name,
Hostname: "foo.local." + name,
})
require.NoError(t, err)
require.NotNil(t, host)
return host
}
filter := fleet.TeamFilter{User: test.UserAdmin}
hostCount := int(float64(HostFailingPoliciesCountOptimPageSizeThreshold) * 1.5)
hosts := make([]*fleet.Host, 0, hostCount)
for i := 0; i < hostCount; i++ {
hosts = append(hosts, newHostFunc(fmt.Sprintf("h%d", i)))
}
// List all hosts with PerPage=0 which should not use the failing policies optimization.
perPage0 := 0
hosts0, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{
ListOptions: fleet.ListOptions{
PerPage: uint(perPage0),
},
})
require.NoError(t, err)
require.Len(t, hosts0, hostCount)
for i, host := range hosts0 {
require.Equal(t, host.ID, hosts[i].ID)
}
// List hosts with number of hosts per page equal to the failing policies optimization threshold, to
// (thus using the optimization).
perPage1 := HostFailingPoliciesCountOptimPageSizeThreshold
hosts1, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{
ListOptions: fleet.ListOptions{
PerPage: uint(perPage1),
},
})
require.NoError(t, err)
require.Len(t, hosts1, perPage1)
for i, host := range hosts1 {
require.Equal(t, host.ID, hosts[i].ID)
}
// List hosts with number of hosts per page higher to the failing policies optimization threshold
// (thus not using the optimization)
perPage2 := int(float64(HostFailingPoliciesCountOptimPageSizeThreshold) * 1.2)
hosts2, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{
ListOptions: fleet.ListOptions{
PerPage: uint(perPage2),
},
})
require.NoError(t, err)
require.Len(t, hosts2, perPage2)
for i, host := range hosts2 {
require.Equal(t, host.ID, hosts[i].ID)
}
// Count hosts doesn't do failing policies count or pagination.
count, err := ds.CountHosts(ctx, filter, fleet.HostListOptions{})
require.NoError(t, err)
require.Equal(t, hostCount, count)
}

View File

@ -1410,7 +1410,6 @@ func hostsReportEndpoint(ctx context.Context, request interface{}, svc fleet.Ser
return hostsReportResponse{Err: err}, nil
}
req.Opts.DisableFailingPolicies = false
req.Opts.AdditionalFilters = nil
req.Opts.Page = 0
req.Opts.PerPage = 0 // explicitly disable any limit, we want all matching hosts

View File

@ -6103,7 +6103,6 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
require.NoError(t, err)
require.Len(t, rows, len(hosts)+1) // all hosts + header row
require.Len(t, rows[0], 48) // total number of cols
t.Log(rows[0])
const (
idCol = 3
@ -6198,7 +6197,6 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local1"}, rows[2][:2])
require.Len(t, rows[3], 3)
require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local0"}, rows[3][:2])
t.Log(rows)
}
func (s *integrationTestSuite) TestSSODisabled() {
@ -7611,3 +7609,172 @@ func (s *integrationTestSuite) TestDirectIngestSoftwareWithInvalidFields() {
})
require.NotZero(t, wiresharkSoftware.ID)
}
func (s *integrationTestSuite) TestHostsReportWithPolicyResults() {
t := s.T()
ctx := context.Background()
newHostFunc := func(name string) *fleet.Host {
host, err := s.ds.NewHost(ctx, &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(name),
UUID: name,
Hostname: "foo.local." + name,
})
require.NoError(t, err)
require.NotNil(t, host)
return host
}
hostCount := 10
hosts := make([]*fleet.Host, 0, hostCount)
for i := 0; i < hostCount; i++ {
hosts = append(hosts, newHostFunc(fmt.Sprintf("h%d", i)))
}
globalPolicy0, err := s.ds.NewGlobalPolicy(ctx, &test.UserAdmin.ID, fleet.PolicyPayload{
Name: "foobar0",
Query: "SELECT 0;",
})
require.NoError(t, err)
globalPolicy1, err := s.ds.NewGlobalPolicy(ctx, &test.UserAdmin.ID, fleet.PolicyPayload{
Name: "foobar1",
Query: "SELECT 1;",
})
require.NoError(t, err)
globalPolicy2, err := s.ds.NewGlobalPolicy(ctx, &test.UserAdmin.ID, fleet.PolicyPayload{
Name: "foobar2",
Query: "SELECT 2;",
})
require.NoError(t, err)
for i, host := range hosts {
// All hosts pass the globalPolicy0
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
if i%2 == 0 {
// Half of the hosts pass the globalPolicy1 and fail the globalPolicy2
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
} else {
// Half of the hosts pass the globalPolicy2 and fail the globalPolicy1
err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
}
}
// The hosts/report endpoint uses svc.ds.ListHosts with page=0, per_page=0, thus we are
// testing the non optimized for pagination queries for failing policies calculation.
res := s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv")
rows1, err := csv.NewReader(res.Body).ReadAll()
res.Body.Close()
require.NoError(t, err)
require.Len(t, rows1, len(hosts)+1) // all hosts + header row
require.Len(t, rows1[0], 48) // total number of cols
var (
idIdx int
issuesIdx int
)
for colIdx, column := range rows1[0] {
switch column {
case "issues":
issuesIdx = colIdx
case "id":
idIdx = colIdx
}
}
for i := 1; i < len(hosts)+1; i++ {
row := rows1[i]
require.Equal(t, row[issuesIdx], "1")
}
// Running with disable_failing_policies=true disable the counting of failed policies for a host.
// Thus, all "issues" values should be 0.
res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "disable_failing_policies", "true")
rows2, err := csv.NewReader(res.Body).ReadAll()
res.Body.Close()
require.NoError(t, err)
require.Len(t, rows2, len(hosts)+1) // all hosts + header row
require.Len(t, rows2[0], 48) // total number of cols
// Check that all hosts have 0 issues and that they match the previous call to `/hosts/report`.
for i := 1; i < len(hosts)+1; i++ {
row := rows2[i]
require.Equal(t, row[issuesIdx], "0")
row1 := rows1[i]
require.Equal(t, row[idIdx], row1[idIdx])
}
for _, tc := range []struct {
name string
args []string
checkRows func(t *testing.T, rows [][]string)
}{
{
name: "get hosts that fail globalPolicy0",
args: []string{"policy_id", fmt.Sprint(globalPolicy0.ID), "policy_response", "failure"},
checkRows: func(t *testing.T, rows [][]string) {
require.Len(t, rows, 1) // just header row, all hosts pass such policy.
},
},
{
name: "get hosts that pass globalPolicy0",
args: []string{"policy_id", fmt.Sprint(globalPolicy0.ID), "policy_response", "passing"},
checkRows: func(t *testing.T, rows [][]string) {
require.Len(t, rows, len(hosts)+1) // all hosts + header row, all hosts pass such policy.
},
},
{
name: "get hosts that fail globalPolicy1",
args: []string{"policy_id", fmt.Sprint(globalPolicy1.ID), "policy_response", "failing"},
checkRows: func(t *testing.T, rows [][]string) {
require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row.
},
},
{
name: "get hosts that pass globalPolicy1",
args: []string{"policy_id", fmt.Sprint(globalPolicy1.ID), "policy_response", "passing"},
checkRows: func(t *testing.T, rows [][]string) {
require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row.
},
},
{
name: "get hosts that fail globalPolicy2",
args: []string{"policy_id", fmt.Sprint(globalPolicy2.ID), "policy_response", "failing"},
checkRows: func(t *testing.T, rows [][]string) {
require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row.
},
},
{
name: "get hosts that pass globalPolicy2",
args: []string{"policy_id", fmt.Sprint(globalPolicy2.ID), "policy_response", "passing"},
checkRows: func(t *testing.T, rows [][]string) {
require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row.
},
},
} {
t.Run(tc.name, func(t *testing.T) {
res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, append(tc.args, "format", "csv")...)
rows, err := csv.NewReader(res.Body).ReadAll()
res.Body.Close()
require.NoError(t, err)
tc.checkRows(t, rows)
// Test the same with "disable_failing_policies=true" which should not change the result.
res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, append(tc.args, "format", "csv", "disable_failing_policies", "true")...)
rows, err = csv.NewReader(res.Body).ReadAll()
res.Body.Close()
require.NoError(t, err)
tc.checkRows(t, rows)
})
}
}