mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
Fix race condition in tests when using global var loginRateLimit (#5197)
This commit is contained in:
parent
3228e1b5dc
commit
f7048b711c
@ -75,8 +75,33 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon
|
||||
}
|
||||
}
|
||||
|
||||
type extraHandlerOpts struct {
|
||||
loginRateLimit *throttled.Rate
|
||||
}
|
||||
|
||||
// ExtraHandlerOption allows adding extra configuration to the HTTP handler.
|
||||
type ExtraHandlerOption func(*extraHandlerOpts)
|
||||
|
||||
// WithLoginRateLimit configures the rate limit for the login endpoint.
|
||||
func WithLoginRateLimit(r throttled.Rate) ExtraHandlerOption {
|
||||
return func(o *extraHandlerOpts) {
|
||||
o.loginRateLimit = &r
|
||||
}
|
||||
}
|
||||
|
||||
// MakeHandler creates an HTTP handler for the Fleet server endpoints.
|
||||
func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Logger, limitStore throttled.GCRAStore) http.Handler {
|
||||
func MakeHandler(
|
||||
svc fleet.Service,
|
||||
config config.FleetConfig,
|
||||
logger kitlog.Logger,
|
||||
limitStore throttled.GCRAStore,
|
||||
extra ...ExtraHandlerOption,
|
||||
) http.Handler {
|
||||
var eopts extraHandlerOpts
|
||||
for _, fn := range extra {
|
||||
fn(&eopts)
|
||||
}
|
||||
|
||||
fleetAPIOptions := []kithttp.ServerOption{
|
||||
kithttp.ServerBefore(
|
||||
kithttp.PopulateRequestContext, // populate the request context with common fields
|
||||
@ -98,7 +123,7 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log
|
||||
|
||||
r.Use(publicIP)
|
||||
|
||||
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions)
|
||||
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts)
|
||||
|
||||
// Results endpoint is handled different due to websockets use
|
||||
|
||||
@ -203,14 +228,9 @@ func addMetrics(r *mux.Router) {
|
||||
r.Walk(walkFn)
|
||||
}
|
||||
|
||||
var (
|
||||
// those are conceptually constants, but var so they can be changed in tests
|
||||
forgotPasswordRateLimit = throttled.PerHour(10)
|
||||
loginRateLimit = throttled.PerMin(10)
|
||||
)
|
||||
|
||||
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
|
||||
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption,
|
||||
extra extraHandlerOpts,
|
||||
) {
|
||||
apiVersions := []string{"v1", "2022-04"}
|
||||
|
||||
@ -413,8 +433,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
||||
ne.WithAltPaths("/api/v1/osquery/enroll").
|
||||
POST("/api/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{})
|
||||
|
||||
// For some reason osquery does not provide a node key with the block data.
|
||||
// Instead the carve session ID should be verified in the service method.
|
||||
// For some reason osquery does not provide a node key with the block data.
|
||||
// Instead the carve session ID should be verified in the service method.
|
||||
ne.WithAltPaths("/api/v1/osquery/carve/block").
|
||||
POST("/api/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{})
|
||||
|
||||
@ -429,11 +449,15 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
||||
|
||||
limiter := ratelimit.NewMiddleware(limitStore)
|
||||
ne.
|
||||
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: forgotPasswordRateLimit, MaxBurst: 9})).
|
||||
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
|
||||
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
|
||||
|
||||
ne.
|
||||
WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
|
||||
loginRateLimit := throttled.PerMin(10)
|
||||
if extra.loginRateLimit != nil {
|
||||
loginRateLimit = *extra.loginRateLimit
|
||||
}
|
||||
|
||||
ne.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
|
||||
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/throttled/throttled/v2"
|
||||
)
|
||||
|
||||
type withDS struct {
|
||||
@ -46,7 +45,6 @@ type withServer struct {
|
||||
func (ts *withServer) SetupSuite(dbName string) {
|
||||
ts.withDS.SetupSuite(dbName)
|
||||
|
||||
loginRateLimit = throttled.PerMin(100)
|
||||
rs := pubsub.NewInmemQueryResults()
|
||||
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs})
|
||||
ts.server = server
|
||||
@ -149,7 +147,7 @@ func (ts *withServer) getTestToken(email string, password string) string {
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
|
||||
|
||||
var jsn = struct {
|
||||
jsn := struct {
|
||||
User *fleet.User `json:"user"`
|
||||
Token string `json:"token"`
|
||||
Err []map[string]string `json:"errors,omitempty"`
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/throttled/throttled/v2"
|
||||
"github.com/throttled/throttled/v2/store/memstore"
|
||||
)
|
||||
|
||||
@ -177,7 +178,7 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe
|
||||
}
|
||||
|
||||
limitStore, _ := memstore.New(0)
|
||||
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore)
|
||||
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore, WithLoginRateLimit(throttled.PerMin(100)))
|
||||
server := httptest.NewServer(r)
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
|
Loading…
Reference in New Issue
Block a user