Fix race condition in tests when using global var loginRateLimit (#5197)

This commit is contained in:
Lucas Manuel Rodriguez 2022-04-19 10:35:53 -03:00 committed by GitHub
parent 3228e1b5dc
commit f7048b711c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 17 deletions

View File

@ -75,8 +75,33 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon
} }
} }
type extraHandlerOpts struct {
loginRateLimit *throttled.Rate
}
// ExtraHandlerOption allows adding extra configuration to the HTTP handler.
type ExtraHandlerOption func(*extraHandlerOpts)
// WithLoginRateLimit configures the rate limit for the login endpoint.
func WithLoginRateLimit(r throttled.Rate) ExtraHandlerOption {
return func(o *extraHandlerOpts) {
o.loginRateLimit = &r
}
}
// MakeHandler creates an HTTP handler for the Fleet server endpoints. // MakeHandler creates an HTTP handler for the Fleet server endpoints.
func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Logger, limitStore throttled.GCRAStore) http.Handler { func MakeHandler(
svc fleet.Service,
config config.FleetConfig,
logger kitlog.Logger,
limitStore throttled.GCRAStore,
extra ...ExtraHandlerOption,
) http.Handler {
var eopts extraHandlerOpts
for _, fn := range extra {
fn(&eopts)
}
fleetAPIOptions := []kithttp.ServerOption{ fleetAPIOptions := []kithttp.ServerOption{
kithttp.ServerBefore( kithttp.ServerBefore(
kithttp.PopulateRequestContext, // populate the request context with common fields kithttp.PopulateRequestContext, // populate the request context with common fields
@ -98,7 +123,7 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log
r.Use(publicIP) r.Use(publicIP)
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions) attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts)
// Results endpoint is handled different due to websockets use // Results endpoint is handled different due to websockets use
@ -203,14 +228,9 @@ func addMetrics(r *mux.Router) {
r.Walk(walkFn) r.Walk(walkFn)
} }
var (
// those are conceptually constants, but var so they can be changed in tests
forgotPasswordRateLimit = throttled.PerHour(10)
loginRateLimit = throttled.PerMin(10)
)
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig, func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption, logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption,
extra extraHandlerOpts,
) { ) {
apiVersions := []string{"v1", "2022-04"} apiVersions := []string{"v1", "2022-04"}
@ -413,8 +433,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ne.WithAltPaths("/api/v1/osquery/enroll"). ne.WithAltPaths("/api/v1/osquery/enroll").
POST("/api/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{}) POST("/api/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{})
// For some reason osquery does not provide a node key with the block data. // For some reason osquery does not provide a node key with the block data.
// Instead the carve session ID should be verified in the service method. // Instead the carve session ID should be verified in the service method.
ne.WithAltPaths("/api/v1/osquery/carve/block"). ne.WithAltPaths("/api/v1/osquery/carve/block").
POST("/api/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{}) POST("/api/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{})
@ -429,11 +449,15 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
limiter := ratelimit.NewMiddleware(limitStore) limiter := ratelimit.NewMiddleware(limitStore)
ne. ne.
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: forgotPasswordRateLimit, MaxBurst: 9})). WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{}) POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
ne. loginRateLimit := throttled.PerMin(10)
WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})). if extra.loginRateLimit != nil {
loginRateLimit = *extra.loginRateLimit
}
ne.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{}) POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
} }

View File

@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/throttled/throttled/v2"
) )
type withDS struct { type withDS struct {
@ -46,7 +45,6 @@ type withServer struct {
func (ts *withServer) SetupSuite(dbName string) { func (ts *withServer) SetupSuite(dbName string) {
ts.withDS.SetupSuite(dbName) ts.withDS.SetupSuite(dbName)
loginRateLimit = throttled.PerMin(100)
rs := pubsub.NewInmemQueryResults() rs := pubsub.NewInmemQueryResults()
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs}) users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs})
ts.server = server ts.server = server
@ -149,7 +147,7 @@ func (ts *withServer) getTestToken(email string, password string) string {
defer resp.Body.Close() defer resp.Body.Close()
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode) assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
var jsn = struct { jsn := struct {
User *fleet.User `json:"user"` User *fleet.User `json:"user"`
Token string `json:"token"` Token string `json:"token"`
Err []map[string]string `json:"errors,omitempty"` Err []map[string]string `json:"errors,omitempty"`

View File

@ -22,6 +22,7 @@ import (
kitlog "github.com/go-kit/kit/log" kitlog "github.com/go-kit/kit/log"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/throttled/throttled/v2"
"github.com/throttled/throttled/v2/store/memstore" "github.com/throttled/throttled/v2/store/memstore"
) )
@ -177,7 +178,7 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe
} }
limitStore, _ := memstore.New(0) limitStore, _ := memstore.New(0)
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore) r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore, WithLoginRateLimit(throttled.PerMin(100)))
server := httptest.NewServer(r) server := httptest.NewServer(r)
t.Cleanup(func() { t.Cleanup(func() {
server.Close() server.Close()