Add rate limit response headers (X-RateLimit-*, Retry-After)

Custom rate limiter replacing Echo built-in, with per-IP token bucket.
Every response includes X-RateLimit-Limit, Remaining, Reset headers.
429 responses additionally include Retry-After (seconds).
CORS updated to expose rate limit headers to mobile clients.
4 unit tests for header behavior and per-IP isolation.
This commit is contained in:
Trey T
2026-03-26 14:36:48 -05:00
parent b679f28e55
commit 6df27f203b
3 changed files with 294 additions and 28 deletions

View File

@@ -1,49 +1,143 @@
package middleware
import (
"fmt"
"math"
"net/http"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"golang.org/x/time/rate"
"github.com/treytartt/honeydue-api/internal/dto/responses"
)
// rateLimitEntry holds a per-client token-bucket limiter and metadata needed
// to compute the standard rate-limit response headers.
type rateLimitEntry struct {
limiter *rate.Limiter
lastSeen time.Time
}
// rateLimitStore is a thread-safe in-memory store keyed by client identifier
// (typically IP address). Stale entries are lazily evicted.
type rateLimitStore struct {
mu sync.Mutex
entries map[string]*rateLimitEntry
rate rate.Limit
burst int
expiresIn time.Duration
}
func newRateLimitStore(r rate.Limit, burst int, expiresIn time.Duration) *rateLimitStore {
return &rateLimitStore{
entries: make(map[string]*rateLimitEntry),
rate: r,
burst: burst,
expiresIn: expiresIn,
}
}
// allow checks the rate limiter for the given identifier and returns:
// - allowed: whether the request should be permitted
// - remaining: approximate number of tokens left (requests remaining)
// - resetAt: when the bucket will next be full
func (s *rateLimitStore) allow(identifier string) (allowed bool, remaining int, resetAt time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
// Lazy eviction of stale entries
if len(s.entries) > 1000 {
for key, entry := range s.entries {
if now.Sub(entry.lastSeen) > s.expiresIn {
delete(s.entries, key)
}
}
}
entry, exists := s.entries[identifier]
if !exists {
limiter := rate.NewLimiter(s.rate, s.burst)
entry = &rateLimitEntry{limiter: limiter}
s.entries[identifier] = entry
}
entry.lastSeen = now
allowed = entry.limiter.Allow()
// Compute remaining tokens (floor to nearest int, min 0).
// After Allow(), tokens are already decremented if allowed.
tokens := entry.limiter.Tokens()
remaining = int(math.Floor(tokens))
if remaining < 0 {
remaining = 0
}
// Compute when the bucket will be fully replenished.
// tokens needed = burst - current tokens
tokensNeeded := float64(s.burst) - tokens
if tokensNeeded <= 0 {
resetAt = now
} else {
secondsToFull := tokensNeeded / float64(s.rate)
resetAt = now.Add(time.Duration(secondsToFull * float64(time.Second)))
}
return allowed, remaining, resetAt
}
// HeaderXRateLimitLimit is the max requests allowed in the window.
const HeaderXRateLimitLimit = "X-RateLimit-Limit"
// HeaderXRateLimitRemaining is the number of requests remaining.
const HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
// HeaderXRateLimitReset is the Unix timestamp when the window resets.
const HeaderXRateLimitReset = "X-RateLimit-Reset"
// HeaderRetryAfter is the seconds until the client can retry (429 only).
const HeaderRetryAfter = "Retry-After"
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client
// IP address.
// endpoints. It uses a custom in-memory token-bucket rate limiter keyed by
// client IP address and sets standard rate-limit headers on every response.
//
// Parameters:
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
// - burst: maximum burst size above the sustained rate
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
store := middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{
Rate: ratePerSecond,
Burst: burst,
ExpiresIn: 5 * time.Minute,
},
)
store := newRateLimitStore(ratePerSecond, burst, 5*time.Minute)
return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
Skipper: middleware.DefaultSkipper,
IdentifierExtractor: func(c echo.Context) (string, error) {
return c.RealIP(), nil
},
Store: store,
DenyHandler: func(c echo.Context, _ string, _ error) error {
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
Error: "Too many requests. Please try again later.",
})
},
ErrorHandler: func(c echo.Context, err error) error {
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
Error: "Unable to process request.",
})
},
})
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
identifier := c.RealIP()
allowed, remaining, resetAt := store.allow(identifier)
// Set rate-limit headers on every response
h := c.Response().Header()
h.Set(HeaderXRateLimitLimit, fmt.Sprintf("%d", burst))
h.Set(HeaderXRateLimitRemaining, fmt.Sprintf("%d", remaining))
h.Set(HeaderXRateLimitReset, fmt.Sprintf("%d", resetAt.Unix()))
if !allowed {
// Calculate Retry-After in seconds (ceiling to be safe)
retryAfter := int(math.Ceil(time.Until(resetAt).Seconds()))
if retryAfter < 1 {
retryAfter = 1
}
h.Set(HeaderRetryAfter, fmt.Sprintf("%d", retryAfter))
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
Error: "Too many requests. Please try again later.",
})
}
return next(c)
}
}
}
// LoginRateLimiter returns rate-limiting middleware for login endpoints.