Do not leak ticks and fix race condition (#14667)

This PR introduces the following fixes:
- Fixes a race condition by protecting the `scheduledQueries` and
`scheduledQueryData` fields with a `sync.Mutex` (*).
- Adds some more information about request counts for /log, /config and
/distributed/read requests (and uptime).
- Fixes the resource leaks around creating time.Ticks on every for loop
iteration.

(*) Sample of the race condition when running with `-race`:
```
==================
WARNING: DATA RACE
Read at 0x00c000604800 by goroutine 20:
  main.(*agent).runLoop()
      /Users/luk/fleetdm/git/fleet/cmd/osquery-perf/agent.go:525 +0x71b
  main.main.func2()
      /Users/luk/fleetdm/git/fleet/cmd/osquery-perf/agent.go:1737 +0x4e

Previous write at 0x00c000604800 by goroutine 40:
  main.(*agent).config()
      /Users/luk/fleetdm/git/fleet/cmd/osquery-perf/agent.go:915 +0xb30
  main.(*agent).runLoop.func2()
      /Users/luk/fleetdm/git/fleet/cmd/osquery-perf/agent.go:512 +0x37
```
This commit is contained in:
Lucas Manuel Rodriguez 2023-10-20 10:29:59 -03:00 committed by GitHub
parent 2f589ff37c
commit 5d7ee58a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -106,12 +106,16 @@ func init() {
}
type Stats struct {
startTime time.Time
errors int
osqueryEnrollments int
orbitEnrollments int
mdmEnrollments int
distributedWrites int
mdmCommandsReceived int
distributedReads int
configRequests int
resultLogRequests int
orbitErrors int
mdmErrors int
desktopErrors int
@ -155,6 +159,24 @@ func (s *Stats) IncrementMDMCommandsReceived() {
s.mdmCommandsReceived++
}
func (s *Stats) IncrementDistributedReads() {
s.l.Lock()
defer s.l.Unlock()
s.distributedReads++
}
func (s *Stats) IncrementConfigRequests() {
s.l.Lock()
defer s.l.Unlock()
s.configRequests++
}
func (s *Stats) IncrementResultLogRequests() {
s.l.Lock()
defer s.l.Unlock()
s.resultLogRequests++
}
func (s *Stats) IncrementOrbitErrors() {
s.l.Lock()
defer s.l.Unlock()
@ -178,13 +200,17 @@ func (s *Stats) Log() {
defer s.l.Unlock()
fmt.Printf(
"%s :: error rate: %.2f, osquery enrolls: %d, orbit enrolls: %d, mdm enrolls: %d, distributed/writes: %d, mdm commands received: %d, orbit errors: %d, desktop errors: %d, mdm errors: %d\n",
"%s :: uptime: %s, error rate: %.2f, osquery enrolls: %d, orbit enrolls: %d, mdm enrolls: %d, distributed/reads: %d, distributed/writes: %d, config requests: %d, result log requests: %d, mdm commands received: %d, orbit errors: %d, desktop errors: %d, mdm errors: %d\n",
time.Now().Format("2006-01-02T15:04:05Z"),
time.Since(s.startTime).Round(time.Second),
float64(s.errors)/float64(s.osqueryEnrollments),
s.osqueryEnrollments,
s.orbitEnrollments,
s.mdmEnrollments,
s.distributedReads,
s.distributedWrites,
s.configRequests,
s.resultLogRequests,
s.mdmCommandsReceived,
s.orbitErrors,
s.desktopErrors,
@ -278,8 +304,6 @@ type agent struct {
deviceAuthToken *string
orbitNodeKey *string
scheduledQueries []string
// mdmClient simulates a device running the MDM protocol (client side).
mdmClient *mdmtest.TestMDMClient
// isEnrolledToMDM is true when the mdmDevice has enrolled.
@ -304,7 +328,10 @@ type agent struct {
QueryInterval time.Duration
MDMCheckInInterval time.Duration
DiskEncryptionEnabled bool
scheduledQueryData []scheduledQuery
scheduledQueriesMu sync.Mutex
scheduledQueries []string
scheduledQueryData []scheduledQuery
}
type entityCount struct {
@ -463,51 +490,49 @@ func (a *agent) runLoop(i int, onlyAlreadyEnrolled bool) {
//
// Thus we try to simulate that as much as we can.
// distributed thread:
// (1) distributed thread:
go func() {
liveQueryTicker := time.Tick(a.QueryInterval)
for {
select {
case <-liveQueryTicker:
resp, err := a.DistributedRead()
if err != nil {
log.Println(err)
} else if len(resp.Queries) > 0 {
a.DistributedWrite(resp.Queries)
}
liveQueryTicker := time.NewTicker(a.QueryInterval)
defer liveQueryTicker.Stop()
for range liveQueryTicker.C {
resp, err := a.DistributedRead()
if err != nil {
log.Println(err)
} else if len(resp.Queries) > 0 {
a.DistributedWrite(resp.Queries)
}
}
}()
// config thread:
// (2) config thread:
go func() {
for {
configTicker := time.Tick(a.ConfigInterval)
select {
case <-configTicker:
a.config()
}
configTicker := time.NewTicker(a.ConfigInterval)
defer configTicker.Stop()
for range configTicker.C {
a.config()
}
}()
// logger thread:
for {
logTicker := time.Tick(a.LogInterval)
select {
case <-logTicker:
// check if we have any scheduled queries that should be returning results
var results []json.RawMessage
now := time.Now().Unix()
for i, query := range a.scheduledQueryData {
if query.NextRun == 0 || now >= int64(query.NextRun) {
results = append(results, a.scheduledQueryResults(query.PackName, query.Name, int(query.NumResults)))
a.scheduledQueryData[i].NextRun = float64(now + int64(query.ScheduleInterval))
}
}
if len(results) > 0 {
a.SubmitLogs(results)
// (3) logger thread:
logTicker := time.NewTicker(a.LogInterval)
defer logTicker.Stop()
for range logTicker.C {
// check if we have any scheduled queries that should be returning results
var results []json.RawMessage
now := time.Now().Unix()
a.scheduledQueriesMu.Lock()
for i, query := range a.scheduledQueryData {
if query.NextRun == 0 || now >= int64(query.NextRun) {
results = append(results, a.scheduledQueryResults(query.PackName, query.Name, int(query.NumResults)))
a.scheduledQueryData[i].NextRun = float64(now + int64(query.ScheduleInterval))
}
}
a.scheduledQueriesMu.Unlock()
if len(results) > 0 {
a.SubmitLogs(results)
}
}
}
@ -842,6 +867,8 @@ func (a *agent) config() {
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
a.stats.IncrementConfigRequests()
if res.StatusCode() != http.StatusOK {
log.Println("config status:", res.StatusCode())
return
@ -885,8 +912,11 @@ func (a *agent) config() {
}
}
a.scheduledQueriesMu.Lock()
a.scheduledQueries = scheduledQueries
a.scheduledQueryData = scheduledQueryData
a.scheduledQueriesMu.Unlock()
}
const stringVals = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_."
@ -1021,6 +1051,8 @@ func (a *agent) DistributedRead() (*distributedReadResponse, error) {
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
a.stats.IncrementDistributedReads()
var parsedResp distributedReadResponse
if err := json.Unmarshal(res.Body(), &parsedResp); err != nil {
log.Println("json parse:", err)
@ -1056,6 +1088,9 @@ func (a *agent) runPolicy(query string) []map[string]string {
}
func (a *agent) randomQueryStats() []map[string]string {
a.scheduledQueriesMu.Lock()
defer a.scheduledQueriesMu.Unlock()
var stats []map[string]string
for _, scheduledQuery := range a.scheduledQueries {
stats = append(stats, map[string]string{
@ -1500,6 +1535,8 @@ func (a *agent) SubmitLogs(results []json.RawMessage) {
fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(res)
a.stats.IncrementResultLogRequests()
}
// Creates a set of results for use in tests for Query Results.
@ -1636,7 +1673,9 @@ func main() {
// Spread starts over the interval to prevent thundering herd
sleepTime := *startPeriod / time.Duration(*hostCount)
stats := &Stats{}
stats := &Stats{
startTime: time.Now(),
}
go stats.runLoop()
nodeKeyManager := &nodeKeyManager{}