Use a dedicated table to store hosts_count and fix pagination with vulns (#4104)

This commit is contained in:
Martin Angers 2022-02-09 10:16:50 -05:00 committed by GitHub
parent 11887f87f7
commit cc1cf69a0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 207 additions and 239 deletions

View File

@ -0,0 +1,34 @@
package tables
import (
"database/sql"
"github.com/pkg/errors"
)
func init() {
MigrationClient.AddMigration(Up_20220208144830, Down_20220208144830)
}
func Up_20220208144830(tx *sql.Tx) error {
softwareHostCountsTable := `
CREATE TABLE IF NOT EXISTS software_host_counts (
software_id bigint(20) unsigned NOT NULL,
hosts_count int(10) unsigned NOT NULL,
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (software_id),
INDEX idx_software_host_counts_host_count_software_id (hosts_count, software_id),
INDEX idx_software_host_counts_updated_at_software_id (updated_at, software_id)
);
`
if _, err := tx.Exec(softwareHostCountsTable); err != nil {
return errors.Wrap(err, "create software_host_counts table")
}
return nil
}
func Down_20220208144830(tx *sql.Tx) error {
return nil
}

File diff suppressed because one or more lines are too long

View File

@ -222,7 +222,7 @@ func listSoftwareDB(
var result []fleet.Software
if err := sqlx.SelectContext(ctx, q, &result, sql, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "load host software")
return nil, ctxerr.Wrap(ctx, err, "select host software")
}
if opts.SkipLoadingCVEs {
@ -311,38 +311,18 @@ func selectSoftwareSQL(hostID *uint, opts fleet.SoftwareListOptions) (string, []
)
}
topLevelListOpts := opts.ListOptions
if opts.WithHostCounts {
subSelectCounts := dialect.From(goqu.I("aggregated_stats").As("shc")).Select(
"shc.id", goqu.I("shc.json_value").As("hosts_count"), goqu.I("shc.updated_at").As("counts_updated_at"),
).Where(goqu.I("shc.type").Eq("software_hosts_count"), goqu.I("shc.json_value").Gt(0))
subSelectListOpts := opts.ListOptions
switch subSelectListOpts.OrderKey {
case "hosts_count", "counts_updated_at":
// all good, known columns, so we sort
subSelectCounts = appendListOptionsToSelect(subSelectCounts, opts.ListOptions)
// since the aggregated_stats table will be properly LIMITed and OFFSET, then
// we must not LIMIT and OFFSET the top-level query again (it can't return
// more rows than this internal query, and it must not be offset as the sub-query
// is already offset, so offsetting the top-level query IN ADDITION would offset
// it past any result.
topLevelListOpts.Page, topLevelListOpts.PerPage = 0, 0
default:
// we don't sort if it's not a column from this table
}
ds = ds.Join(
subSelectCounts.As("shc"),
goqu.On(
goqu.I("s.id").Eq(goqu.I("shc.id")),
),
).SelectAppend(
goqu.I("shc.hosts_count"),
goqu.I("shc.counts_updated_at"),
)
goqu.I("software_host_counts").As("shc"),
goqu.On(goqu.I("s.id").Eq(goqu.I("shc.software_id"))),
).
Where(goqu.I("shc.hosts_count").Gt(0)).
SelectAppend(
goqu.I("shc.hosts_count"),
goqu.I("shc.updated_at").As("counts_updated_at"),
)
}
ds = appendListOptionsToSelect(ds, topLevelListOpts)
ds = appendListOptionsToSelect(ds, opts.ListOptions)
return ds.ToSQL()
}
@ -563,18 +543,12 @@ func (ds *Datastore) SoftwareByID(ctx context.Context, id uint) (*fleet.Software
}
// CalculateHostsPerSoftware calculates the number of hosts having each
// software installed and stores that information in the aggregated_stats
// software installed and stores that information in the software_host_counts
// table.
func (ds *Datastore) CalculateHostsPerSoftware(ctx context.Context, updatedAt time.Time) error {
// NOTE(mna): for reference, on my laptop I get ~1.5ms for 10_000 hosts / 100 software each,
// ~1.5s for 10_000 hosts / 1_000 software each (but this is with an otherwise empty
// aggregated_stats table, but still reasonable numbers give that this runs as a cron
// task in the background).
resetStmt := `
UPDATE aggregated_stats
SET json_value = CAST(0 AS json)
WHERE type = "software_hosts_count"`
UPDATE software_host_counts
SET hosts_count = 0, updated_at = ?`
queryStmt := `
SELECT count(*), software_id
@ -583,18 +557,18 @@ func (ds *Datastore) CalculateHostsPerSoftware(ctx context.Context, updatedAt ti
GROUP BY software_id`
insertStmt := `
INSERT INTO aggregated_stats
(id, type, json_value, updated_at)
INSERT INTO software_host_counts
(software_id, hosts_count, updated_at)
VALUES
%s
ON DUPLICATE KEY UPDATE
json_value = VALUES(json_value),
hosts_count = VALUES(hosts_count),
updated_at = VALUES(updated_at)`
valuesPart := `(?, "software_hosts_count", CAST(? AS json), ?),`
valuesPart := `(?, ?, ?),`
// first, reset all counts to 0
if _, err := ds.writer.ExecContext(ctx, resetStmt); err != nil {
return ctxerr.Wrap(ctx, err, "reset all software_hosts_count to 0 in aggregated_stats")
if _, err := ds.writer.ExecContext(ctx, resetStmt, updatedAt); err != nil {
return ctxerr.Wrap(ctx, err, "reset all software_host_counts to 0")
}
// next get a cursor for the counts for each software
@ -623,7 +597,7 @@ func (ds *Datastore) CalculateHostsPerSoftware(ctx context.Context, updatedAt ti
if batchCount == batchSize {
values := strings.TrimSuffix(strings.Repeat(valuesPart, batchCount), ",")
if _, err := ds.writer.ExecContext(ctx, fmt.Sprintf(insertStmt, values), args...); err != nil {
return ctxerr.Wrap(ctx, err, "insert batch into aggregated_stats")
return ctxerr.Wrap(ctx, err, "insert batch into software_host_counts")
}
args = args[:0]
@ -633,7 +607,7 @@ func (ds *Datastore) CalculateHostsPerSoftware(ctx context.Context, updatedAt ti
if batchCount > 0 {
values := strings.TrimSuffix(strings.Repeat(valuesPart, batchCount), ",")
if _, err := ds.writer.ExecContext(ctx, fmt.Sprintf(insertStmt, values), args...); err != nil {
return ctxerr.Wrap(ctx, err, "insert batch into aggregated_stats")
return ctxerr.Wrap(ctx, err, "insert last batch into software_host_counts")
}
}
if err := rows.Err(); err != nil {
@ -649,11 +623,10 @@ func (ds *Datastore) CalculateHostsPerSoftware(ctx context.Context, updatedAt ti
NOT EXISTS (
SELECT 1
FROM
aggregated_stats shc
software_host_counts shc
WHERE
software.id = shc.id AND
shc.type = "software_hosts_count" AND
json_value > 0)`
software.id = shc.software_id AND
shc.hosts_count > 0)`
if _, err := ds.writer.ExecContext(ctx, cleanupStmt); err != nil {
return ctxerr.Wrap(ctx, err, "delete unused software")
}

View File

@ -1,185 +0,0 @@
package mysql
import (
"context"
"fmt"
"strconv"
"strings"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/stretchr/testify/require"
)
func BenchmarkCalculateHostsPerSoftware(b *testing.B) {
ts := time.Now()
type counts struct{ hs, sws int }
cases := []counts{
{1, 1},
{10, 10},
{100, 100},
{1_000, 100},
{10_000, 100},
{10_000, 1_000},
}
b.Run("resetUpdate", func(b *testing.B) {
b.Run("singleSelectGroupByInsertBatch100AggStats", func(b *testing.B) {
for _, c := range cases {
b.Run(fmt.Sprintf("%d:%d", c.hs, c.sws), func(b *testing.B) {
ds := CreateMySQLDS(b)
generateHostsWithSoftware(b, ds, c.hs, c.sws)
b.ResetTimer()
for i := 0; i < b.N; i++ {
resetUpdateAllZeroAgg(b, ds)
singleSelectGroupByInsertBatchAgg(b, ds, ts, 100)
}
checkCountsAgg(b, ds, c.hs, c.sws)
})
}
})
})
b.Run("CalculateHostsPerSoftware", func(b *testing.B) {
for _, c := range cases {
b.Run(fmt.Sprintf("%d:%d", c.hs, c.sws), func(b *testing.B) {
ctx := context.Background()
ds := CreateMySQLDS(b)
generateHostsWithSoftware(b, ds, c.hs, c.sws)
b.ResetTimer()
for i := 0; i < b.N; i++ {
require.NoError(b, ds.CalculateHostsPerSoftware(ctx, ts))
}
checkCountsAgg(b, ds, c.hs, c.sws)
})
}
})
}
func checkCountsAgg(b *testing.B, ds *Datastore, hs, sws int) {
var rowsCount, invalidHostsCount int
rowsStmt := `SELECT COUNT(*) FROM aggregated_stats WHERE type = "software_hosts_count"`
err := ds.writer.GetContext(context.Background(), &rowsCount, rowsStmt)
require.NoError(b, err)
require.Equal(b, sws, rowsCount)
invalidStmt := `SELECT COUNT(*) FROM aggregated_stats WHERE type = "software_hosts_count" AND json_value != CAST(? AS json)`
err = ds.writer.GetContext(context.Background(), &invalidHostsCount, invalidStmt, hs)
require.NoError(b, err)
require.Equal(b, 0, invalidHostsCount)
}
func generateHostsWithSoftware(b *testing.B, ds *Datastore, hs, sws int) {
hostInsert := `
INSERT INTO hosts (
osquery_host_id,
node_key,
hostname,
uuid
)
VALUES `
hostValuePart := `(?, ?, ?, ?),`
var sb strings.Builder
sb.WriteString(hostInsert)
args := make([]interface{}, 0, hs*4)
for i := 0; i < hs; i++ {
osqueryHostID, _ := server.GenerateRandomText(10)
name := "host" + strconv.Itoa(i)
args = append(args, osqueryHostID, name+"key", name, name+"uuid")
sb.WriteString(hostValuePart)
}
stmt := strings.TrimSuffix(sb.String(), ",")
_, err := ds.writer.ExecContext(context.Background(), stmt, args...)
require.NoError(b, err)
swInsert := `
INSERT INTO software (
name,
version,
source
) VALUES `
swValuePart := `(?, ?, ?),`
sb.Reset()
sb.WriteString(swInsert)
args = make([]interface{}, 0, sws*3)
for i := 0; i < sws; i++ {
name := "software" + strconv.Itoa(i)
args = append(args, name, strconv.Itoa(i)+".0.0", "testing")
sb.WriteString(swValuePart)
}
stmt = strings.TrimSuffix(sb.String(), ",")
_, err = ds.writer.ExecContext(context.Background(), stmt, args...)
require.NoError(b, err)
// cartesian product of hosts and software tables
hostSwInsert := `
INSERT INTO host_software (host_id, software_id)
SELECT
h.id,
sw.id
FROM
hosts h,
software sw`
_, err = ds.writer.ExecContext(context.Background(), hostSwInsert)
require.NoError(b, err)
}
func resetUpdateAllZeroAgg(b *testing.B, ds *Datastore) {
updateStmt := `UPDATE aggregated_stats SET json_value = CAST(0 AS json) WHERE type = "software_hosts_count"`
_, err := ds.writer.ExecContext(context.Background(), updateStmt)
require.NoError(b, err)
}
func singleSelectGroupByInsertBatchAgg(b *testing.B, ds *Datastore, updatedAt time.Time, batchSize int) {
queryStmt := `
SELECT count(*), software_id
FROM host_software
GROUP BY software_id`
insertStmt := `
INSERT INTO aggregated_stats
(id, type, json_value, updated_at)
VALUES
%s
ON DUPLICATE KEY UPDATE
json_value = VALUES(json_value),
updated_at = VALUES(updated_at)`
valuesPart := `(?, "software_hosts_count", CAST(? AS json), ?),`
rows, err := ds.reader.QueryContext(context.Background(), queryStmt)
require.NoError(b, err)
defer rows.Close()
var batchCount int
args := make([]interface{}, 0, batchSize*3)
for rows.Next() {
var count int
var sid uint
require.NoError(b, rows.Scan(&count, &sid))
args = append(args, sid, count, updatedAt)
batchCount++
if batchCount == batchSize {
values := strings.TrimSuffix(strings.Repeat(valuesPart, batchCount), ",")
_, err := ds.writer.ExecContext(context.Background(), fmt.Sprintf(insertStmt, values), args...)
require.NoError(b, err)
args = args[:0]
batchCount = 0
}
}
if batchCount > 0 {
values := strings.TrimSuffix(strings.Repeat(valuesPart, batchCount), ",")
_, err := ds.writer.ExecContext(context.Background(), fmt.Sprintf(insertStmt, values), args...)
require.NoError(b, err)
}
require.NoError(b, rows.Err())
}

View File

@ -508,7 +508,7 @@ func testSoftwareList(t *testing.T, ds *Datastore) {
})
t.Run("hosts count", func(t *testing.T) {
defer TruncateTables(t, ds, "aggregated_stats")
defer TruncateTables(t, ds, "software_host_counts")
listSoftwareCheckCount(t, ds, 0, 0, fleet.SoftwareListOptions{WithHostCounts: true}, false)
// create the counts for those software and re-run

View File

@ -47,11 +47,15 @@ func (s *integrationTestSuite) TearDownTest() {
var ids []uint
for _, host := range hosts {
ids = append(ids, host.ID)
require.NoError(t, s.ds.UpdateHostSoftware(context.Background(), host.ID, nil))
}
if len(ids) > 0 {
require.NoError(t, s.ds.DeleteHosts(ctx, ids))
}
// recalculate software counts will remove the software entries
require.NoError(t, s.ds.CalculateHostsPerSoftware(context.Background(), time.Now()))
lbls, err := s.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{})
require.NoError(t, err)
for _, lbl := range lbls {
@ -2615,6 +2619,136 @@ func (s *integrationTestSuite) TestQuerySpecs() {
assert.Equal(t, uint(3), delBatchResp.Deleted)
}
func (s *integrationTestSuite) TestPaginateListSoftware() {
t := s.T()
// create a few hosts specific to this test
hosts := make([]*fleet.Host, 20)
for i := range hosts {
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: t.Name() + strconv.Itoa(i),
OsqueryHostID: t.Name() + strconv.Itoa(i),
UUID: t.Name() + strconv.Itoa(i),
Hostname: t.Name() + "foo" + strconv.Itoa(i) + ".local",
PrimaryIP: "192.168.1." + strconv.Itoa(i),
PrimaryMac: fmt.Sprintf("30-65-EC-6F-C4-%02d", i),
})
require.NoError(t, err)
require.NotNil(t, host)
hosts[i] = host
}
// create a bunch of software
sws := make([]fleet.Software, 20)
for i := range sws {
sw := fleet.Software{Name: "sw" + strconv.Itoa(i), Version: "0.0." + strconv.Itoa(i), Source: "apps"}
sws[i] = sw
}
// mark them as installed on the hosts, with host at index 0 having all 20,
// at index 1 having 19, index 2 = 18, etc. until index 19 = 1. So software
// sws[0] is only used by 1 host, while sws[19] is used by all.
for i, h := range hosts {
require.NoError(t, s.ds.UpdateHostSoftware(context.Background(), h.ID, sws[i:]))
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), h))
if i == 0 {
// this host has all software, refresh the list so we have the software.ID filled
sws = h.Software
}
}
for i, sw := range sws {
cpe := "somecpe" + strconv.Itoa(i)
require.NoError(t, s.ds.AddCPEForSoftware(context.Background(), sw, cpe))
if i < 10 {
// add CVEs for the first 10 software, which are the least used (lower hosts_count)
_, err := s.ds.InsertCVEForCPE(context.Background(), fmt.Sprintf("cve-123-123-%03d", i), []string{cpe})
require.NoError(t, err)
}
}
assertResp := func(resp listSoftwareResponse, want []fleet.Software, ts time.Time, counts ...int) {
require.Len(t, resp.Software, len(want))
for i := range resp.Software {
wantID, gotID := want[i].ID, resp.Software[i].ID
assert.Equal(t, wantID, gotID)
wantCount, gotCount := counts[i], resp.Software[i].HostsCount
assert.Equal(t, wantCount, gotCount)
}
if ts.IsZero() {
assert.Nil(t, resp.CountsUpdatedAt)
} else {
require.NotNil(t, resp.CountsUpdatedAt)
assert.WithinDuration(t, ts, *resp.CountsUpdatedAt, time.Second)
}
}
// no software host counts have been calculated yet, so this returns nothing
var lsResp listSoftwareResponse
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, nil, time.Time{})
// calculate hosts counts
hostsCountTs := time.Now().UTC()
require.NoError(t, s.ds.CalculateHostsPerSoftware(context.Background(), hostsCountTs))
// now the list software endpoint returns the software, get the first page without vulns
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "0", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[19], sws[18], sws[17], sws[16], sws[15]}, hostsCountTs, 20, 19, 18, 17, 16)
// second page (page=1)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "1", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[14], sws[13], sws[12], sws[11], sws[10]}, hostsCountTs, 15, 14, 13, 12, 11)
// third page (page=2)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "2", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, 10, 9, 8, 7, 6)
// last page (page=3)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "3", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, 5, 4, 3, 2, 1)
// past the end
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "4", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, nil, time.Time{})
// no explicit sort order, defaults to hosts_count DESC
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "2", "page", "0")
assertResp(lsResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, 20, 19)
// hosts_count ascending
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "3", "page", "0", "order_key", "hosts_count", "order_direction", "asc")
assertResp(lsResp, []fleet.Software{sws[0], sws[1], sws[2]}, hostsCountTs, 1, 2, 3)
// vulnerable software only
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "vulnerable", "true", "per_page", "5", "page", "0", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, 10, 9, 8, 7, 6)
// vulnerable software only, next page
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "vulnerable", "true", "per_page", "5", "page", "1", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, 5, 4, 3, 2, 1)
// vulnerable software only, past last page
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", nil, http.StatusOK, &lsResp, "vulnerable", "true", "per_page", "5", "page", "2", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, nil, time.Time{})
}
// creates a session and returns it, its key is to be passed as authorization header.
func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session {
key := make([]byte, 64)