mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Add rate limits for device authed endpoints (#6529)
* Add rate limits for device authed endpoints * Fix lint * Add missing test * Fix test * Increase the quota for desktop endpoints * Add comment about quota
This commit is contained in:
parent
6a1724a474
commit
af0cf9b703
1
changes/rate-limit-device-authed
Normal file
1
changes/rate-limit-device-authed
Normal file
@ -0,0 +1 @@
|
||||
* Rate limit device authenticated endpoints when there's a failure
|
@ -393,15 +393,34 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
||||
ue.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil)
|
||||
ue.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil)
|
||||
|
||||
errorLimiter := ratelimit.NewErrorMiddleware(limitStore)
|
||||
|
||||
// device-authenticated endpoints
|
||||
de := newDeviceAuthenticatedEndpointer(svc, logger, opts, r, apiVersions...)
|
||||
de.GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{})
|
||||
de.POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{})
|
||||
de.GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{})
|
||||
de.GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{})
|
||||
de.GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{})
|
||||
de.GET("/api/_version_/fleet/device/{token}/api_features", deviceAPIFeaturesEndpoint, deviceAPIFeaturesRequest{})
|
||||
de.GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{})
|
||||
// We allow a quota of 720 because in the onboarding of a Fleet Desktop takes a few tries until it authenticates
|
||||
// properly
|
||||
desktopQuota := throttled.RateQuota{MaxRate: throttled.PerHour(720), MaxBurst: 100}
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_host", desktopQuota),
|
||||
).GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("refetch_device_host", desktopQuota),
|
||||
).POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_mapping", desktopQuota),
|
||||
).GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_macadmins", desktopQuota),
|
||||
).GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_policies", desktopQuota),
|
||||
).GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_api_features", desktopQuota),
|
||||
).GET("/api/_version_/fleet/device/{token}/api_features", deviceAPIFeaturesEndpoint, deviceAPIFeaturesRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_transparency", desktopQuota),
|
||||
).GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{})
|
||||
|
||||
// host-authenticated endpoints
|
||||
he := newHostAuthenticatedEndpointer(svc, logger, opts, r, apiVersions...)
|
||||
@ -453,9 +472,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
||||
// the handler.
|
||||
ne.UsePathPrefix().PathHandler("GET", "/api/_version_/fleet/results/", makeStreamDistributedQueryCampaignResultsHandler(svc, logger))
|
||||
|
||||
quota := throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 90}
|
||||
limiter := ratelimit.NewMiddleware(limitStore)
|
||||
ne.
|
||||
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
|
||||
WithCustomMiddleware(limiter.Limit("forgot_password", quota)).
|
||||
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
|
||||
|
||||
loginRateLimit := throttled.PerMin(10)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/throttled/throttled/v2"
|
||||
)
|
||||
|
||||
@ -47,6 +48,53 @@ func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.M
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorMiddleware is a rate limiter that performs limits only when there is an error in the request
|
||||
type ErrorMiddleware struct {
|
||||
store throttled.GCRAStore
|
||||
}
|
||||
|
||||
// NewErrorMiddleware creates a new instance of ErrorMiddleware
|
||||
func NewErrorMiddleware(store throttled.GCRAStore) *ErrorMiddleware {
|
||||
if store == nil {
|
||||
panic("nil store")
|
||||
}
|
||||
|
||||
return &ErrorMiddleware{store: store}
|
||||
}
|
||||
|
||||
// Limit returns a new middleware function enforcing the provided quota only when errors occur in the next middleware
|
||||
func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware {
|
||||
return func(next endpoint.Endpoint) endpoint.Endpoint {
|
||||
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
|
||||
xForwardedFor, _ := ctx.Value(kithttp.ContextKeyRequestXForwardedFor).(string)
|
||||
ipKeyName := fmt.Sprintf("%s-%s", keyName, xForwardedFor)
|
||||
|
||||
// RateLimit with quantity 0 will never get limited=true, so we check result.Remaining instead
|
||||
_, result, err := limiter.RateLimit(ipKeyName, 0)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "check rate limit")
|
||||
}
|
||||
if result.Remaining == 0 {
|
||||
return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result})
|
||||
}
|
||||
|
||||
resp, err := next(ctx, req)
|
||||
if err != nil {
|
||||
_, _, rateErr := limiter.RateLimit(ipKeyName, 1)
|
||||
if rateErr != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "check rate limit")
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Error is the interface for rate limiting errors.
|
||||
type Error interface {
|
||||
error
|
||||
|
@ -33,3 +33,45 @@ func TestLimit(t *testing.T) {
|
||||
var rle Error
|
||||
assert.True(t, errors.As(err, &rle))
|
||||
}
|
||||
|
||||
func TestNewErrorMiddlewarePanics(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
NewErrorMiddleware(nil)
|
||||
}
|
||||
|
||||
func TestLimitOnlyWhenError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := memstore.New(1)
|
||||
limiter := NewErrorMiddleware(store)
|
||||
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
|
||||
wrapped := limiter.Limit(
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
||||
)(endpoint)
|
||||
|
||||
// Does NOT hit any rate limits because the endpoint doesn't fail
|
||||
_, err := wrapped(context.Background(), struct{}{})
|
||||
assert.NoError(t, err)
|
||||
_, err = wrapped(context.Background(), struct{}{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedError := errors.New("error")
|
||||
failingEndpoint := func(context.Context, interface{}) (interface{}, error) { return nil, expectedError }
|
||||
wrappedFailer := limiter.Limit(
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
||||
)(failingEndpoint)
|
||||
|
||||
_, err = wrappedFailer(context.Background(), struct{}{})
|
||||
assert.ErrorIs(t, err, expectedError)
|
||||
|
||||
// Hits rate limit now that it fails
|
||||
_, err = wrappedFailer(context.Background(), struct{}{})
|
||||
assert.Error(t, err)
|
||||
var rle Error
|
||||
assert.True(t, errors.As(err, &rle))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user