diff --git a/changes/rate-limit-device-authed b/changes/rate-limit-device-authed new file mode 100644 index 000000000..87bc17ccd --- /dev/null +++ b/changes/rate-limit-device-authed @@ -0,0 +1 @@ +* Rate limit device authenticated endpoints when there's a failure diff --git a/server/service/handler.go b/server/service/handler.go index aac7dcb33..4f4cd120b 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -393,15 +393,34 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ue.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil) ue.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil) + errorLimiter := ratelimit.NewErrorMiddleware(limitStore) + // device-authenticated endpoints de := newDeviceAuthenticatedEndpointer(svc, logger, opts, r, apiVersions...) - de.GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{}) - de.POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{}) - de.GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{}) - de.GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{}) - de.GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{}) - de.GET("/api/_version_/fleet/device/{token}/api_features", deviceAPIFeaturesEndpoint, deviceAPIFeaturesRequest{}) - de.GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{}) + // We allow a quota of 720 because in the onboarding of a Fleet Desktop takes a few tries until it authenticates + // properly + desktopQuota := throttled.RateQuota{MaxRate: throttled.PerHour(720), MaxBurst: 100} + de.WithCustomMiddleware( + errorLimiter.Limit("get_device_host", desktopQuota), + ).GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{}) + de.WithCustomMiddleware( + errorLimiter.Limit("refetch_device_host", desktopQuota), + ).POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{}) + de.WithCustomMiddleware( + errorLimiter.Limit("get_device_mapping", desktopQuota), + ).GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{}) + de.WithCustomMiddleware( + errorLimiter.Limit("get_device_macadmins", desktopQuota), + ).GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{}) + de.WithCustomMiddleware( + errorLimiter.Limit("get_device_policies", desktopQuota), + ).GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{}) + de.WithCustomMiddleware( + errorLimiter.Limit("get_device_api_features", desktopQuota), + ).GET("/api/_version_/fleet/device/{token}/api_features", deviceAPIFeaturesEndpoint, deviceAPIFeaturesRequest{}) + de.WithCustomMiddleware( + errorLimiter.Limit("get_device_transparency", desktopQuota), + ).GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{}) // host-authenticated endpoints he := newHostAuthenticatedEndpointer(svc, logger, opts, r, apiVersions...) @@ -453,9 +472,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC // the handler. ne.UsePathPrefix().PathHandler("GET", "/api/_version_/fleet/results/", makeStreamDistributedQueryCampaignResultsHandler(svc, logger)) + quota := throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 90} limiter := ratelimit.NewMiddleware(limitStore) ne. - WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})). + WithCustomMiddleware(limiter.Limit("forgot_password", quota)). POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{}) loginRateLimit := throttled.PerMin(10) diff --git a/server/service/middleware/ratelimit/ratelimit.go b/server/service/middleware/ratelimit/ratelimit.go index 3fb3971ff..4c50a8db5 100644 --- a/server/service/middleware/ratelimit/ratelimit.go +++ b/server/service/middleware/ratelimit/ratelimit.go @@ -7,6 +7,7 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/go-kit/kit/endpoint" + kithttp "github.com/go-kit/kit/transport/http" "github.com/throttled/throttled/v2" ) @@ -47,6 +48,53 @@ func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.M } } +// ErrorMiddleware is a rate limiter that performs limits only when there is an error in the request +type ErrorMiddleware struct { + store throttled.GCRAStore +} + +// NewErrorMiddleware creates a new instance of ErrorMiddleware +func NewErrorMiddleware(store throttled.GCRAStore) *ErrorMiddleware { + if store == nil { + panic("nil store") + } + + return &ErrorMiddleware{store: store} +} + +// Limit returns a new middleware function enforcing the provided quota only when errors occur in the next middleware +func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + limiter, err := throttled.NewGCRARateLimiter(m.store, quota) + if err != nil { + panic(err) + } + + return func(ctx context.Context, req interface{}) (response interface{}, err error) { + xForwardedFor, _ := ctx.Value(kithttp.ContextKeyRequestXForwardedFor).(string) + ipKeyName := fmt.Sprintf("%s-%s", keyName, xForwardedFor) + + // RateLimit with quantity 0 will never get limited=true, so we check result.Remaining instead + _, result, err := limiter.RateLimit(ipKeyName, 0) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "check rate limit") + } + if result.Remaining == 0 { + return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result}) + } + + resp, err := next(ctx, req) + if err != nil { + _, _, rateErr := limiter.RateLimit(ipKeyName, 1) + if rateErr != nil { + return nil, ctxerr.Wrap(ctx, err, "check rate limit") + } + } + return resp, err + } + } +} + // Error is the interface for rate limiting errors. type Error interface { error diff --git a/server/service/middleware/ratelimit/ratelimit_test.go b/server/service/middleware/ratelimit/ratelimit_test.go index 8e3338840..281a2d924 100644 --- a/server/service/middleware/ratelimit/ratelimit_test.go +++ b/server/service/middleware/ratelimit/ratelimit_test.go @@ -33,3 +33,45 @@ func TestLimit(t *testing.T) { var rle Error assert.True(t, errors.As(err, &rle)) } + +func TestNewErrorMiddlewarePanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + + NewErrorMiddleware(nil) +} + +func TestLimitOnlyWhenError(t *testing.T) { + t.Parallel() + + store, _ := memstore.New(1) + limiter := NewErrorMiddleware(store) + endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + wrapped := limiter.Limit( + "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, + )(endpoint) + + // Does NOT hit any rate limits because the endpoint doesn't fail + _, err := wrapped(context.Background(), struct{}{}) + assert.NoError(t, err) + _, err = wrapped(context.Background(), struct{}{}) + assert.NoError(t, err) + + expectedError := errors.New("error") + failingEndpoint := func(context.Context, interface{}) (interface{}, error) { return nil, expectedError } + wrappedFailer := limiter.Limit( + "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, + )(failingEndpoint) + + _, err = wrappedFailer(context.Background(), struct{}{}) + assert.ErrorIs(t, err, expectedError) + + // Hits rate limit now that it fails + _, err = wrappedFailer(context.Background(), struct{}{}) + assert.Error(t, err) + var rle Error + assert.True(t, errors.As(err, &rle)) +}