From 6df27f203b7066ccb4f452798ed28c306e024fb6 Mon Sep 17 00:00:00 2001 From: Trey T Date: Thu, 26 Mar 2026 14:36:48 -0500 Subject: [PATCH] Add rate limit response headers (X-RateLimit-*, Retry-After) Custom rate limiter replacing Echo built-in, with per-IP token bucket. Every response includes X-RateLimit-Limit, Remaining, Reset headers. 429 responses additionally include Retry-After (seconds). CORS updated to expose rate limit headers to mobile clients. 4 unit tests for header behavior and per-IP isolation. --- internal/middleware/rate_limit.go | 148 +++++++++++++++++---- internal/middleware/rate_limit_test.go | 172 +++++++++++++++++++++++++ internal/router/router.go | 2 +- 3 files changed, 294 insertions(+), 28 deletions(-) create mode 100644 internal/middleware/rate_limit_test.go diff --git a/internal/middleware/rate_limit.go b/internal/middleware/rate_limit.go index a46cb5a..ac4bd9e 100644 --- a/internal/middleware/rate_limit.go +++ b/internal/middleware/rate_limit.go @@ -1,49 +1,143 @@ package middleware import ( + "fmt" + "math" "net/http" + "sync" "time" "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" "golang.org/x/time/rate" "github.com/treytartt/honeydue-api/internal/dto/responses" ) +// rateLimitEntry holds a per-client token-bucket limiter and metadata needed +// to compute the standard rate-limit response headers. +type rateLimitEntry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// rateLimitStore is a thread-safe in-memory store keyed by client identifier +// (typically IP address). Stale entries are lazily evicted. +type rateLimitStore struct { + mu sync.Mutex + entries map[string]*rateLimitEntry + rate rate.Limit + burst int + expiresIn time.Duration +} + +func newRateLimitStore(r rate.Limit, burst int, expiresIn time.Duration) *rateLimitStore { + return &rateLimitStore{ + entries: make(map[string]*rateLimitEntry), + rate: r, + burst: burst, + expiresIn: expiresIn, + } +} + +// allow checks the rate limiter for the given identifier and returns: +// - allowed: whether the request should be permitted +// - remaining: approximate number of tokens left (requests remaining) +// - resetAt: when the bucket will next be full +func (s *rateLimitStore) allow(identifier string) (allowed bool, remaining int, resetAt time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + + // Lazy eviction of stale entries + if len(s.entries) > 1000 { + for key, entry := range s.entries { + if now.Sub(entry.lastSeen) > s.expiresIn { + delete(s.entries, key) + } + } + } + + entry, exists := s.entries[identifier] + if !exists { + limiter := rate.NewLimiter(s.rate, s.burst) + entry = &rateLimitEntry{limiter: limiter} + s.entries[identifier] = entry + } + entry.lastSeen = now + + allowed = entry.limiter.Allow() + + // Compute remaining tokens (floor to nearest int, min 0). + // After Allow(), tokens are already decremented if allowed. + tokens := entry.limiter.Tokens() + remaining = int(math.Floor(tokens)) + if remaining < 0 { + remaining = 0 + } + + // Compute when the bucket will be fully replenished. + // tokens needed = burst - current tokens + tokensNeeded := float64(s.burst) - tokens + if tokensNeeded <= 0 { + resetAt = now + } else { + secondsToFull := tokensNeeded / float64(s.rate) + resetAt = now.Add(time.Duration(secondsToFull * float64(time.Second))) + } + + return allowed, remaining, resetAt +} + +// HeaderXRateLimitLimit is the max requests allowed in the window. +const HeaderXRateLimitLimit = "X-RateLimit-Limit" + +// HeaderXRateLimitRemaining is the number of requests remaining. +const HeaderXRateLimitRemaining = "X-RateLimit-Remaining" + +// HeaderXRateLimitReset is the Unix timestamp when the window resets. +const HeaderXRateLimitReset = "X-RateLimit-Reset" + +// HeaderRetryAfter is the seconds until the client can retry (429 only). +const HeaderRetryAfter = "Retry-After" + // AuthRateLimiter returns rate-limiting middleware tuned for authentication -// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client -// IP address. +// endpoints. It uses a custom in-memory token-bucket rate limiter keyed by +// client IP address and sets standard rate-limit headers on every response. // // Parameters: // - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute) // - burst: maximum burst size above the sustained rate func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc { - store := middleware.NewRateLimiterMemoryStoreWithConfig( - middleware.RateLimiterMemoryStoreConfig{ - Rate: ratePerSecond, - Burst: burst, - ExpiresIn: 5 * time.Minute, - }, - ) + store := newRateLimitStore(ratePerSecond, burst, 5*time.Minute) - return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{ - Skipper: middleware.DefaultSkipper, - IdentifierExtractor: func(c echo.Context) (string, error) { - return c.RealIP(), nil - }, - Store: store, - DenyHandler: func(c echo.Context, _ string, _ error) error { - return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{ - Error: "Too many requests. Please try again later.", - }) - }, - ErrorHandler: func(c echo.Context, err error) error { - return c.JSON(http.StatusForbidden, responses.ErrorResponse{ - Error: "Unable to process request.", - }) - }, - }) + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + identifier := c.RealIP() + allowed, remaining, resetAt := store.allow(identifier) + + // Set rate-limit headers on every response + h := c.Response().Header() + h.Set(HeaderXRateLimitLimit, fmt.Sprintf("%d", burst)) + h.Set(HeaderXRateLimitRemaining, fmt.Sprintf("%d", remaining)) + h.Set(HeaderXRateLimitReset, fmt.Sprintf("%d", resetAt.Unix())) + + if !allowed { + // Calculate Retry-After in seconds (ceiling to be safe) + retryAfter := int(math.Ceil(time.Until(resetAt).Seconds())) + if retryAfter < 1 { + retryAfter = 1 + } + h.Set(HeaderRetryAfter, fmt.Sprintf("%d", retryAfter)) + + return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{ + Error: "Too many requests. Please try again later.", + }) + } + + return next(c) + } + } } // LoginRateLimiter returns rate-limiting middleware for login endpoints. diff --git a/internal/middleware/rate_limit_test.go b/internal/middleware/rate_limit_test.go new file mode 100644 index 0000000..1b82db9 --- /dev/null +++ b/internal/middleware/rate_limit_test.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" + + "github.com/treytartt/honeydue-api/internal/dto/responses" +) + +// okHandler returns 200 with a simple body. +func okHandler(c echo.Context) error { + return c.String(http.StatusOK, "ok") +} + +func TestRateLimiter_AllowedRequest_IncludesHeaders(t *testing.T) { + // Create a rate limiter with burst=3 (generous enough for a single request) + mw := AuthRateLimiter(rate.Limit(1.0), 3) + e := echo.New() + handler := mw(okHandler) + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // X-RateLimit-Limit should equal the burst size + assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit)) + + // X-RateLimit-Remaining should be present and a valid integer + remaining := rec.Header().Get(HeaderXRateLimitRemaining) + assert.NotEmpty(t, remaining) + rem, err := strconv.Atoi(remaining) + require.NoError(t, err) + assert.GreaterOrEqual(t, rem, 0) + + // X-RateLimit-Reset should be a Unix timestamp in the future (or now) + resetStr := rec.Header().Get(HeaderXRateLimitReset) + assert.NotEmpty(t, resetStr) + resetTS, err := strconv.ParseInt(resetStr, 10, 64) + require.NoError(t, err) + assert.GreaterOrEqual(t, resetTS, time.Now().Unix()-1) + + // Retry-After should NOT be present on allowed requests + assert.Empty(t, rec.Header().Get(HeaderRetryAfter)) +} + +func TestRateLimiter_DeniedRequest_Returns429WithRetryAfter(t *testing.T) { + // Create a very strict limiter: burst=1, so second request is denied. + mw := AuthRateLimiter(rate.Limit(0.01), 1) // ~1 request per 100s + e := echo.New() + handler := mw(okHandler) + + // First request — should succeed + req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + err := handler(c1) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + // Second request — should be rate-limited + req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + err = handler(c2) + require.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec2.Code) + + // Should have all rate-limit headers + assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitLimit)) + assert.Equal(t, "0", rec2.Header().Get(HeaderXRateLimitRemaining)) + assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitReset)) + + // Retry-After must be present on 429 responses + retryAfter := rec2.Header().Get(HeaderRetryAfter) + assert.NotEmpty(t, retryAfter) + retrySeconds, err := strconv.Atoi(retryAfter) + require.NoError(t, err) + assert.Greater(t, retrySeconds, 0) + + // Body should be the standard error response + var errResp responses.ErrorResponse + err = json.Unmarshal(rec2.Body.Bytes(), &errResp) + require.NoError(t, err) + assert.Contains(t, errResp.Error, "Too many requests") +} + +func TestRateLimiter_RemainingDecreases(t *testing.T) { + // Burst=3, slow replenish so we can see remaining decrease + mw := AuthRateLimiter(rate.Limit(0.001), 3) + e := echo.New() + handler := mw(okHandler) + + var remainingValues []int + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + err := handler(c) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + rem, err := strconv.Atoi(rec.Header().Get(HeaderXRateLimitRemaining)) + require.NoError(t, err) + remainingValues = append(remainingValues, rem) + } + + // Remaining should decrease with each request + assert.Greater(t, remainingValues[0], remainingValues[1]) + assert.Greater(t, remainingValues[1], remainingValues[2]) +} + +func TestRateLimiter_DifferentIPs_IndependentLimits(t *testing.T) { + // Strict limiter: burst=1 + mw := AuthRateLimiter(rate.Limit(0.01), 1) + e := echo.New() + handler := mw(okHandler) + + // Request from IP "1.2.3.4" — should succeed + req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + req1.Header.Set("X-Real-Ip", "1.2.3.4") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + err := handler(c1) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + // Second request from same IP — should be denied + req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + req2.Header.Set("X-Real-Ip", "1.2.3.4") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + err = handler(c2) + require.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec2.Code) + + // Request from different IP "5.6.7.8" — should still succeed + req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + req3.Header.Set("X-Real-Ip", "5.6.7.8") + rec3 := httptest.NewRecorder() + c3 := e.NewContext(req3, rec3) + err = handler(c3) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec3.Code) +} + +func TestLoginRateLimiter_ReturnsMiddleware(t *testing.T) { + mw := LoginRateLimiter() + assert.NotNil(t, mw) +} + +func TestRegistrationRateLimiter_ReturnsMiddleware(t *testing.T) { + mw := RegistrationRateLimiter() + assert.NotNil(t, mw) +} + +func TestPasswordResetRateLimiter_ReturnsMiddleware(t *testing.T) { + mw := PasswordResetRateLimiter() + assert.NotNil(t, mw) +} diff --git a/internal/router/router.go b/internal/router/router.go index 441a79e..a73a1bc 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -309,7 +309,7 @@ func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc { AllowOrigins: origins, AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodOptions}, AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization, "X-Requested-With", "X-Timezone"}, - ExposeHeaders: []string{echo.HeaderContentLength}, + ExposeHeaders: []string{echo.HeaderContentLength, "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset", "Retry-After"}, AllowCredentials: false, MaxAge: int((12 * time.Hour).Seconds()), })