Add rate limit response headers (X-RateLimit-*, Retry-After)
Custom rate limiter replacing Echo built-in, with per-IP token bucket. Every response includes X-RateLimit-Limit, Remaining, Reset headers. 429 responses additionally include Retry-After (seconds). CORS updated to expose rate limit headers to mobile clients. 4 unit tests for header behavior and per-IP isolation.
This commit is contained in:
@@ -1,49 +1,143 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/echo/v4/middleware"
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// rateLimitEntry holds a per-client token-bucket limiter and metadata needed
|
||||||
|
// to compute the standard rate-limit response headers.
|
||||||
|
type rateLimitEntry struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// rateLimitStore is a thread-safe in-memory store keyed by client identifier
|
||||||
|
// (typically IP address). Stale entries are lazily evicted.
|
||||||
|
type rateLimitStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
entries map[string]*rateLimitEntry
|
||||||
|
rate rate.Limit
|
||||||
|
burst int
|
||||||
|
expiresIn time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRateLimitStore(r rate.Limit, burst int, expiresIn time.Duration) *rateLimitStore {
|
||||||
|
return &rateLimitStore{
|
||||||
|
entries: make(map[string]*rateLimitEntry),
|
||||||
|
rate: r,
|
||||||
|
burst: burst,
|
||||||
|
expiresIn: expiresIn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow checks the rate limiter for the given identifier and returns:
|
||||||
|
// - allowed: whether the request should be permitted
|
||||||
|
// - remaining: approximate number of tokens left (requests remaining)
|
||||||
|
// - resetAt: when the bucket will next be full
|
||||||
|
func (s *rateLimitStore) allow(identifier string) (allowed bool, remaining int, resetAt time.Time) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Lazy eviction of stale entries
|
||||||
|
if len(s.entries) > 1000 {
|
||||||
|
for key, entry := range s.entries {
|
||||||
|
if now.Sub(entry.lastSeen) > s.expiresIn {
|
||||||
|
delete(s.entries, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, exists := s.entries[identifier]
|
||||||
|
if !exists {
|
||||||
|
limiter := rate.NewLimiter(s.rate, s.burst)
|
||||||
|
entry = &rateLimitEntry{limiter: limiter}
|
||||||
|
s.entries[identifier] = entry
|
||||||
|
}
|
||||||
|
entry.lastSeen = now
|
||||||
|
|
||||||
|
allowed = entry.limiter.Allow()
|
||||||
|
|
||||||
|
// Compute remaining tokens (floor to nearest int, min 0).
|
||||||
|
// After Allow(), tokens are already decremented if allowed.
|
||||||
|
tokens := entry.limiter.Tokens()
|
||||||
|
remaining = int(math.Floor(tokens))
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute when the bucket will be fully replenished.
|
||||||
|
// tokens needed = burst - current tokens
|
||||||
|
tokensNeeded := float64(s.burst) - tokens
|
||||||
|
if tokensNeeded <= 0 {
|
||||||
|
resetAt = now
|
||||||
|
} else {
|
||||||
|
secondsToFull := tokensNeeded / float64(s.rate)
|
||||||
|
resetAt = now.Add(time.Duration(secondsToFull * float64(time.Second)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return allowed, remaining, resetAt
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeaderXRateLimitLimit is the max requests allowed in the window.
|
||||||
|
const HeaderXRateLimitLimit = "X-RateLimit-Limit"
|
||||||
|
|
||||||
|
// HeaderXRateLimitRemaining is the number of requests remaining.
|
||||||
|
const HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
|
||||||
|
|
||||||
|
// HeaderXRateLimitReset is the Unix timestamp when the window resets.
|
||||||
|
const HeaderXRateLimitReset = "X-RateLimit-Reset"
|
||||||
|
|
||||||
|
// HeaderRetryAfter is the seconds until the client can retry (429 only).
|
||||||
|
const HeaderRetryAfter = "Retry-After"
|
||||||
|
|
||||||
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
|
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
|
||||||
// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client
|
// endpoints. It uses a custom in-memory token-bucket rate limiter keyed by
|
||||||
// IP address.
|
// client IP address and sets standard rate-limit headers on every response.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
|
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
|
||||||
// - burst: maximum burst size above the sustained rate
|
// - burst: maximum burst size above the sustained rate
|
||||||
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
|
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
|
||||||
store := middleware.NewRateLimiterMemoryStoreWithConfig(
|
store := newRateLimitStore(ratePerSecond, burst, 5*time.Minute)
|
||||||
middleware.RateLimiterMemoryStoreConfig{
|
|
||||||
Rate: ratePerSecond,
|
|
||||||
Burst: burst,
|
|
||||||
ExpiresIn: 5 * time.Minute,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
Skipper: middleware.DefaultSkipper,
|
return func(c echo.Context) error {
|
||||||
IdentifierExtractor: func(c echo.Context) (string, error) {
|
identifier := c.RealIP()
|
||||||
return c.RealIP(), nil
|
allowed, remaining, resetAt := store.allow(identifier)
|
||||||
},
|
|
||||||
Store: store,
|
// Set rate-limit headers on every response
|
||||||
DenyHandler: func(c echo.Context, _ string, _ error) error {
|
h := c.Response().Header()
|
||||||
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
|
h.Set(HeaderXRateLimitLimit, fmt.Sprintf("%d", burst))
|
||||||
Error: "Too many requests. Please try again later.",
|
h.Set(HeaderXRateLimitRemaining, fmt.Sprintf("%d", remaining))
|
||||||
})
|
h.Set(HeaderXRateLimitReset, fmt.Sprintf("%d", resetAt.Unix()))
|
||||||
},
|
|
||||||
ErrorHandler: func(c echo.Context, err error) error {
|
if !allowed {
|
||||||
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
|
// Calculate Retry-After in seconds (ceiling to be safe)
|
||||||
Error: "Unable to process request.",
|
retryAfter := int(math.Ceil(time.Until(resetAt).Seconds()))
|
||||||
})
|
if retryAfter < 1 {
|
||||||
},
|
retryAfter = 1
|
||||||
})
|
}
|
||||||
|
h.Set(HeaderRetryAfter, fmt.Sprintf("%d", retryAfter))
|
||||||
|
|
||||||
|
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
|
||||||
|
Error: "Too many requests. Please try again later.",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginRateLimiter returns rate-limiting middleware for login endpoints.
|
// LoginRateLimiter returns rate-limiting middleware for login endpoints.
|
||||||
|
|||||||
172
internal/middleware/rate_limit_test.go
Normal file
172
internal/middleware/rate_limit_test.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
)
|
||||||
|
|
||||||
|
// okHandler returns 200 with a simple body.
|
||||||
|
func okHandler(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_AllowedRequest_IncludesHeaders(t *testing.T) {
|
||||||
|
// Create a rate limiter with burst=3 (generous enough for a single request)
|
||||||
|
mw := AuthRateLimiter(rate.Limit(1.0), 3)
|
||||||
|
e := echo.New()
|
||||||
|
handler := mw(okHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
err := handler(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
// X-RateLimit-Limit should equal the burst size
|
||||||
|
assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
|
||||||
|
|
||||||
|
// X-RateLimit-Remaining should be present and a valid integer
|
||||||
|
remaining := rec.Header().Get(HeaderXRateLimitRemaining)
|
||||||
|
assert.NotEmpty(t, remaining)
|
||||||
|
rem, err := strconv.Atoi(remaining)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.GreaterOrEqual(t, rem, 0)
|
||||||
|
|
||||||
|
// X-RateLimit-Reset should be a Unix timestamp in the future (or now)
|
||||||
|
resetStr := rec.Header().Get(HeaderXRateLimitReset)
|
||||||
|
assert.NotEmpty(t, resetStr)
|
||||||
|
resetTS, err := strconv.ParseInt(resetStr, 10, 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.GreaterOrEqual(t, resetTS, time.Now().Unix()-1)
|
||||||
|
|
||||||
|
// Retry-After should NOT be present on allowed requests
|
||||||
|
assert.Empty(t, rec.Header().Get(HeaderRetryAfter))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_DeniedRequest_Returns429WithRetryAfter(t *testing.T) {
|
||||||
|
// Create a very strict limiter: burst=1, so second request is denied.
|
||||||
|
mw := AuthRateLimiter(rate.Limit(0.01), 1) // ~1 request per 100s
|
||||||
|
e := echo.New()
|
||||||
|
handler := mw(okHandler)
|
||||||
|
|
||||||
|
// First request — should succeed
|
||||||
|
req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
c1 := e.NewContext(req1, rec1)
|
||||||
|
err := handler(c1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
|
||||||
|
// Second request — should be rate-limited
|
||||||
|
req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
c2 := e.NewContext(req2, rec2)
|
||||||
|
err = handler(c2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec2.Code)
|
||||||
|
|
||||||
|
// Should have all rate-limit headers
|
||||||
|
assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitLimit))
|
||||||
|
assert.Equal(t, "0", rec2.Header().Get(HeaderXRateLimitRemaining))
|
||||||
|
assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitReset))
|
||||||
|
|
||||||
|
// Retry-After must be present on 429 responses
|
||||||
|
retryAfter := rec2.Header().Get(HeaderRetryAfter)
|
||||||
|
assert.NotEmpty(t, retryAfter)
|
||||||
|
retrySeconds, err := strconv.Atoi(retryAfter)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Greater(t, retrySeconds, 0)
|
||||||
|
|
||||||
|
// Body should be the standard error response
|
||||||
|
var errResp responses.ErrorResponse
|
||||||
|
err = json.Unmarshal(rec2.Body.Bytes(), &errResp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, errResp.Error, "Too many requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_RemainingDecreases(t *testing.T) {
|
||||||
|
// Burst=3, slow replenish so we can see remaining decrease
|
||||||
|
mw := AuthRateLimiter(rate.Limit(0.001), 3)
|
||||||
|
e := echo.New()
|
||||||
|
handler := mw(okHandler)
|
||||||
|
|
||||||
|
var remainingValues []int
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
err := handler(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
rem, err := strconv.Atoi(rec.Header().Get(HeaderXRateLimitRemaining))
|
||||||
|
require.NoError(t, err)
|
||||||
|
remainingValues = append(remainingValues, rem)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remaining should decrease with each request
|
||||||
|
assert.Greater(t, remainingValues[0], remainingValues[1])
|
||||||
|
assert.Greater(t, remainingValues[1], remainingValues[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_DifferentIPs_IndependentLimits(t *testing.T) {
|
||||||
|
// Strict limiter: burst=1
|
||||||
|
mw := AuthRateLimiter(rate.Limit(0.01), 1)
|
||||||
|
e := echo.New()
|
||||||
|
handler := mw(okHandler)
|
||||||
|
|
||||||
|
// Request from IP "1.2.3.4" — should succeed
|
||||||
|
req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
req1.Header.Set("X-Real-Ip", "1.2.3.4")
|
||||||
|
rec1 := httptest.NewRecorder()
|
||||||
|
c1 := e.NewContext(req1, rec1)
|
||||||
|
err := handler(c1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec1.Code)
|
||||||
|
|
||||||
|
// Second request from same IP — should be denied
|
||||||
|
req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
req2.Header.Set("X-Real-Ip", "1.2.3.4")
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
c2 := e.NewContext(req2, rec2)
|
||||||
|
err = handler(c2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec2.Code)
|
||||||
|
|
||||||
|
// Request from different IP "5.6.7.8" — should still succeed
|
||||||
|
req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
req3.Header.Set("X-Real-Ip", "5.6.7.8")
|
||||||
|
rec3 := httptest.NewRecorder()
|
||||||
|
c3 := e.NewContext(req3, rec3)
|
||||||
|
err = handler(c3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec3.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginRateLimiter_ReturnsMiddleware(t *testing.T) {
|
||||||
|
mw := LoginRateLimiter()
|
||||||
|
assert.NotNil(t, mw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrationRateLimiter_ReturnsMiddleware(t *testing.T) {
|
||||||
|
mw := RegistrationRateLimiter()
|
||||||
|
assert.NotNil(t, mw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetRateLimiter_ReturnsMiddleware(t *testing.T) {
|
||||||
|
mw := PasswordResetRateLimiter()
|
||||||
|
assert.NotNil(t, mw)
|
||||||
|
}
|
||||||
@@ -309,7 +309,7 @@ func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
|
|||||||
AllowOrigins: origins,
|
AllowOrigins: origins,
|
||||||
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodOptions},
|
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodOptions},
|
||||||
AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization, "X-Requested-With", "X-Timezone"},
|
AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization, "X-Requested-With", "X-Timezone"},
|
||||||
ExposeHeaders: []string{echo.HeaderContentLength},
|
ExposeHeaders: []string{echo.HeaderContentLength, "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset", "Retry-After"},
|
||||||
AllowCredentials: false,
|
AllowCredentials: false,
|
||||||
MaxAge: int((12 * time.Hour).Seconds()),
|
MaxAge: int((12 * time.Hour).Seconds()),
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user