fleet/cmd/osquery-perf/agent.go
2021-11-19 08:50:25 -03:00

451 lines
11 KiB
Go

package main
import (
"bytes"
"crypto/tls"
"embed"
"encoding/json"
"flag"
"fmt"
"log"
"math/rand"
"net/http"
"os"
"strings"
"sync"
"text/template"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/google/uuid"
"github.com/valyala/fasthttp"
)
//go:embed *.tmpl
var templatesFS embed.FS
type Stats struct {
errors int
enrollments int
distributedwrites int
l sync.Mutex
}
func (s *Stats) RecordStats(errors int, enrollments int, distributedwrites int) {
s.l.Lock()
defer s.l.Unlock()
s.errors += errors
s.enrollments += enrollments
s.distributedwrites += distributedwrites
}
func (s *Stats) Log() {
s.l.Lock()
defer s.l.Unlock()
fmt.Printf(
"%s :: error rate: %.2f \t enrollments: %d \t writes: %d\n",
time.Now().String(),
float64(s.errors)/float64(s.enrollments),
s.enrollments,
s.distributedwrites,
)
}
func (s *Stats) runLoop() {
ticker := time.Tick(10 * time.Second)
for range ticker {
s.Log()
}
}
type nodeKeyManager struct {
filepath string
l sync.Mutex
nodekeys []string
}
func (n *nodeKeyManager) LoadKeys() {
if n.filepath == "" {
return
}
n.l.Lock()
defer n.l.Unlock()
data, err := os.ReadFile(n.filepath)
if err != nil {
fmt.Println("WARNING (ignore if creating a new node key file): error loading nodekey file:", err)
return
}
n.nodekeys = strings.Split(string(data), "\n")
fmt.Printf("loaded %d node keys\n", len(n.nodekeys))
}
func (n *nodeKeyManager) Get(i int) string {
n.l.Lock()
defer n.l.Unlock()
if len(n.nodekeys) > i {
return n.nodekeys[i]
}
return ""
}
func (n *nodeKeyManager) Add(nodekey string) {
if n.filepath == "" {
return
}
// we lock just to make sure we write one at a time
n.l.Lock()
defer n.l.Unlock()
f, err := os.OpenFile(n.filepath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
fmt.Println("error opening nodekey file:", err.Error())
return
}
defer f.Close()
if _, err := f.WriteString(nodekey + "\n"); err != nil {
fmt.Println("error writing nodekey file:", err)
}
}
type agent struct {
ServerAddress string
EnrollSecret string
NodeKey string
UUID string
FastClient fasthttp.Client
ConfigInterval time.Duration
QueryInterval time.Duration
Templates *template.Template
Stats *Stats
SoftwareCount int
NodeKeyManager *nodeKeyManager
policyPassProb float64
strings map[string]string
}
func newAgent(
serverAddress, enrollSecret string, templates *template.Template,
configInterval, queryInterval time.Duration, softwareCount int,
policyPassProb float64,
) *agent {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
transport.DisableCompression = true
return &agent{
ServerAddress: serverAddress,
EnrollSecret: enrollSecret,
Templates: templates,
ConfigInterval: configInterval,
QueryInterval: queryInterval,
UUID: uuid.New().String(),
FastClient: fasthttp.Client{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
},
SoftwareCount: softwareCount,
strings: make(map[string]string),
policyPassProb: policyPassProb,
}
}
type enrollResponse struct {
NodeKey string `json:"node_key"`
}
type distributedReadResponse struct {
Queries map[string]string `json:"queries"`
}
func (a *agent) runLoop(i int, onlyAlreadyEnrolled bool) {
if err := a.enroll(i, onlyAlreadyEnrolled); err != nil {
return
}
a.config()
resp, err := a.DistributedRead()
if err != nil {
log.Println(err)
} else {
if len(resp.Queries) > 0 {
a.DistributedWrite(resp.Queries)
}
}
configTicker := time.Tick(a.ConfigInterval)
liveQueryTicker := time.Tick(a.QueryInterval)
for {
select {
case <-configTicker:
a.config()
case <-liveQueryTicker:
resp, err := a.DistributedRead()
if err != nil {
log.Println(err)
} else {
if len(resp.Queries) > 0 {
a.DistributedWrite(resp.Queries)
}
}
}
}
}
func (a *agent) waitingDo(req *fasthttp.Request, res *fasthttp.Response) {
err := a.FastClient.Do(req, res)
for err != nil || res.StatusCode() != http.StatusOK {
fmt.Println(err, res.StatusCode())
a.Stats.RecordStats(1, 0, 0)
<-time.Tick(time.Duration(rand.Intn(120)+1) * time.Second)
err = a.FastClient.Do(req, res)
}
}
func (a *agent) enroll(i int, onlyAlreadyEnrolled bool) error {
a.NodeKey = a.NodeKeyManager.Get(i)
if a.NodeKey != "" {
a.Stats.RecordStats(0, 1, 0)
return nil
}
if onlyAlreadyEnrolled {
return fmt.Errorf("not enrolled")
}
var body bytes.Buffer
if err := a.Templates.ExecuteTemplate(&body, "enroll", a); err != nil {
log.Println("execute template:", err)
return err
}
req := fasthttp.AcquireRequest()
req.SetBody(body.Bytes())
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Add("User-Agent", "osquery/4.6.0")
req.SetRequestURI(a.ServerAddress + "/api/v1/osquery/enroll")
res := fasthttp.AcquireResponse()
a.waitingDo(req, res)
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
if res.StatusCode() != http.StatusOK {
log.Println("enroll status:", res.StatusCode())
return fmt.Errorf("status code: %d", res.StatusCode())
}
var parsedResp enrollResponse
if err := json.Unmarshal(res.Body(), &parsedResp); err != nil {
log.Println("json parse:", err)
return err
}
a.NodeKey = parsedResp.NodeKey
a.Stats.RecordStats(0, 1, 0)
a.NodeKeyManager.Add(a.NodeKey)
return nil
}
func (a *agent) config() {
body := bytes.NewBufferString(`{"node_key": "` + a.NodeKey + `"}`)
req := fasthttp.AcquireRequest()
req.SetBody(body.Bytes())
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Add("User-Agent", "osquery/4.6.0")
req.SetRequestURI(a.ServerAddress + "/api/v1/osquery/config")
res := fasthttp.AcquireResponse()
a.waitingDo(req, res)
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
if res.StatusCode() != http.StatusOK {
log.Println("config status:", res.StatusCode())
return
}
// No need to read the config body
}
const stringVals = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."
func (a *agent) randomString(n int) string {
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
sb.WriteByte(stringVals[rand.Int63()%int64(len(stringVals))])
}
return sb.String()
}
func (a *agent) CachedString(key string) string {
if val, ok := a.strings[key]; ok {
return val
}
val := a.randomString(12)
a.strings[key] = val
return val
}
func (a *agent) SoftwareMacOS() []fleet.Software {
software := make([]fleet.Software, a.SoftwareCount)
for i := 0; i < len(software); i++ {
software[i] = fleet.Software{
Name: "Placeholder_Software",
Version: "0.0.1",
BundleIdentifier: "com.fleetdm.osquery-perf",
Source: "osquery-perf",
}
}
return software
}
func (a *agent) DistributedRead() (*distributedReadResponse, error) {
req := fasthttp.AcquireRequest()
req.SetBody([]byte(`{"node_key": "` + a.NodeKey + `"}`))
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Add("User-Agent", "osquery/4.6.0")
req.SetRequestURI(a.ServerAddress + "/api/v1/osquery/distributed/read")
res := fasthttp.AcquireResponse()
a.waitingDo(req, res)
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
var parsedResp distributedReadResponse
if err := json.Unmarshal(res.Body(), &parsedResp); err != nil {
log.Println("json parse:", err)
return nil, err
}
return &parsedResp, nil
}
var defaultQueryResult = []map[string]string{
{"foo": "bar"},
}
func (a *agent) runPolicy(query string) []map[string]string {
if rand.Float64() <= a.policyPassProb {
return []map[string]string{
{"1": "1"},
}
}
return nil
}
func (a *agent) DistributedWrite(queries map[string]string) {
r := service.SubmitDistributedQueryResultsRequest{
Results: make(fleet.OsqueryDistributedQueryResults),
Statuses: make(map[string]fleet.OsqueryStatus),
}
r.NodeKey = a.NodeKey
const hostPolicyQueryPrefix = "fleet_policy_query_"
for name := range queries {
r.Results[name] = defaultQueryResult
r.Statuses[name] = fleet.StatusOK
if strings.HasPrefix(name, hostPolicyQueryPrefix) {
r.Results[name] = a.runPolicy(queries[name])
continue
}
if t := a.Templates.Lookup(name); t == nil {
continue
}
var ni bytes.Buffer
err := a.Templates.ExecuteTemplate(&ni, name, a)
if err != nil {
panic(err)
}
var m []map[string]string
err = json.Unmarshal(ni.Bytes(), &m)
if err != nil {
panic(err)
}
r.Results[name] = m
}
body, err := json.Marshal(r)
if err != nil {
panic(err)
}
req := fasthttp.AcquireRequest()
req.SetBody(body)
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Add("User-Agent", "osquery/4.6.0")
req.SetRequestURI(a.ServerAddress + "/api/v1/osquery/distributed/write")
res := fasthttp.AcquireResponse()
a.waitingDo(req, res)
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
a.Stats.RecordStats(0, 0, 1)
// No need to read the distributed write body
}
func main() {
serverURL := flag.String("server_url", "https://localhost:8080", "URL (with protocol and port of osquery server)")
enrollSecret := flag.String("enroll_secret", "", "Enroll secret to authenticate enrollment")
hostCount := flag.Int("host_count", 10, "Number of hosts to start (default 10)")
randSeed := flag.Int64("seed", time.Now().UnixNano(), "Seed for random generator (default current time)")
startPeriod := flag.Duration("start_period", 10*time.Second, "Duration to spread start of hosts over")
configInterval := flag.Duration("config_interval", 1*time.Minute, "Interval for config requests")
queryInterval := flag.Duration("query_interval", 10*time.Second, "Interval for live query requests")
onlyAlreadyEnrolled := flag.Bool("only_already_enrolled", false, "Only start agents that are already enrolled")
nodeKeyFile := flag.String("node_key_file", "", "File with node keys to use")
softwareCount := flag.Int("software_count", 10, "Number of installed applications reported to fleet")
policyPassProb := flag.Float64("policy_pass_prob", 1.0, "Probability of policies to pass [0, 1].")
flag.Parse()
rand.Seed(*randSeed)
// Currently all hosts will be macOS.
tmpl, err := template.ParseFS(templatesFS, "mac10.14.6.tmpl")
if err != nil {
log.Fatal("parse templates: ", err)
}
// Spread starts over the interval to prevent thundering herd
sleepTime := *startPeriod / time.Duration(*hostCount)
stats := &Stats{}
go stats.runLoop()
nodeKeyManager := &nodeKeyManager{}
if nodeKeyFile != nil {
nodeKeyManager.filepath = *nodeKeyFile
nodeKeyManager.LoadKeys()
}
for i := 0; i < *hostCount; i++ {
a := newAgent(*serverURL, *enrollSecret, tmpl, *configInterval, *queryInterval, *softwareCount, *policyPassProb)
a.Stats = stats
a.NodeKeyManager = nodeKeyManager
go a.runLoop(i, onlyAlreadyEnrolled != nil && *onlyAlreadyEnrolled)
time.Sleep(sleepTime)
}
fmt.Println("Agents running. Kill with C-c.")
<-make(chan struct{})
}