mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
238 lines
7.1 KiB
Go
238 lines
7.1 KiB
Go
package async
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
redigo "github.com/gomodule/redigo/redis"
|
|
)
|
|
|
|
const (
|
|
policyPassHostIDsKey = "policy_pass:active_host_ids"
|
|
policyPassHostKey = "policy_pass:{%d}"
|
|
policyPassReportedKey = "policy_pass_reported:{%d}"
|
|
policyPassKeysMinTTL = 7 * 24 * time.Hour // 1 week
|
|
)
|
|
|
|
var (
|
|
// redis list will be LTRIM'd if there are more policy IDs than this.
|
|
maxRedisPolicyResultsPerHost = 1000
|
|
)
|
|
|
|
func (t *Task) RecordPolicyQueryExecutions(ctx context.Context, host *fleet.Host, results map[uint]*bool, ts time.Time, deferred bool) error {
|
|
if !t.AsyncEnabled {
|
|
host.PolicyUpdatedAt = ts
|
|
return t.Datastore.RecordPolicyQueryExecutions(ctx, host, results, ts, deferred)
|
|
}
|
|
|
|
keyList := fmt.Sprintf(policyPassHostKey, host.ID)
|
|
keyTs := fmt.Sprintf(policyPassReportedKey, host.ID)
|
|
|
|
// set an expiration on both keys (list and ts), ensuring that a deleted host
|
|
// (eventually) does not use any redis space. Ensure that TTL is reasonably
|
|
// big to avoid deleting information that hasn't been collected yet - 1 week
|
|
// or 10 * the collector interval, whichever is biggest.
|
|
//
|
|
// This means that it will only expire if that host hasn't reported policies
|
|
// during that (TTL) time (each time it does report, the TTL is reset), and
|
|
// the collector will have plenty of time to run (multiple times) to try to
|
|
// persist all the data in mysql.
|
|
ttl := policyPassKeysMinTTL
|
|
if maxTTL := 10 * t.CollectorInterval; maxTTL > ttl {
|
|
ttl = maxTTL
|
|
}
|
|
|
|
// KEYS[1]: keyList (policyPassHostKey)
|
|
// KEYS[2]: keyTs (policyPassReportedKey)
|
|
// ARGV[1]: timestamp for "reported at"
|
|
// ARGV[2]: max policy results to keep per host (list is trimmed to that size)
|
|
// ARGV[3]: ttl for both keys
|
|
// ARGV[4..]: policy_id=pass entries to LPUSH to the list
|
|
script := redigo.NewScript(2, `
|
|
redis.call('LPUSH', KEYS[1], unpack(ARGV, 4))
|
|
redis.call('LTRIM', KEYS[1], 0, ARGV[2])
|
|
redis.call('EXPIRE', KEYS[1], ARGV[3])
|
|
redis.call('SET', KEYS[2], ARGV[1])
|
|
return redis.call('EXPIRE', KEYS[2], ARGV[3])
|
|
`)
|
|
|
|
// convert results to LPUSH arguments, store as policy_id=1 for pass,
|
|
// policy_id=-1 for fail, policy_id=0 for null result.
|
|
args := make(redigo.Args, 0, 5+len(results))
|
|
args = args.Add(keyList, keyTs, ts.Unix(), maxRedisPolicyResultsPerHost, int(ttl.Seconds()))
|
|
for k, v := range results {
|
|
pass := 0
|
|
if v != nil {
|
|
if *v {
|
|
pass = 1
|
|
} else {
|
|
pass = -1
|
|
}
|
|
}
|
|
args = args.Add(fmt.Sprintf("%d=%d", k, pass))
|
|
}
|
|
|
|
conn := t.Pool.Get()
|
|
defer conn.Close()
|
|
if err := redis.BindConn(t.Pool, conn, keyList, keyTs); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "bind redis connection")
|
|
}
|
|
|
|
if _, err := script.Do(conn, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "run redis script")
|
|
}
|
|
|
|
// Storing the host id in the set of active host IDs for policy membership
|
|
// outside of the redis script because in Redis Cluster mode the key may not
|
|
// live on the same node as the host's keys. At the same time, purge any
|
|
// entry in the set that is older than now - TTL.
|
|
if _, err := storePurgeActiveHostID(t.Pool, policyPassHostIDsKey, host.ID, ts, ts.Add(-ttl)); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "store active host id")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *Task) collectPolicyQueryExecutions(ctx context.Context, ds fleet.Datastore, pool fleet.RedisPool, stats *collectorExecStats) error {
|
|
hosts, err := loadActiveHostIDs(pool, policyPassHostIDsKey, t.RedisScanKeysCount)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load active host ids")
|
|
}
|
|
stats.Keys = len(hosts)
|
|
|
|
// need to use a script as the RPOP command only supports a COUNT since
|
|
// 6.2. Because we use LTRIM when inserting, we know the total number
|
|
// of results is at most maxRedisPolicyResultsPerHost, so it is capped
|
|
// and can be returned in one go.
|
|
script := redigo.NewScript(1, `
|
|
local res = redis.call('LRANGE', KEYS[1], 0, -1)
|
|
redis.call('DEL', KEYS[1])
|
|
return res
|
|
`)
|
|
|
|
getKeyTuples := func(hostID uint) (inserts []fleet.PolicyMembershipResult, err error) {
|
|
keyList := fmt.Sprintf(policyPassHostKey, hostID)
|
|
conn := redis.ConfigureDoer(pool, pool.Get())
|
|
defer conn.Close()
|
|
|
|
stats.RedisCmds++
|
|
res, err := redigo.Strings(script.Do(conn, keyList))
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "redis LRANGE script")
|
|
}
|
|
|
|
inserts = make([]fleet.PolicyMembershipResult, 0, len(res))
|
|
stats.Items += len(res)
|
|
for _, item := range res {
|
|
parts := strings.Split(item, "=")
|
|
if len(parts) != 2 {
|
|
continue
|
|
}
|
|
|
|
var tup fleet.PolicyMembershipResult
|
|
if id, _ := strconv.ParseUint(parts[0], 10, 32); id > 0 {
|
|
tup.HostID = hostID
|
|
tup.PolicyID = uint(id)
|
|
switch parts[1] {
|
|
case "1":
|
|
tup.Passes = ptr.Bool(true)
|
|
case "-1":
|
|
tup.Passes = ptr.Bool(false)
|
|
case "0":
|
|
tup.Passes = nil
|
|
default:
|
|
continue
|
|
}
|
|
inserts = append(inserts, tup)
|
|
}
|
|
}
|
|
return inserts, nil
|
|
}
|
|
|
|
runInsertBatch := func(batch []fleet.PolicyMembershipResult) error {
|
|
stats.Inserts++
|
|
return ds.AsyncBatchInsertPolicyMembership(ctx, batch)
|
|
}
|
|
|
|
runUpdateBatch := func(ids []uint, ts time.Time) error {
|
|
stats.Updates++
|
|
return ds.AsyncBatchUpdatePolicyTimestamp(ctx, ids, ts)
|
|
}
|
|
|
|
insertBatch := make([]fleet.PolicyMembershipResult, 0, t.InsertBatch)
|
|
for _, host := range hosts {
|
|
hid := host.HostID
|
|
ins, err := getKeyTuples(hid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
insertBatch = append(insertBatch, ins...)
|
|
|
|
if len(insertBatch) >= t.InsertBatch {
|
|
if err := runInsertBatch(insertBatch); err != nil {
|
|
return err
|
|
}
|
|
insertBatch = insertBatch[:0]
|
|
}
|
|
}
|
|
|
|
// process any remaining batch that did not reach the batchSize limit in the
|
|
// loop.
|
|
if len(insertBatch) > 0 {
|
|
if err := runInsertBatch(insertBatch); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if len(hosts) > 0 {
|
|
hostIDs := make([]uint, len(hosts))
|
|
for i, host := range hosts {
|
|
hostIDs[i] = host.HostID
|
|
}
|
|
|
|
ts := time.Now()
|
|
updateBatch := make([]uint, t.UpdateBatch)
|
|
for {
|
|
n := copy(updateBatch, hostIDs)
|
|
if n == 0 {
|
|
break
|
|
}
|
|
if err := runUpdateBatch(updateBatch[:n], ts); err != nil {
|
|
return err
|
|
}
|
|
hostIDs = hostIDs[n:]
|
|
}
|
|
|
|
// batch-remove any host ID from the active set that still has its score to
|
|
// the initial value, so that the active set does not keep all (potentially
|
|
// 100K+) host IDs to process at all times - only those with reported
|
|
// results to process.
|
|
if _, err := removeProcessedHostIDs(pool, policyPassHostIDsKey, hosts); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "remove processed host ids")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *Task) GetHostPolicyReportedAt(ctx context.Context, host *fleet.Host) time.Time {
|
|
if t.AsyncEnabled {
|
|
conn := redis.ConfigureDoer(t.Pool, t.Pool.Get())
|
|
defer conn.Close()
|
|
|
|
key := fmt.Sprintf(policyPassReportedKey, host.ID)
|
|
epoch, err := redigo.Int64(conn.Do("GET", key))
|
|
if err == nil {
|
|
if reported := time.Unix(epoch, 0); reported.After(host.PolicyUpdatedAt) {
|
|
return reported
|
|
}
|
|
}
|
|
}
|
|
return host.PolicyUpdatedAt
|
|
}
|