mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
55d56ba2db
Return proper HTTP status code if endpoint is rate limited.
84 lines
2.3 KiB
Go
84 lines
2.3 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/throttled/throttled/v2"
|
|
"github.com/throttled/throttled/v2/store/memstore"
|
|
)
|
|
|
|
// Intent is to test the middleware functionality. We rely on the tests within
|
|
// Throttled to verify that the rate limiting algorithm works properly.
|
|
|
|
func TestLimit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := memstore.New(0)
|
|
limiter := NewMiddleware(store)
|
|
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
|
|
wrapped := limiter.Limit(
|
|
"test_limit",
|
|
throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
|
)(endpoint)
|
|
|
|
authzCtx := &authz_ctx.AuthorizationContext{}
|
|
ctx := authz_ctx.NewContext(context.Background(), authzCtx)
|
|
|
|
_, err := wrapped(ctx, struct{}{})
|
|
assert.NoError(t, err)
|
|
|
|
// Hits rate limit
|
|
_, err = wrapped(ctx, struct{}{})
|
|
assert.Error(t, err)
|
|
var rle Error
|
|
|
|
assert.True(t, errors.As(err, &rle))
|
|
assert.True(t, authzCtx.Checked())
|
|
}
|
|
|
|
func TestNewErrorMiddlewarePanics(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Errorf("The code did not panic")
|
|
}
|
|
}()
|
|
|
|
NewErrorMiddleware(nil)
|
|
}
|
|
|
|
func TestLimitOnlyWhenError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := memstore.New(1)
|
|
limiter := NewErrorMiddleware(store)
|
|
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
|
|
wrapped := limiter.Limit(
|
|
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
|
)(endpoint)
|
|
|
|
// Does NOT hit any rate limits because the endpoint doesn't fail
|
|
_, err := wrapped(context.Background(), struct{}{})
|
|
assert.NoError(t, err)
|
|
_, err = wrapped(context.Background(), struct{}{})
|
|
assert.NoError(t, err)
|
|
|
|
expectedError := errors.New("error")
|
|
failingEndpoint := func(context.Context, interface{}) (interface{}, error) { return nil, expectedError }
|
|
wrappedFailer := limiter.Limit(
|
|
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
|
)(failingEndpoint)
|
|
|
|
_, err = wrappedFailer(context.Background(), struct{}{})
|
|
assert.ErrorIs(t, err, expectedError)
|
|
|
|
// Hits rate limit now that it fails
|
|
_, err = wrappedFailer(context.Background(), struct{}{})
|
|
assert.Error(t, err)
|
|
var rle Error
|
|
assert.True(t, errors.As(err, &rle))
|
|
}
|