mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Fix redis locking receive (#1655)
* Receive redis data with timeout to not hold the connection * Address review comments
This commit is contained in:
parent
e85996c291
commit
96c0244c04
1
changes/fix-redis-locking-receive
Normal file
1
changes/fix-redis-locking-receive
Normal file
@ -0,0 +1 @@
|
||||
* When a connection from a live query websocket is closed, Fleet now timeouts the receive and handles the different cases correctly to not hold the connection to Redis.
|
2
go.mod
2
go.mod
@ -30,7 +30,7 @@ require (
|
||||
github.com/go-logfmt/logfmt v0.5.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/golang-jwt/jwt/v4 v4.0.0
|
||||
github.com/gomodule/redigo v1.8.4
|
||||
github.com/gomodule/redigo v1.8.5
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/go-github/v37 v37.0.0
|
||||
github.com/google/uuid v1.1.2
|
||||
|
2
go.sum
2
go.sum
@ -350,6 +350,8 @@ github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4 h1:zwtduBRr5SSW
|
||||
github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4/go.mod h1:Izgrg8RkN3rCIMLGE9CyYmU9pY2Jer6DgANEnZ/L/cQ=
|
||||
github.com/gomodule/redigo v1.8.4 h1:Z5JUg94HMTR1XpwBaSH4vq3+PNSIykBLxMdglbw10gg=
|
||||
github.com/gomodule/redigo v1.8.4/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||
github.com/gomodule/redigo v1.8.5 h1:nRAxCa+SVsyjSBrtZmG/cqb6VbTmuRzpg/PoTFlpumc=
|
||||
github.com/gomodule/redigo v1.8.5/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// waitTimeout waits for the waitgroup for the specified max timeout.
|
||||
@ -27,31 +26,8 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func setupRedis(t *testing.T) (store *redisQueryResults, teardown func()) {
|
||||
var (
|
||||
addr = "127.0.0.1:6379"
|
||||
password = ""
|
||||
database = 0
|
||||
useTLS = false
|
||||
dupResults = false
|
||||
)
|
||||
|
||||
pool, err := NewRedisPool(addr, password, database, useTLS)
|
||||
require.NoError(t, err)
|
||||
store = NewRedisQueryResults(pool, dupResults)
|
||||
|
||||
_, err = store.pool.Get().Do("PING")
|
||||
require.Nil(t, err)
|
||||
|
||||
teardown = func() {
|
||||
store.pool.Close()
|
||||
}
|
||||
|
||||
return store, teardown
|
||||
}
|
||||
|
||||
func TestQueryResultsStoreErrors(t *testing.T) {
|
||||
store, teardown := setupRedis(t)
|
||||
store, teardown := SetupRedisForTest(t)
|
||||
defer teardown()
|
||||
|
||||
// Write with no subscriber
|
||||
@ -78,7 +54,7 @@ func TestQueryResultsStoreErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryResultsStore(t *testing.T) {
|
||||
store, teardown := setupRedis(t)
|
||||
store, teardown := SetupRedisForTest(t)
|
||||
defer teardown()
|
||||
|
||||
// Test handling results for two campaigns in parallel
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -94,6 +95,10 @@ func pubSubForID(id uint) string {
|
||||
return fmt.Sprintf("results_%d", id)
|
||||
}
|
||||
|
||||
func (r *redisQueryResults) Pool() *redisc.Cluster {
|
||||
return r.pool
|
||||
}
|
||||
|
||||
func (r *redisQueryResults) WriteResult(result fleet.DistributedQueryResult) error {
|
||||
conn := r.pool.Get()
|
||||
defer conn.Close()
|
||||
@ -121,6 +126,17 @@ func (r *redisQueryResults) WriteResult(result fleet.DistributedQueryResult) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeOrDone tries to write the item into the channel taking into account context.Done(). If context is done, returns
|
||||
// true, otherwise false
|
||||
func writeOrDone(ctx context.Context, ch chan<- interface{}, item interface{}) bool {
|
||||
select {
|
||||
case ch <- item:
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// receiveMessages runs in a goroutine, forwarding messages from the Pub/Sub
|
||||
// connection over the provided channel. This effectively allows a select
|
||||
// statement to run on conn.Receive() (by running on the channel that is being
|
||||
@ -131,8 +147,8 @@ func receiveMessages(ctx context.Context, pool *redisc.Cluster, query fleet.Dist
|
||||
|
||||
pubSubName := pubSubForID(query.ID)
|
||||
err := conn.Subscribe(pubSubName)
|
||||
if err != nil {
|
||||
outChan <- errors.Wrap(err, "subscribe to channel")
|
||||
if err != nil && writeOrDone(ctx, outChan, errors.Wrap(err, "subscribe to channel")) {
|
||||
return
|
||||
}
|
||||
defer conn.Unsubscribe(pubSubName)
|
||||
|
||||
@ -141,24 +157,28 @@ func receiveMessages(ctx context.Context, pool *redisc.Cluster, query fleet.Dist
|
||||
}()
|
||||
|
||||
for {
|
||||
msg := conn.Receive()
|
||||
// This Receive needs to be with timeout, otherwise we might block on it forever
|
||||
msg := conn.ReceiveWithTimeout(5 * time.Second)
|
||||
|
||||
select {
|
||||
case outChan <- msg:
|
||||
switch msg := msg.(type) {
|
||||
case error:
|
||||
// If an error occurred (i.e. connection was closed),
|
||||
// then we should exit
|
||||
return
|
||||
if err, ok := msg.(net.Error); ok && err.Timeout() {
|
||||
// We ignore timeouts, we just want them there to make sure we don't hang on Receiving
|
||||
continue
|
||||
} else {
|
||||
// If an error occurred (i.e. connection was closed), then we should exit
|
||||
return
|
||||
}
|
||||
case redis.Subscription:
|
||||
// If the subscription count is 0, the ReadChannel call
|
||||
// that invoked this goroutine has unsubscribed, and we
|
||||
// can exit
|
||||
// If the subscription count is 0, the ReadChannel call that invoked this goroutine has unsubscribed,
|
||||
// and we can exit
|
||||
if msg.Count == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
conn.Unsubscribe(pubSubName)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -176,8 +196,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
|
||||
defer close(outChannel)
|
||||
|
||||
for {
|
||||
// Loop reading messages from conn.Receive() (via
|
||||
// msgChannel) until the context is cancelled.
|
||||
// Loop reading messages from conn.Receive() (via msgChannel) until the context is cancelled.
|
||||
select {
|
||||
case msg, ok := <-msgChannel:
|
||||
if !ok {
|
||||
@ -188,11 +207,17 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib
|
||||
var res fleet.DistributedQueryResult
|
||||
err := json.Unmarshal(msg.Data, &res)
|
||||
if err != nil {
|
||||
outChannel <- err
|
||||
if writeOrDone(ctx, outChannel, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if writeOrDone(ctx, outChannel, res) {
|
||||
return
|
||||
}
|
||||
outChannel <- res
|
||||
case error:
|
||||
outChannel <- errors.Wrap(msg, "reading from redis")
|
||||
if writeOrDone(ctx, outChannel, errors.Wrap(msg, "reading from redis")) {
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
30
server/pubsub/testing_utils.go
Normal file
30
server/pubsub/testing_utils.go
Normal file
@ -0,0 +1,30 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func SetupRedisForTest(t *testing.T) (store *redisQueryResults, teardown func()) {
|
||||
var (
|
||||
addr = "127.0.0.1:6379"
|
||||
password = ""
|
||||
database = 0
|
||||
useTLS = false
|
||||
dupResults = false
|
||||
)
|
||||
|
||||
pool, err := NewRedisPool(addr, password, database, useTLS)
|
||||
require.NoError(t, err)
|
||||
store = NewRedisQueryResults(pool, dupResults)
|
||||
|
||||
_, err = store.pool.Get().Do("PING")
|
||||
require.Nil(t, err)
|
||||
|
||||
teardown = func() {
|
||||
store.pool.Close()
|
||||
}
|
||||
|
||||
return store, teardown
|
||||
}
|
163
server/service/service_campaign_test.go
Normal file
163
server/service/service_campaign_test.go
Normal file
@ -0,0 +1,163 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/live_query"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
ws "github.com/fleetdm/fleet/v4/server/websocket"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) {
|
||||
store, teardown := pubsub.SetupRedisForTest(t)
|
||||
defer teardown()
|
||||
|
||||
mockClock := clock.NewMockClock()
|
||||
ds := new(mock.Store)
|
||||
lq := new(live_query.MockLiveQuery)
|
||||
svc := newTestServiceWithClock(ds, store, lq, mockClock)
|
||||
|
||||
campaign := &fleet.DistributedQueryCampaign{ID: 42}
|
||||
|
||||
ds.LabelQueriesForHostFunc = func(host *fleet.Host, cutoff time.Time) (map[string]string, error) {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
ds.SaveHostFunc = func(host *fleet.Host) error {
|
||||
return nil
|
||||
}
|
||||
ds.AppConfigFunc = func() (*fleet.AppConfig, error) {
|
||||
return &fleet.AppConfig{EnableHostUsers: true}, nil
|
||||
}
|
||||
ds.NewQueryFunc = func(query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
|
||||
return query, nil
|
||||
}
|
||||
ds.NewDistributedQueryCampaignFunc = func(camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) {
|
||||
return camp, nil
|
||||
}
|
||||
ds.NewDistributedQueryCampaignTargetFunc = func(target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) {
|
||||
return target, nil
|
||||
}
|
||||
ds.HostIDsInTargetsFunc = func(filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) {
|
||||
return []uint{1}, nil
|
||||
}
|
||||
ds.CountHostsInTargetsFunc = func(filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) {
|
||||
return fleet.TargetMetrics{TotalHosts: 1}, nil
|
||||
}
|
||||
ds.NewActivityFunc = func(user *fleet.User, activityType string, details *map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
ds.SessionByKeyFunc = func(key string) (*fleet.Session, error) {
|
||||
return &fleet.Session{
|
||||
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
|
||||
ID: 42,
|
||||
AccessedAt: time.Now(),
|
||||
UserID: 999,
|
||||
Key: "asd",
|
||||
}, nil
|
||||
}
|
||||
|
||||
host := &fleet.Host{ID: 1, Platform: "windows"}
|
||||
|
||||
lq.On("QueriesForHost", uint(1)).Return(
|
||||
map[string]string{
|
||||
strconv.Itoa(int(campaign.ID)): "select * from time",
|
||||
},
|
||||
nil,
|
||||
)
|
||||
lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil)
|
||||
lq.On("RunQuery", "0", "select year, month, day, hour, minutes, seconds from time", []uint{1}).Return(nil)
|
||||
viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{
|
||||
User: &fleet.User{
|
||||
ID: 0,
|
||||
GlobalRole: ptr.String(fleet.RoleAdmin),
|
||||
},
|
||||
})
|
||||
q := "select year, month, day, hour, minutes, seconds from time"
|
||||
_, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}})
|
||||
require.NoError(t, err)
|
||||
|
||||
s := httptest.NewServer(makeStreamDistributedQueryCampaignResultsHandler(svc, kitlog.NewNopLogger()))
|
||||
defer s.Close()
|
||||
// Convert http://127.0.0.1 to ws://127.0.0.1
|
||||
u := "ws" + strings.TrimPrefix(s.URL, "http") + "/api/v1/fleet/results/websocket"
|
||||
|
||||
// Connect to the server
|
||||
dialer := &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(u, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.WriteJSON(ws.JSONMessage{
|
||||
Type: "auth",
|
||||
Data: map[string]interface{}{"token": "asd"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.WriteJSON(ws.JSONMessage{
|
||||
Type: "select_campaign",
|
||||
Data: map[string]interface{}{"campaign_id": campaign.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ds.MarkSessionAccessedFunc = func(*fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
ds.UserByIDFunc = func(id uint) (*fleet.User, error) {
|
||||
return &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, nil
|
||||
}
|
||||
ds.DistributedQueryCampaignFunc = func(id uint) (*fleet.DistributedQueryCampaign, error) {
|
||||
return campaign, nil
|
||||
}
|
||||
ds.SaveDistributedQueryCampaignFunc = func(camp *fleet.DistributedQueryCampaign) error {
|
||||
return nil
|
||||
}
|
||||
ds.DistributedQueryCampaignTargetIDsFunc = func(id uint) (targets *fleet.HostTargets, err error) {
|
||||
return &fleet.HostTargets{HostIDs: []uint{1}}, nil
|
||||
}
|
||||
ds.QueryFunc = func(id uint) (*fleet.Query, error) {
|
||||
return &fleet.Query{}, nil
|
||||
}
|
||||
|
||||
/*****************************************************************************************/
|
||||
/* THE ACTUAL TEST BEGINS HERE */
|
||||
/*****************************************************************************************/
|
||||
prevActiveConn := 0
|
||||
for prevActiveConn < 3 {
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
for _, s := range store.Pool().Stats() {
|
||||
prevActiveConn = s.ActiveCount
|
||||
}
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
newActiveConn := prevActiveConn
|
||||
for _, s := range store.Pool().Stats() {
|
||||
newActiveConn = s.ActiveCount
|
||||
}
|
||||
require.Equal(t, prevActiveConn-1, newActiveConn)
|
||||
}
|
Loading…
Reference in New Issue
Block a user