Add rate limits for device authed endpoints (#6529)

* Add rate limits for device authed endpoints

* Fix lint

* Add missing test

* Fix test

* Increase the quota for desktop endpoints

* Add comment about quota
This commit is contained in:
Tomas Touceda 2022-07-11 10:49:05 -03:00 committed by GitHub
parent 6a1724a474
commit af0cf9b703
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 119 additions and 8 deletions

View File

@ -0,0 +1 @@
* Rate limit device authenticated endpoints when there's a failure

View File

@ -393,15 +393,34 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ue.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil)
ue.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil)
errorLimiter := ratelimit.NewErrorMiddleware(limitStore)
// device-authenticated endpoints
de := newDeviceAuthenticatedEndpointer(svc, logger, opts, r, apiVersions...)
de.GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{})
de.POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{})
de.GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{})
de.GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{})
de.GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{})
de.GET("/api/_version_/fleet/device/{token}/api_features", deviceAPIFeaturesEndpoint, deviceAPIFeaturesRequest{})
de.GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{})
// We allow a quota of 720 because in the onboarding of a Fleet Desktop takes a few tries until it authenticates
// properly
desktopQuota := throttled.RateQuota{MaxRate: throttled.PerHour(720), MaxBurst: 100}
de.WithCustomMiddleware(
errorLimiter.Limit("get_device_host", desktopQuota),
).GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{})
de.WithCustomMiddleware(
errorLimiter.Limit("refetch_device_host", desktopQuota),
).POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{})
de.WithCustomMiddleware(
errorLimiter.Limit("get_device_mapping", desktopQuota),
).GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{})
de.WithCustomMiddleware(
errorLimiter.Limit("get_device_macadmins", desktopQuota),
).GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{})
de.WithCustomMiddleware(
errorLimiter.Limit("get_device_policies", desktopQuota),
).GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{})
de.WithCustomMiddleware(
errorLimiter.Limit("get_device_api_features", desktopQuota),
).GET("/api/_version_/fleet/device/{token}/api_features", deviceAPIFeaturesEndpoint, deviceAPIFeaturesRequest{})
de.WithCustomMiddleware(
errorLimiter.Limit("get_device_transparency", desktopQuota),
).GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{})
// host-authenticated endpoints
he := newHostAuthenticatedEndpointer(svc, logger, opts, r, apiVersions...)
@ -453,9 +472,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
// the handler.
ne.UsePathPrefix().PathHandler("GET", "/api/_version_/fleet/results/", makeStreamDistributedQueryCampaignResultsHandler(svc, logger))
quota := throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 90}
limiter := ratelimit.NewMiddleware(limitStore)
ne.
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
WithCustomMiddleware(limiter.Limit("forgot_password", quota)).
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
loginRateLimit := throttled.PerMin(10)

View File

@ -7,6 +7,7 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/throttled/throttled/v2"
)
@ -47,6 +48,53 @@ func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.M
}
}
// ErrorMiddleware is a rate limiter that performs limits only when there is an error in the request
type ErrorMiddleware struct {
store throttled.GCRAStore
}
// NewErrorMiddleware creates a new instance of ErrorMiddleware
func NewErrorMiddleware(store throttled.GCRAStore) *ErrorMiddleware {
if store == nil {
panic("nil store")
}
return &ErrorMiddleware{store: store}
}
// Limit returns a new middleware function enforcing the provided quota only when errors occur in the next middleware
func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
if err != nil {
panic(err)
}
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
xForwardedFor, _ := ctx.Value(kithttp.ContextKeyRequestXForwardedFor).(string)
ipKeyName := fmt.Sprintf("%s-%s", keyName, xForwardedFor)
// RateLimit with quantity 0 will never get limited=true, so we check result.Remaining instead
_, result, err := limiter.RateLimit(ipKeyName, 0)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "check rate limit")
}
if result.Remaining == 0 {
return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result})
}
resp, err := next(ctx, req)
if err != nil {
_, _, rateErr := limiter.RateLimit(ipKeyName, 1)
if rateErr != nil {
return nil, ctxerr.Wrap(ctx, err, "check rate limit")
}
}
return resp, err
}
}
}
// Error is the interface for rate limiting errors.
type Error interface {
error

View File

@ -33,3 +33,45 @@ func TestLimit(t *testing.T) {
var rle Error
assert.True(t, errors.As(err, &rle))
}
func TestNewErrorMiddlewarePanics(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
NewErrorMiddleware(nil)
}
func TestLimitOnlyWhenError(t *testing.T) {
t.Parallel()
store, _ := memstore.New(1)
limiter := NewErrorMiddleware(store)
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
wrapped := limiter.Limit(
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
)(endpoint)
// Does NOT hit any rate limits because the endpoint doesn't fail
_, err := wrapped(context.Background(), struct{}{})
assert.NoError(t, err)
_, err = wrapped(context.Background(), struct{}{})
assert.NoError(t, err)
expectedError := errors.New("error")
failingEndpoint := func(context.Context, interface{}) (interface{}, error) { return nil, expectedError }
wrappedFailer := limiter.Limit(
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
)(failingEndpoint)
_, err = wrappedFailer(context.Background(), struct{}{})
assert.ErrorIs(t, err, expectedError)
// Hits rate limit now that it fails
_, err = wrappedFailer(context.Background(), struct{}{})
assert.Error(t, err)
var rle Error
assert.True(t, errors.As(err, &rle))
}