Fix race condition in tests when using global var loginRateLimit (#5197)

This commit is contained in:
Lucas Manuel Rodriguez 2022-04-19 10:35:53 -03:00 committed by GitHub
parent 3228e1b5dc
commit f7048b711c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 17 deletions

View File

@ -75,8 +75,33 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon
}
}
type extraHandlerOpts struct {
loginRateLimit *throttled.Rate
}
// ExtraHandlerOption allows adding extra configuration to the HTTP handler.
type ExtraHandlerOption func(*extraHandlerOpts)
// WithLoginRateLimit configures the rate limit for the login endpoint.
func WithLoginRateLimit(r throttled.Rate) ExtraHandlerOption {
return func(o *extraHandlerOpts) {
o.loginRateLimit = &r
}
}
// MakeHandler creates an HTTP handler for the Fleet server endpoints.
func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Logger, limitStore throttled.GCRAStore) http.Handler {
func MakeHandler(
svc fleet.Service,
config config.FleetConfig,
logger kitlog.Logger,
limitStore throttled.GCRAStore,
extra ...ExtraHandlerOption,
) http.Handler {
var eopts extraHandlerOpts
for _, fn := range extra {
fn(&eopts)
}
fleetAPIOptions := []kithttp.ServerOption{
kithttp.ServerBefore(
kithttp.PopulateRequestContext, // populate the request context with common fields
@ -98,7 +123,7 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log
r.Use(publicIP)
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions)
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts)
// Results endpoint is handled different due to websockets use
@ -203,14 +228,9 @@ func addMetrics(r *mux.Router) {
r.Walk(walkFn)
}
var (
// those are conceptually constants, but var so they can be changed in tests
forgotPasswordRateLimit = throttled.PerHour(10)
loginRateLimit = throttled.PerMin(10)
)
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption,
extra extraHandlerOpts,
) {
apiVersions := []string{"v1", "2022-04"}
@ -413,8 +433,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ne.WithAltPaths("/api/v1/osquery/enroll").
POST("/api/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{})
// For some reason osquery does not provide a node key with the block data.
// Instead the carve session ID should be verified in the service method.
// For some reason osquery does not provide a node key with the block data.
// Instead the carve session ID should be verified in the service method.
ne.WithAltPaths("/api/v1/osquery/carve/block").
POST("/api/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{})
@ -429,11 +449,15 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
limiter := ratelimit.NewMiddleware(limitStore)
ne.
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: forgotPasswordRateLimit, MaxBurst: 9})).
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
ne.
WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
loginRateLimit := throttled.PerMin(10)
if extra.loginRateLimit != nil {
loginRateLimit = *extra.loginRateLimit
}
ne.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
}

View File

@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/throttled/throttled/v2"
)
type withDS struct {
@ -46,7 +45,6 @@ type withServer struct {
func (ts *withServer) SetupSuite(dbName string) {
ts.withDS.SetupSuite(dbName)
loginRateLimit = throttled.PerMin(100)
rs := pubsub.NewInmemQueryResults()
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs})
ts.server = server
@ -149,7 +147,7 @@ func (ts *withServer) getTestToken(email string, password string) string {
defer resp.Body.Close()
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
var jsn = struct {
jsn := struct {
User *fleet.User `json:"user"`
Token string `json:"token"`
Err []map[string]string `json:"errors,omitempty"`

View File

@ -22,6 +22,7 @@ import (
kitlog "github.com/go-kit/kit/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/throttled/throttled/v2"
"github.com/throttled/throttled/v2/store/memstore"
)
@ -177,7 +178,7 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe
}
limitStore, _ := memstore.New(0)
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore)
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore, WithLoginRateLimit(throttled.PerMin(100)))
server := httptest.NewServer(r)
t.Cleanup(func() {
server.Close()