fleet/server/service/async/async_policy.go

238 lines
7.1 KiB
Go

package async
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/redis"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
redigo "github.com/gomodule/redigo/redis"
)
const (
policyPassHostIDsKey = "policy_pass:active_host_ids"
policyPassHostKey = "policy_pass:{%d}"
policyPassReportedKey = "policy_pass_reported:{%d}"
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
)
var (
// redis list will be LTRIM'd if there are more policy IDs than this.
maxRedisPolicyResultsPerHost = 1000
)
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
if !t.AsyncEnabled {
host.PolicyUpdatedAt = ts
return t.Datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
}
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
keyTs := fmt.Sprintf(policyPassReportedKey, host.ID)
// set an expiration on both keys (list and ts), ensuring that a deleted host
// (eventually) does not use any redis space. Ensure that TTL is reasonably
// big to avoid deleting information that hasn't been collected yet - 1 week
// or 10 * the collector interval, whichever is biggest.
//
// This means that it will only expire if that host hasn't reported policies
// during that (TTL) time (each time it does report, the TTL is reset), and
// the collector will have plenty of time to run (multiple times) to try to
// persist all the data in mysql.
ttl := policyPassKeysMinTTL
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
ttl = maxTTL
}
// KEYS[1]: keyList (policyPassHostKey)
// KEYS[2]: keyTs (policyPassReportedKey)
// ARGV[1]: timestamp for "reported at"
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
// ARGV[3]: ttl for both keys
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
script := redigo.NewScript(2, `
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[3])
redis.call('SET', KEYS[2], ARGV[1])
return redis.call('EXPIRE', KEYS[2], ARGV[3])
`)
// convert results to LPUSH arguments, store as policy_id=1 for pass,
// policy_id=-1 for fail, policy_id=0 for null result.
args := make(redigo.Args, 0, 5+len(results))
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
for k, v := range results {
pass := 0
if v != nil {
if *v {
pass = 1
} else {
pass = -1
}
}
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
}
conn := t.Pool.Get()
defer conn.Close()
if err := redis.BindConn(t.Pool, conn, keyList, keyTs); err != nil {
return ctxerr.Wrap(ctx, err, "bind redis connection")
}
if _, err := script.Do(conn, args...); err != nil {
return ctxerr.Wrap(ctx, err, "run redis script")
}
// Storing the host id in the set of active host IDs for policy membership
// outside of the redis script because in Redis Cluster mode the key may not
// live on the same node as the host's keys. At the same time, purge any
// entry in the set that is older than now - TTL.
if _, err := storePurgeActiveHostID(t.Pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
return ctxerr.Wrap(ctx, err, "store active host id")
}
return nil
}
func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, t.RedisScanKeysCount)
if err != nil {
return ctxerr.Wrap(ctx, err, "load active host ids")
}
stats.Keys = len(hosts)
// need to use a script as the RPOP command only supports a COUNT since
// 6.2. Because we use LTRIM when inserting, we know the total number
// of results is at most maxRedisPolicyResultsPerHost, so it is capped
// and can be returned in one go.
script := redigo.NewScript(1, `
local res = redis.call('LRANGE', KEYS[1], 0, -1)
redis.call('DEL', KEYS[1])
return res
`)
getKeyTuples := func(hostID uint) (inserts []fleet.PolicyMembershipResult, err error) {
keyList := fmt.Sprintf(policyPassHostKey, hostID)
conn := redis.ConfigureDoer(pool, pool.Get())
defer conn.Close()
stats.RedisCmds++
res, err := redigo.Strings(script.Do(conn, keyList))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "redis LRANGE script")
}
inserts = make([]fleet.PolicyMembershipResult, 0, len(res))
stats.Items += len(res)
for _, item := range res {
parts := strings.Split(item, "=")
if len(parts) != 2 {
continue
}
var tup fleet.PolicyMembershipResult
if id, _ := strconv.ParseUint(parts[0], 10, 32); id > 0 {
tup.HostID = hostID
tup.PolicyID = uint(id)
switch parts[1] {
case "1":
tup.Passes = ptr.Bool(true)
case "-1":
tup.Passes = ptr.Bool(false)
case "0":
tup.Passes = nil
default:
continue
}
inserts = append(inserts, tup)
}
}
return inserts, nil
}
runInsertBatch := func(batch []fleet.PolicyMembershipResult) error {
stats.Inserts++
return ds.AsyncBatchInsertPolicyMembership(ctx, batch)
}
runUpdateBatch := func(ids []uint, ts time.Time) error {
stats.Updates++
return ds.AsyncBatchUpdatePolicyTimestamp(ctx, ids, ts)
}
insertBatch := make([]fleet.PolicyMembershipResult, 0, t.InsertBatch)
for _, host := range hosts {
hid := host.HostID
ins, err := getKeyTuples(hid)
if err != nil {
return err
}
insertBatch = append(insertBatch, ins...)
if len(insertBatch) >= t.InsertBatch {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
insertBatch = insertBatch[:0]
}
}
// process any remaining batch that did not reach the batchSize limit in the
// loop.
if len(insertBatch) > 0 {
if err := runInsertBatch(insertBatch); err != nil {
return err
}
}
if len(hosts) > 0 {
hostIDs := make([]uint, len(hosts))
for i, host := range hosts {
hostIDs[i] = host.HostID
}
ts := time.Now()
updateBatch := make([]uint, t.UpdateBatch)
for {
n := copy(updateBatch, hostIDs)
if n == 0 {
break
}
if err := runUpdateBatch(updateBatch[:n], ts); err != nil {
return err
}
hostIDs = hostIDs[n:]
}
// batch-remove any host ID from the active set that still has its score to
// the initial value, so that the active set does not keep all (potentially
// 100K+) host IDs to process at all times - only those with reported
// results to process.
if _, err := removeProcessedHostIDs(pool, policyPassHostIDsKey, hosts); err != nil {
return ctxerr.Wrap(ctx, err, "remove processed host ids")
}
}
return nil
}
func (t *Task) GetHostPolicyReportedAt(ctx context.Context, host *fleet.Host) time.Time {
if t.AsyncEnabled {
conn := redis.ConfigureDoer(t.Pool, t.Pool.Get())
defer conn.Close()
key := fmt.Sprintf(policyPassReportedKey, host.ID)
epoch, err := redigo.Int64(conn.Do("GET", key))
if err == nil {
if reported := time.Unix(epoch, 0); reported.After(host.PolicyUpdatedAt) {
return reported
}
}
}
return host.PolicyUpdatedAt
}