diff --git a/changes/13926-fix-performance-regression-in-hosts-report b/changes/13926-fix-performance-regression-in-hosts-report new file mode 100644 index 000000000..9b9e7eedd --- /dev/null +++ b/changes/13926-fix-performance-regression-in-hosts-report @@ -0,0 +1,4 @@ +* Fix v4.37.0 performance regressions in the following API endpoints: + * `/api/v1/fleet/hosts/report` + * `/api/v1/fleet/hosts` when using `per_page=0` or a large number for `per_page` (in the thousands) +* Allow specifying `disable_failing_policies` on the `/api/v1/fleet/hosts/report` API endpoint for increased performance. This is useful if the user is not interested on counting failed policies (`issues` column). diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 0533b5fca..0438dac71 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -776,6 +776,24 @@ func amountEnrolledHostsByOSDB(ctx context.Context, db sqlx.QueryerContext) (byO return byOS, totalCount, nil } +// HostFailingPoliciesCountOptimPageSizeThreshold is the value of the page size that determines whether +// to run an optimized version of the hosts queries when pagination is used. +// +// If the page size is under this value then the queries will be optimized assuming a low number of hosts. +// If the page size is 0 or higher than this value then the queries will be optimized assuming a high number of hosts. +// +// IMPORTANT: The UI currently always uses PerPage=50 to list hosts. For better performance, +// HostFailingPoliciesCountOptimPageSizeThreshold should always be higher than what the UI uses. +// +// The optimization consists on calculating the failing policy count (which involves querying a large table, `policy_membership`) +// differently depending on the page size: +// - When the page size is short (lower than or equal to this value) then hosts are queried and filtered first, and +// then the failure policy count is calculated on such hosts only (with an IN clause). +// - When the page size is large (higher than this value) or ALL hosts are being retrieved then the hosts are +// filtered and their failing policy count are calculated on the same query (the IN clause performs worse +// than a LEFT JOIN when the number of rows is high). +var HostFailingPoliciesCountOptimPageSizeThreshold = 100 + func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) { sql := `SELECT h.id, @@ -833,6 +851,16 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt ` } + // See definition of HostFailingPoliciesCountOptimPageSizeThreshold for more details. + useHostPaginationOptim := opt.PerPage != 0 && opt.PerPage <= uint(HostFailingPoliciesCountOptimPageSizeThreshold) + + if !opt.DisableFailingPolicies && !useHostPaginationOptim { + sql += `, + COALESCE(failing_policies.count, 0) AS failing_policies_count, + COALESCE(failing_policies.count, 0) AS total_issues_count + ` + } + var params []interface{} // Only include "additional" if filter provided. @@ -855,14 +883,15 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt ` } - sql, params = ds.applyHostFilters(opt, sql, filter, params) + leftJoinFailingPolicies := !useHostPaginationOptim + sql, params = ds.applyHostFilters(opt, sql, filter, params, leftJoinFailingPolicies) hosts := []*fleet.Host{} if err := sqlx.SelectContext(ctx, ds.reader(ctx), &hosts, sql, params...); err != nil { return nil, ctxerr.Wrap(ctx, err, "list hosts") } - if !opt.DisableFailingPolicies { + if !opt.DisableFailingPolicies && useHostPaginationOptim { var err error hosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, hosts) if err != nil { @@ -874,7 +903,7 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt } // TODO(Sarah): Do we need to reconcile mutually exclusive filters? -func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, filter fleet.TeamFilter, params []interface{}) (string, []interface{}) { +func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, filter fleet.TeamFilter, params []interface{}, leftJoinFailingPolicies bool) (string, []interface{}) { opt.OrderKey = defaultHostColumnTableAlias(opt.OrderKey) deviceMappingJoin := `LEFT JOIN ( @@ -902,6 +931,14 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil params = append(params, opt.SoftwareIDFilter) } + failingPoliciesJoin := "" + if !opt.DisableFailingPolicies && leftJoinFailingPolicies { + failingPoliciesJoin = `LEFT JOIN ( + SELECT host_id, count(*) as count FROM policy_membership WHERE passes = 0 + GROUP BY host_id + ) as failing_policies ON (h.id=failing_policies.host_id)` + } + operatingSystemJoin := "" if opt.OSIDFilter != nil || (opt.OSNameFilter != nil && opt.OSVersionFilter != nil) { operatingSystemJoin = `JOIN host_operating_system hos ON h.id = hos.host_id` @@ -936,6 +973,7 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil %s %s %s + %s %s WHERE TRUE AND %s AND %s AND %s AND %s `, @@ -944,6 +982,7 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil hostMDMJoin, deviceMappingJoin, policyMembershipJoin, + failingPoliciesJoin, operatingSystemJoin, munkiJoin, displayNameJoin, @@ -1161,12 +1200,14 @@ func filterHostsByMDMBootstrapPackageStatus(sql string, opt fleet.HostListOption func (ds *Datastore) CountHosts(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) (int, error) { sql := `SELECT count(*) ` - // ignore pagination in count + // Ignore pagination in count. opt.Page = 0 opt.PerPage = 0 + // We don't need the failing policy counts of each host for counting hosts. + leftJoinFailingPolicies := false var params []interface{} - sql, params = ds.applyHostFilters(opt, sql, filter, params) + sql, params = ds.applyHostFilters(opt, sql, filter, params, leftJoinFailingPolicies) var count int if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, sql, params...); err != nil { diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 7e442f03c..0cc6846c2 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -153,6 +153,7 @@ func TestHosts(t *testing.T) { {"GetMatchingHostSerials", testGetMatchingHostSerials}, {"ListHostsLiteByIDs", testHostsListHostsLiteByIDs}, {"HostScriptResult", testHostScriptResult}, + {"ListHostsWithPagination", testListHostsWithPagination}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -7431,3 +7432,76 @@ func testHostScriptResult(t *testing.T, ds *Datastore) { require.NoError(t, err) require.Equal(t, expectedOutput, script.Output) } + +func testListHostsWithPagination(t *testing.T, ds *Datastore) { + ctx := context.Background() + + newHostFunc := func(name string) *fleet.Host { + host, err := ds.NewHost(ctx, &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String(name), + UUID: name, + Hostname: "foo.local." + name, + }) + require.NoError(t, err) + require.NotNil(t, host) + return host + } + + filter := fleet.TeamFilter{User: test.UserAdmin} + + hostCount := int(float64(HostFailingPoliciesCountOptimPageSizeThreshold) * 1.5) + hosts := make([]*fleet.Host, 0, hostCount) + for i := 0; i < hostCount; i++ { + hosts = append(hosts, newHostFunc(fmt.Sprintf("h%d", i))) + } + + // List all hosts with PerPage=0 which should not use the failing policies optimization. + perPage0 := 0 + hosts0, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{ + ListOptions: fleet.ListOptions{ + PerPage: uint(perPage0), + }, + }) + require.NoError(t, err) + require.Len(t, hosts0, hostCount) + for i, host := range hosts0 { + require.Equal(t, host.ID, hosts[i].ID) + } + + // List hosts with number of hosts per page equal to the failing policies optimization threshold, to + // (thus using the optimization). + perPage1 := HostFailingPoliciesCountOptimPageSizeThreshold + hosts1, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{ + ListOptions: fleet.ListOptions{ + PerPage: uint(perPage1), + }, + }) + require.NoError(t, err) + require.Len(t, hosts1, perPage1) + for i, host := range hosts1 { + require.Equal(t, host.ID, hosts[i].ID) + } + + // List hosts with number of hosts per page higher to the failing policies optimization threshold + // (thus not using the optimization) + perPage2 := int(float64(HostFailingPoliciesCountOptimPageSizeThreshold) * 1.2) + hosts2, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{ + ListOptions: fleet.ListOptions{ + PerPage: uint(perPage2), + }, + }) + require.NoError(t, err) + require.Len(t, hosts2, perPage2) + for i, host := range hosts2 { + require.Equal(t, host.ID, hosts[i].ID) + } + + // Count hosts doesn't do failing policies count or pagination. + count, err := ds.CountHosts(ctx, filter, fleet.HostListOptions{}) + require.NoError(t, err) + require.Equal(t, hostCount, count) +} diff --git a/server/service/hosts.go b/server/service/hosts.go index 85eeb010b..76e926bb4 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -1410,7 +1410,6 @@ func hostsReportEndpoint(ctx context.Context, request interface{}, svc fleet.Ser return hostsReportResponse{Err: err}, nil } - req.Opts.DisableFailingPolicies = false req.Opts.AdditionalFilters = nil req.Opts.Page = 0 req.Opts.PerPage = 0 // explicitly disable any limit, we want all matching hosts diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index de724804b..569a1652b 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -6103,7 +6103,6 @@ func (s *integrationTestSuite) TestHostsReportDownload() { require.NoError(t, err) require.Len(t, rows, len(hosts)+1) // all hosts + header row require.Len(t, rows[0], 48) // total number of cols - t.Log(rows[0]) const ( idCol = 3 @@ -6198,7 +6197,6 @@ func (s *integrationTestSuite) TestHostsReportDownload() { require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local1"}, rows[2][:2]) require.Len(t, rows[3], 3) require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local0"}, rows[3][:2]) - t.Log(rows) } func (s *integrationTestSuite) TestSSODisabled() { @@ -7611,3 +7609,172 @@ func (s *integrationTestSuite) TestDirectIngestSoftwareWithInvalidFields() { }) require.NotZero(t, wiresharkSoftware.ID) } + +func (s *integrationTestSuite) TestHostsReportWithPolicyResults() { + t := s.T() + ctx := context.Background() + + newHostFunc := func(name string) *fleet.Host { + host, err := s.ds.NewHost(ctx, &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String(name), + UUID: name, + Hostname: "foo.local." + name, + }) + require.NoError(t, err) + require.NotNil(t, host) + return host + } + + hostCount := 10 + hosts := make([]*fleet.Host, 0, hostCount) + for i := 0; i < hostCount; i++ { + hosts = append(hosts, newHostFunc(fmt.Sprintf("h%d", i))) + } + + globalPolicy0, err := s.ds.NewGlobalPolicy(ctx, &test.UserAdmin.ID, fleet.PolicyPayload{ + Name: "foobar0", + Query: "SELECT 0;", + }) + require.NoError(t, err) + globalPolicy1, err := s.ds.NewGlobalPolicy(ctx, &test.UserAdmin.ID, fleet.PolicyPayload{ + Name: "foobar1", + Query: "SELECT 1;", + }) + require.NoError(t, err) + globalPolicy2, err := s.ds.NewGlobalPolicy(ctx, &test.UserAdmin.ID, fleet.PolicyPayload{ + Name: "foobar2", + Query: "SELECT 2;", + }) + require.NoError(t, err) + + for i, host := range hosts { + // All hosts pass the globalPolicy0 + err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy0.ID: ptr.Bool(true)}, time.Now(), false) + require.NoError(t, err) + + if i%2 == 0 { + // Half of the hosts pass the globalPolicy1 and fail the globalPolicy2 + err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(true)}, time.Now(), false) + require.NoError(t, err) + err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(false)}, time.Now(), false) + require.NoError(t, err) + } else { + // Half of the hosts pass the globalPolicy2 and fail the globalPolicy1 + err := s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy1.ID: ptr.Bool(false)}, time.Now(), false) + require.NoError(t, err) + err = s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{globalPolicy2.ID: ptr.Bool(true)}, time.Now(), false) + require.NoError(t, err) + } + } + + // The hosts/report endpoint uses svc.ds.ListHosts with page=0, per_page=0, thus we are + // testing the non optimized for pagination queries for failing policies calculation. + res := s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv") + rows1, err := csv.NewReader(res.Body).ReadAll() + res.Body.Close() + require.NoError(t, err) + require.Len(t, rows1, len(hosts)+1) // all hosts + header row + require.Len(t, rows1[0], 48) // total number of cols + + var ( + idIdx int + issuesIdx int + ) + for colIdx, column := range rows1[0] { + switch column { + case "issues": + issuesIdx = colIdx + case "id": + idIdx = colIdx + } + } + + for i := 1; i < len(hosts)+1; i++ { + row := rows1[i] + require.Equal(t, row[issuesIdx], "1") + } + + // Running with disable_failing_policies=true disable the counting of failed policies for a host. + // Thus, all "issues" values should be 0. + res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "disable_failing_policies", "true") + rows2, err := csv.NewReader(res.Body).ReadAll() + res.Body.Close() + require.NoError(t, err) + require.Len(t, rows2, len(hosts)+1) // all hosts + header row + require.Len(t, rows2[0], 48) // total number of cols + + // Check that all hosts have 0 issues and that they match the previous call to `/hosts/report`. + for i := 1; i < len(hosts)+1; i++ { + row := rows2[i] + require.Equal(t, row[issuesIdx], "0") + row1 := rows1[i] + require.Equal(t, row[idIdx], row1[idIdx]) + } + + for _, tc := range []struct { + name string + args []string + checkRows func(t *testing.T, rows [][]string) + }{ + { + name: "get hosts that fail globalPolicy0", + args: []string{"policy_id", fmt.Sprint(globalPolicy0.ID), "policy_response", "failure"}, + checkRows: func(t *testing.T, rows [][]string) { + require.Len(t, rows, 1) // just header row, all hosts pass such policy. + }, + }, + { + name: "get hosts that pass globalPolicy0", + args: []string{"policy_id", fmt.Sprint(globalPolicy0.ID), "policy_response", "passing"}, + checkRows: func(t *testing.T, rows [][]string) { + require.Len(t, rows, len(hosts)+1) // all hosts + header row, all hosts pass such policy. + }, + }, + { + name: "get hosts that fail globalPolicy1", + args: []string{"policy_id", fmt.Sprint(globalPolicy1.ID), "policy_response", "failing"}, + checkRows: func(t *testing.T, rows [][]string) { + require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row. + }, + }, + { + name: "get hosts that pass globalPolicy1", + args: []string{"policy_id", fmt.Sprint(globalPolicy1.ID), "policy_response", "passing"}, + checkRows: func(t *testing.T, rows [][]string) { + require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row. + }, + }, + { + name: "get hosts that fail globalPolicy2", + args: []string{"policy_id", fmt.Sprint(globalPolicy2.ID), "policy_response", "failing"}, + checkRows: func(t *testing.T, rows [][]string) { + require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row. + }, + }, + { + name: "get hosts that pass globalPolicy2", + args: []string{"policy_id", fmt.Sprint(globalPolicy2.ID), "policy_response", "passing"}, + checkRows: func(t *testing.T, rows [][]string) { + require.Len(t, rows, len(hosts)/2+1) // half of hosts + header row. + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, append(tc.args, "format", "csv")...) + rows, err := csv.NewReader(res.Body).ReadAll() + res.Body.Close() + require.NoError(t, err) + tc.checkRows(t, rows) + // Test the same with "disable_failing_policies=true" which should not change the result. + res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, append(tc.args, "format", "csv", "disable_failing_policies", "true")...) + rows, err = csv.NewReader(res.Body).ReadAll() + res.Body.Close() + require.NoError(t, err) + tc.checkRows(t, rows) + }) + } +}