mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
Fix race condition in tests when using global var loginRateLimit (#5197)
This commit is contained in:
parent
3228e1b5dc
commit
f7048b711c
@ -75,8 +75,33 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type extraHandlerOpts struct {
|
||||||
|
loginRateLimit *throttled.Rate
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtraHandlerOption allows adding extra configuration to the HTTP handler.
|
||||||
|
type ExtraHandlerOption func(*extraHandlerOpts)
|
||||||
|
|
||||||
|
// WithLoginRateLimit configures the rate limit for the login endpoint.
|
||||||
|
func WithLoginRateLimit(r throttled.Rate) ExtraHandlerOption {
|
||||||
|
return func(o *extraHandlerOpts) {
|
||||||
|
o.loginRateLimit = &r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MakeHandler creates an HTTP handler for the Fleet server endpoints.
|
// MakeHandler creates an HTTP handler for the Fleet server endpoints.
|
||||||
func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Logger, limitStore throttled.GCRAStore) http.Handler {
|
func MakeHandler(
|
||||||
|
svc fleet.Service,
|
||||||
|
config config.FleetConfig,
|
||||||
|
logger kitlog.Logger,
|
||||||
|
limitStore throttled.GCRAStore,
|
||||||
|
extra ...ExtraHandlerOption,
|
||||||
|
) http.Handler {
|
||||||
|
var eopts extraHandlerOpts
|
||||||
|
for _, fn := range extra {
|
||||||
|
fn(&eopts)
|
||||||
|
}
|
||||||
|
|
||||||
fleetAPIOptions := []kithttp.ServerOption{
|
fleetAPIOptions := []kithttp.ServerOption{
|
||||||
kithttp.ServerBefore(
|
kithttp.ServerBefore(
|
||||||
kithttp.PopulateRequestContext, // populate the request context with common fields
|
kithttp.PopulateRequestContext, // populate the request context with common fields
|
||||||
@ -98,7 +123,7 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log
|
|||||||
|
|
||||||
r.Use(publicIP)
|
r.Use(publicIP)
|
||||||
|
|
||||||
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions)
|
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts)
|
||||||
|
|
||||||
// Results endpoint is handled different due to websockets use
|
// Results endpoint is handled different due to websockets use
|
||||||
|
|
||||||
@ -203,14 +228,9 @@ func addMetrics(r *mux.Router) {
|
|||||||
r.Walk(walkFn)
|
r.Walk(walkFn)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
// those are conceptually constants, but var so they can be changed in tests
|
|
||||||
forgotPasswordRateLimit = throttled.PerHour(10)
|
|
||||||
loginRateLimit = throttled.PerMin(10)
|
|
||||||
)
|
|
||||||
|
|
||||||
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
|
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
|
||||||
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption,
|
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption,
|
||||||
|
extra extraHandlerOpts,
|
||||||
) {
|
) {
|
||||||
apiVersions := []string{"v1", "2022-04"}
|
apiVersions := []string{"v1", "2022-04"}
|
||||||
|
|
||||||
@ -429,11 +449,15 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
|||||||
|
|
||||||
limiter := ratelimit.NewMiddleware(limitStore)
|
limiter := ratelimit.NewMiddleware(limitStore)
|
||||||
ne.
|
ne.
|
||||||
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: forgotPasswordRateLimit, MaxBurst: 9})).
|
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
|
||||||
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
|
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
|
||||||
|
|
||||||
ne.
|
loginRateLimit := throttled.PerMin(10)
|
||||||
WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
|
if extra.loginRateLimit != nil {
|
||||||
|
loginRateLimit = *extra.loginRateLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
ne.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
|
||||||
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
|
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,7 +17,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"github.com/throttled/throttled/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type withDS struct {
|
type withDS struct {
|
||||||
@ -46,7 +45,6 @@ type withServer struct {
|
|||||||
func (ts *withServer) SetupSuite(dbName string) {
|
func (ts *withServer) SetupSuite(dbName string) {
|
||||||
ts.withDS.SetupSuite(dbName)
|
ts.withDS.SetupSuite(dbName)
|
||||||
|
|
||||||
loginRateLimit = throttled.PerMin(100)
|
|
||||||
rs := pubsub.NewInmemQueryResults()
|
rs := pubsub.NewInmemQueryResults()
|
||||||
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs})
|
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs})
|
||||||
ts.server = server
|
ts.server = server
|
||||||
@ -149,7 +147,7 @@ func (ts *withServer) getTestToken(email string, password string) string {
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
|
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var jsn = struct {
|
jsn := struct {
|
||||||
User *fleet.User `json:"user"`
|
User *fleet.User `json:"user"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
Err []map[string]string `json:"errors,omitempty"`
|
Err []map[string]string `json:"errors,omitempty"`
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
kitlog "github.com/go-kit/kit/log"
|
kitlog "github.com/go-kit/kit/log"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/throttled/throttled/v2"
|
||||||
"github.com/throttled/throttled/v2/store/memstore"
|
"github.com/throttled/throttled/v2/store/memstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -177,7 +178,7 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe
|
|||||||
}
|
}
|
||||||
|
|
||||||
limitStore, _ := memstore.New(0)
|
limitStore, _ := memstore.New(0)
|
||||||
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore)
|
r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore, WithLoginRateLimit(throttled.PerMin(100)))
|
||||||
server := httptest.NewServer(r)
|
server := httptest.NewServer(r)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
server.Close()
|
server.Close()
|
||||||
|
Loading…
Reference in New Issue
Block a user