Add rate limit response headers (X-RateLimit-*, Retry-After)
Custom rate limiter replacing Echo built-in, with per-IP token bucket. Every response includes X-RateLimit-Limit, Remaining, Reset headers. 429 responses additionally include Retry-After (seconds). CORS updated to expose rate limit headers to mobile clients. 4 unit tests for header behavior and per-IP isolation.
This commit is contained in:
@@ -1,49 +1,143 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||
)
|
||||
|
||||
// rateLimitEntry holds a per-client token-bucket limiter and metadata needed
|
||||
// to compute the standard rate-limit response headers.
|
||||
type rateLimitEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// rateLimitStore is a thread-safe in-memory store keyed by client identifier
|
||||
// (typically IP address). Stale entries are lazily evicted.
|
||||
type rateLimitStore struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*rateLimitEntry
|
||||
rate rate.Limit
|
||||
burst int
|
||||
expiresIn time.Duration
|
||||
}
|
||||
|
||||
func newRateLimitStore(r rate.Limit, burst int, expiresIn time.Duration) *rateLimitStore {
|
||||
return &rateLimitStore{
|
||||
entries: make(map[string]*rateLimitEntry),
|
||||
rate: r,
|
||||
burst: burst,
|
||||
expiresIn: expiresIn,
|
||||
}
|
||||
}
|
||||
|
||||
// allow checks the rate limiter for the given identifier and returns:
|
||||
// - allowed: whether the request should be permitted
|
||||
// - remaining: approximate number of tokens left (requests remaining)
|
||||
// - resetAt: when the bucket will next be full
|
||||
func (s *rateLimitStore) allow(identifier string) (allowed bool, remaining int, resetAt time.Time) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Lazy eviction of stale entries
|
||||
if len(s.entries) > 1000 {
|
||||
for key, entry := range s.entries {
|
||||
if now.Sub(entry.lastSeen) > s.expiresIn {
|
||||
delete(s.entries, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entry, exists := s.entries[identifier]
|
||||
if !exists {
|
||||
limiter := rate.NewLimiter(s.rate, s.burst)
|
||||
entry = &rateLimitEntry{limiter: limiter}
|
||||
s.entries[identifier] = entry
|
||||
}
|
||||
entry.lastSeen = now
|
||||
|
||||
allowed = entry.limiter.Allow()
|
||||
|
||||
// Compute remaining tokens (floor to nearest int, min 0).
|
||||
// After Allow(), tokens are already decremented if allowed.
|
||||
tokens := entry.limiter.Tokens()
|
||||
remaining = int(math.Floor(tokens))
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// Compute when the bucket will be fully replenished.
|
||||
// tokens needed = burst - current tokens
|
||||
tokensNeeded := float64(s.burst) - tokens
|
||||
if tokensNeeded <= 0 {
|
||||
resetAt = now
|
||||
} else {
|
||||
secondsToFull := tokensNeeded / float64(s.rate)
|
||||
resetAt = now.Add(time.Duration(secondsToFull * float64(time.Second)))
|
||||
}
|
||||
|
||||
return allowed, remaining, resetAt
|
||||
}
|
||||
|
||||
// HeaderXRateLimitLimit is the max requests allowed in the window.
|
||||
const HeaderXRateLimitLimit = "X-RateLimit-Limit"
|
||||
|
||||
// HeaderXRateLimitRemaining is the number of requests remaining.
|
||||
const HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
|
||||
|
||||
// HeaderXRateLimitReset is the Unix timestamp when the window resets.
|
||||
const HeaderXRateLimitReset = "X-RateLimit-Reset"
|
||||
|
||||
// HeaderRetryAfter is the seconds until the client can retry (429 only).
|
||||
const HeaderRetryAfter = "Retry-After"
|
||||
|
||||
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
|
||||
// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client
|
||||
// IP address.
|
||||
// endpoints. It uses a custom in-memory token-bucket rate limiter keyed by
|
||||
// client IP address and sets standard rate-limit headers on every response.
|
||||
//
|
||||
// Parameters:
|
||||
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
|
||||
// - burst: maximum burst size above the sustained rate
|
||||
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
|
||||
store := middleware.NewRateLimiterMemoryStoreWithConfig(
|
||||
middleware.RateLimiterMemoryStoreConfig{
|
||||
Rate: ratePerSecond,
|
||||
Burst: burst,
|
||||
ExpiresIn: 5 * time.Minute,
|
||||
},
|
||||
)
|
||||
store := newRateLimitStore(ratePerSecond, burst, 5*time.Minute)
|
||||
|
||||
return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
|
||||
Skipper: middleware.DefaultSkipper,
|
||||
IdentifierExtractor: func(c echo.Context) (string, error) {
|
||||
return c.RealIP(), nil
|
||||
},
|
||||
Store: store,
|
||||
DenyHandler: func(c echo.Context, _ string, _ error) error {
|
||||
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
|
||||
Error: "Too many requests. Please try again later.",
|
||||
})
|
||||
},
|
||||
ErrorHandler: func(c echo.Context, err error) error {
|
||||
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
|
||||
Error: "Unable to process request.",
|
||||
})
|
||||
},
|
||||
})
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
identifier := c.RealIP()
|
||||
allowed, remaining, resetAt := store.allow(identifier)
|
||||
|
||||
// Set rate-limit headers on every response
|
||||
h := c.Response().Header()
|
||||
h.Set(HeaderXRateLimitLimit, fmt.Sprintf("%d", burst))
|
||||
h.Set(HeaderXRateLimitRemaining, fmt.Sprintf("%d", remaining))
|
||||
h.Set(HeaderXRateLimitReset, fmt.Sprintf("%d", resetAt.Unix()))
|
||||
|
||||
if !allowed {
|
||||
// Calculate Retry-After in seconds (ceiling to be safe)
|
||||
retryAfter := int(math.Ceil(time.Until(resetAt).Seconds()))
|
||||
if retryAfter < 1 {
|
||||
retryAfter = 1
|
||||
}
|
||||
h.Set(HeaderRetryAfter, fmt.Sprintf("%d", retryAfter))
|
||||
|
||||
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
|
||||
Error: "Too many requests. Please try again later.",
|
||||
})
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRateLimiter returns rate-limiting middleware for login endpoints.
|
||||
|
||||
Reference in New Issue
Block a user