mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
12d292164f
Prevent abuse of these endpoints with rate limiting backed by Redis. The limits assigned should be appropriate for almost any Fleet deployment. Closes #530
81 lines
1.9 KiB
Go
81 lines
1.9 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"runtime"
|
|
|
|
"github.com/go-kit/kit/endpoint"
|
|
"github.com/pkg/errors"
|
|
"github.com/throttled/throttled/v2"
|
|
)
|
|
|
|
// Middleware is a rate limiting middleware using the provided store. Each
|
|
// function wrapped by the rate limiter receives a separate quota.
|
|
type Middleware struct {
|
|
store throttled.GCRAStore
|
|
}
|
|
|
|
// NewMiddleware initializes the middleware with the provided store.
|
|
func NewMiddleware(store throttled.GCRAStore) *Middleware {
|
|
if store == nil {
|
|
panic("nil store")
|
|
}
|
|
|
|
return &Middleware{store: store}
|
|
}
|
|
|
|
// Limit returns a new middleware function enforcing the provided quota.
|
|
func (m *Middleware) Limit(quota throttled.RateQuota) endpoint.Middleware {
|
|
return func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
// Get function name to use as a key for rate limiting (each wrapped function
|
|
// gets a separate quota)
|
|
funcName := runtime.FuncForPC(reflect.ValueOf(next).Pointer()).Name()
|
|
|
|
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
|
|
limited, result, err := limiter.RateLimit(funcName, 1)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "check rate limit")
|
|
}
|
|
if limited {
|
|
return nil, &ratelimitError{result: result}
|
|
}
|
|
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Error is the interface for rate limiting errors.
|
|
type Error interface {
|
|
error
|
|
Result() throttled.RateLimitResult
|
|
}
|
|
|
|
type ratelimitError struct {
|
|
result throttled.RateLimitResult
|
|
}
|
|
|
|
func (r ratelimitError) Error() string {
|
|
return fmt.Sprintf("limit exceeded, retry after: %ds", int(r.result.RetryAfter.Seconds()))
|
|
}
|
|
|
|
func (r ratelimitError) StatusCode() int {
|
|
return http.StatusTooManyRequests
|
|
}
|
|
|
|
func (r ratelimitError) RetryAfter() int {
|
|
return int(r.result.RetryAfter.Seconds())
|
|
}
|
|
|
|
func (r ratelimitError) Result() throttled.RateLimitResult {
|
|
return r.result
|
|
}
|