fleet/server/service/middleware/ratelimit/ratelimit.go
Zach Wasserman 12d292164f
Add rate-limiting to login and password reset (#543)
Prevent abuse of these endpoints with rate limiting backed by Redis. The
limits assigned should be appropriate for almost any Fleet deployment.

Closes #530
2021-03-26 11:23:29 -07:00

81 lines
1.9 KiB
Go

package ratelimit
import (
"context"
"fmt"
"net/http"
"reflect"
"runtime"
"github.com/go-kit/kit/endpoint"
"github.com/pkg/errors"
"github.com/throttled/throttled/v2"
)
// Middleware is a rate limiting middleware using the provided store. Each
// function wrapped by the rate limiter receives a separate quota.
type Middleware struct {
store throttled.GCRAStore
}
// NewMiddleware initializes the middleware with the provided store.
func NewMiddleware(store throttled.GCRAStore) *Middleware {
if store == nil {
panic("nil store")
}
return &Middleware{store: store}
}
// Limit returns a new middleware function enforcing the provided quota.
func (m *Middleware) Limit(quota throttled.RateQuota) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
// Get function name to use as a key for rate limiting (each wrapped function
// gets a separate quota)
funcName := runtime.FuncForPC(reflect.ValueOf(next).Pointer()).Name()
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
if err != nil {
panic(err)
}
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
limited, result, err := limiter.RateLimit(funcName, 1)
if err != nil {
return nil, errors.Wrap(err, "check rate limit")
}
if limited {
return nil, &ratelimitError{result: result}
}
return next(ctx, req)
}
}
}
// Error is the interface for rate limiting errors.
type Error interface {
error
Result() throttled.RateLimitResult
}
type ratelimitError struct {
result throttled.RateLimitResult
}
func (r ratelimitError) Error() string {
return fmt.Sprintf("limit exceeded, retry after: %ds", int(r.result.RetryAfter.Seconds()))
}
func (r ratelimitError) StatusCode() int {
return http.StatusTooManyRequests
}
func (r ratelimitError) RetryAfter() int {
return int(r.result.RetryAfter.Seconds())
}
func (r ratelimitError) Result() throttled.RateLimitResult {
return r.result
}