Custom rate limiter replacing Echo built-in, with per-IP token bucket. Every response includes X-RateLimit-Limit, Remaining, Reset headers. 429 responses additionally include Retry-After (seconds). CORS updated to expose rate limit headers to mobile clients. 4 unit tests for header behavior and per-IP isolation.
163 lines
5.0 KiB
Go
163 lines
5.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
|
)
|
|
|
|
// rateLimitEntry holds a per-client token-bucket limiter and metadata needed
|
|
// to compute the standard rate-limit response headers.
|
|
type rateLimitEntry struct {
|
|
limiter *rate.Limiter
|
|
lastSeen time.Time
|
|
}
|
|
|
|
// rateLimitStore is a thread-safe in-memory store keyed by client identifier
|
|
// (typically IP address). Stale entries are lazily evicted.
|
|
type rateLimitStore struct {
|
|
mu sync.Mutex
|
|
entries map[string]*rateLimitEntry
|
|
rate rate.Limit
|
|
burst int
|
|
expiresIn time.Duration
|
|
}
|
|
|
|
func newRateLimitStore(r rate.Limit, burst int, expiresIn time.Duration) *rateLimitStore {
|
|
return &rateLimitStore{
|
|
entries: make(map[string]*rateLimitEntry),
|
|
rate: r,
|
|
burst: burst,
|
|
expiresIn: expiresIn,
|
|
}
|
|
}
|
|
|
|
// allow checks the rate limiter for the given identifier and returns:
|
|
// - allowed: whether the request should be permitted
|
|
// - remaining: approximate number of tokens left (requests remaining)
|
|
// - resetAt: when the bucket will next be full
|
|
func (s *rateLimitStore) allow(identifier string) (allowed bool, remaining int, resetAt time.Time) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
|
|
// Lazy eviction of stale entries
|
|
if len(s.entries) > 1000 {
|
|
for key, entry := range s.entries {
|
|
if now.Sub(entry.lastSeen) > s.expiresIn {
|
|
delete(s.entries, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
entry, exists := s.entries[identifier]
|
|
if !exists {
|
|
limiter := rate.NewLimiter(s.rate, s.burst)
|
|
entry = &rateLimitEntry{limiter: limiter}
|
|
s.entries[identifier] = entry
|
|
}
|
|
entry.lastSeen = now
|
|
|
|
allowed = entry.limiter.Allow()
|
|
|
|
// Compute remaining tokens (floor to nearest int, min 0).
|
|
// After Allow(), tokens are already decremented if allowed.
|
|
tokens := entry.limiter.Tokens()
|
|
remaining = int(math.Floor(tokens))
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
// Compute when the bucket will be fully replenished.
|
|
// tokens needed = burst - current tokens
|
|
tokensNeeded := float64(s.burst) - tokens
|
|
if tokensNeeded <= 0 {
|
|
resetAt = now
|
|
} else {
|
|
secondsToFull := tokensNeeded / float64(s.rate)
|
|
resetAt = now.Add(time.Duration(secondsToFull * float64(time.Second)))
|
|
}
|
|
|
|
return allowed, remaining, resetAt
|
|
}
|
|
|
|
// HeaderXRateLimitLimit is the max requests allowed in the window.
|
|
const HeaderXRateLimitLimit = "X-RateLimit-Limit"
|
|
|
|
// HeaderXRateLimitRemaining is the number of requests remaining.
|
|
const HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
|
|
|
|
// HeaderXRateLimitReset is the Unix timestamp when the window resets.
|
|
const HeaderXRateLimitReset = "X-RateLimit-Reset"
|
|
|
|
// HeaderRetryAfter is the seconds until the client can retry (429 only).
|
|
const HeaderRetryAfter = "Retry-After"
|
|
|
|
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
|
|
// endpoints. It uses a custom in-memory token-bucket rate limiter keyed by
|
|
// client IP address and sets standard rate-limit headers on every response.
|
|
//
|
|
// Parameters:
|
|
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
|
|
// - burst: maximum burst size above the sustained rate
|
|
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
|
|
store := newRateLimitStore(ratePerSecond, burst, 5*time.Minute)
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
identifier := c.RealIP()
|
|
allowed, remaining, resetAt := store.allow(identifier)
|
|
|
|
// Set rate-limit headers on every response
|
|
h := c.Response().Header()
|
|
h.Set(HeaderXRateLimitLimit, fmt.Sprintf("%d", burst))
|
|
h.Set(HeaderXRateLimitRemaining, fmt.Sprintf("%d", remaining))
|
|
h.Set(HeaderXRateLimitReset, fmt.Sprintf("%d", resetAt.Unix()))
|
|
|
|
if !allowed {
|
|
// Calculate Retry-After in seconds (ceiling to be safe)
|
|
retryAfter := int(math.Ceil(time.Until(resetAt).Seconds()))
|
|
if retryAfter < 1 {
|
|
retryAfter = 1
|
|
}
|
|
h.Set(HeaderRetryAfter, fmt.Sprintf("%d", retryAfter))
|
|
|
|
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
|
|
Error: "Too many requests. Please try again later.",
|
|
})
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// LoginRateLimiter returns rate-limiting middleware for login endpoints.
|
|
// Allows 10 requests per minute with a burst of 5.
|
|
func LoginRateLimiter() echo.MiddlewareFunc {
|
|
// 10 requests per 60 seconds = ~0.167 req/s, burst 5
|
|
return AuthRateLimiter(rate.Limit(10.0/60.0), 5)
|
|
}
|
|
|
|
// RegistrationRateLimiter returns rate-limiting middleware for registration
|
|
// endpoints. Allows 5 requests per minute with a burst of 3.
|
|
func RegistrationRateLimiter() echo.MiddlewareFunc {
|
|
// 5 requests per 60 seconds = ~0.083 req/s, burst 3
|
|
return AuthRateLimiter(rate.Limit(5.0/60.0), 3)
|
|
}
|
|
|
|
// PasswordResetRateLimiter returns rate-limiting middleware for password
|
|
// reset endpoints. Allows 3 requests per minute with a burst of 2.
|
|
func PasswordResetRateLimiter() echo.MiddlewareFunc {
|
|
// 3 requests per 60 seconds = 0.05 req/s, burst 2
|
|
return AuthRateLimiter(rate.Limit(3.0/60.0), 2)
|
|
}
|