Custom rate limiter replacing Echo built-in, with per-IP token bucket. Every response includes X-RateLimit-Limit, Remaining, Reset headers. 429 responses additionally include Retry-After (seconds). CORS updated to expose rate limit headers to mobile clients. 4 unit tests for header behavior and per-IP isolation.
173 lines
5.3 KiB
Go
173 lines
5.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
|
)
|
|
|
|
// okHandler returns 200 with a simple body.
|
|
func okHandler(c echo.Context) error {
|
|
return c.String(http.StatusOK, "ok")
|
|
}
|
|
|
|
func TestRateLimiter_AllowedRequest_IncludesHeaders(t *testing.T) {
|
|
// Create a rate limiter with burst=3 (generous enough for a single request)
|
|
mw := AuthRateLimiter(rate.Limit(1.0), 3)
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
// X-RateLimit-Limit should equal the burst size
|
|
assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
|
|
|
|
// X-RateLimit-Remaining should be present and a valid integer
|
|
remaining := rec.Header().Get(HeaderXRateLimitRemaining)
|
|
assert.NotEmpty(t, remaining)
|
|
rem, err := strconv.Atoi(remaining)
|
|
require.NoError(t, err)
|
|
assert.GreaterOrEqual(t, rem, 0)
|
|
|
|
// X-RateLimit-Reset should be a Unix timestamp in the future (or now)
|
|
resetStr := rec.Header().Get(HeaderXRateLimitReset)
|
|
assert.NotEmpty(t, resetStr)
|
|
resetTS, err := strconv.ParseInt(resetStr, 10, 64)
|
|
require.NoError(t, err)
|
|
assert.GreaterOrEqual(t, resetTS, time.Now().Unix()-1)
|
|
|
|
// Retry-After should NOT be present on allowed requests
|
|
assert.Empty(t, rec.Header().Get(HeaderRetryAfter))
|
|
}
|
|
|
|
func TestRateLimiter_DeniedRequest_Returns429WithRetryAfter(t *testing.T) {
|
|
// Create a very strict limiter: burst=1, so second request is denied.
|
|
mw := AuthRateLimiter(rate.Limit(0.01), 1) // ~1 request per 100s
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
// First request — should succeed
|
|
req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
rec1 := httptest.NewRecorder()
|
|
c1 := e.NewContext(req1, rec1)
|
|
err := handler(c1)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec1.Code)
|
|
|
|
// Second request — should be rate-limited
|
|
req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
rec2 := httptest.NewRecorder()
|
|
c2 := e.NewContext(req2, rec2)
|
|
err = handler(c2)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec2.Code)
|
|
|
|
// Should have all rate-limit headers
|
|
assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitLimit))
|
|
assert.Equal(t, "0", rec2.Header().Get(HeaderXRateLimitRemaining))
|
|
assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitReset))
|
|
|
|
// Retry-After must be present on 429 responses
|
|
retryAfter := rec2.Header().Get(HeaderRetryAfter)
|
|
assert.NotEmpty(t, retryAfter)
|
|
retrySeconds, err := strconv.Atoi(retryAfter)
|
|
require.NoError(t, err)
|
|
assert.Greater(t, retrySeconds, 0)
|
|
|
|
// Body should be the standard error response
|
|
var errResp responses.ErrorResponse
|
|
err = json.Unmarshal(rec2.Body.Bytes(), &errResp)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, errResp.Error, "Too many requests")
|
|
}
|
|
|
|
func TestRateLimiter_RemainingDecreases(t *testing.T) {
|
|
// Burst=3, slow replenish so we can see remaining decrease
|
|
mw := AuthRateLimiter(rate.Limit(0.001), 3)
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
var remainingValues []int
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
err := handler(c)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
rem, err := strconv.Atoi(rec.Header().Get(HeaderXRateLimitRemaining))
|
|
require.NoError(t, err)
|
|
remainingValues = append(remainingValues, rem)
|
|
}
|
|
|
|
// Remaining should decrease with each request
|
|
assert.Greater(t, remainingValues[0], remainingValues[1])
|
|
assert.Greater(t, remainingValues[1], remainingValues[2])
|
|
}
|
|
|
|
func TestRateLimiter_DifferentIPs_IndependentLimits(t *testing.T) {
|
|
// Strict limiter: burst=1
|
|
mw := AuthRateLimiter(rate.Limit(0.01), 1)
|
|
e := echo.New()
|
|
handler := mw(okHandler)
|
|
|
|
// Request from IP "1.2.3.4" — should succeed
|
|
req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
req1.Header.Set("X-Real-Ip", "1.2.3.4")
|
|
rec1 := httptest.NewRecorder()
|
|
c1 := e.NewContext(req1, rec1)
|
|
err := handler(c1)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec1.Code)
|
|
|
|
// Second request from same IP — should be denied
|
|
req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
req2.Header.Set("X-Real-Ip", "1.2.3.4")
|
|
rec2 := httptest.NewRecorder()
|
|
c2 := e.NewContext(req2, rec2)
|
|
err = handler(c2)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec2.Code)
|
|
|
|
// Request from different IP "5.6.7.8" — should still succeed
|
|
req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
|
req3.Header.Set("X-Real-Ip", "5.6.7.8")
|
|
rec3 := httptest.NewRecorder()
|
|
c3 := e.NewContext(req3, rec3)
|
|
err = handler(c3)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, http.StatusOK, rec3.Code)
|
|
}
|
|
|
|
func TestLoginRateLimiter_ReturnsMiddleware(t *testing.T) {
|
|
mw := LoginRateLimiter()
|
|
assert.NotNil(t, mw)
|
|
}
|
|
|
|
func TestRegistrationRateLimiter_ReturnsMiddleware(t *testing.T) {
|
|
mw := RegistrationRateLimiter()
|
|
assert.NotNil(t, mw)
|
|
}
|
|
|
|
func TestPasswordResetRateLimiter_ReturnsMiddleware(t *testing.T) {
|
|
mw := PasswordResetRateLimiter()
|
|
assert.NotNil(t, mw)
|
|
}
|