package middleware import ( "encoding/json" "net/http" "net/http/httptest" "strconv" "testing" "time" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/time/rate" "github.com/treytartt/honeydue-api/internal/dto/responses" ) // okHandler returns 200 with a simple body. func okHandler(c echo.Context) error { return c.String(http.StatusOK, "ok") } func TestRateLimiter_AllowedRequest_IncludesHeaders(t *testing.T) { // Create a rate limiter with burst=3 (generous enough for a single request) mw := AuthRateLimiter(rate.Limit(1.0), 3) e := echo.New() handler := mw(okHandler) req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) // X-RateLimit-Limit should equal the burst size assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit)) // X-RateLimit-Remaining should be present and a valid integer remaining := rec.Header().Get(HeaderXRateLimitRemaining) assert.NotEmpty(t, remaining) rem, err := strconv.Atoi(remaining) require.NoError(t, err) assert.GreaterOrEqual(t, rem, 0) // X-RateLimit-Reset should be a Unix timestamp in the future (or now) resetStr := rec.Header().Get(HeaderXRateLimitReset) assert.NotEmpty(t, resetStr) resetTS, err := strconv.ParseInt(resetStr, 10, 64) require.NoError(t, err) assert.GreaterOrEqual(t, resetTS, time.Now().Unix()-1) // Retry-After should NOT be present on allowed requests assert.Empty(t, rec.Header().Get(HeaderRetryAfter)) } func TestRateLimiter_DeniedRequest_Returns429WithRetryAfter(t *testing.T) { // Create a very strict limiter: burst=1, so second request is denied. mw := AuthRateLimiter(rate.Limit(0.01), 1) // ~1 request per 100s e := echo.New() handler := mw(okHandler) // First request — should succeed req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) rec1 := httptest.NewRecorder() c1 := e.NewContext(req1, rec1) err := handler(c1) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec1.Code) // Second request — should be rate-limited req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) rec2 := httptest.NewRecorder() c2 := e.NewContext(req2, rec2) err = handler(c2) require.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, rec2.Code) // Should have all rate-limit headers assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitLimit)) assert.Equal(t, "0", rec2.Header().Get(HeaderXRateLimitRemaining)) assert.NotEmpty(t, rec2.Header().Get(HeaderXRateLimitReset)) // Retry-After must be present on 429 responses retryAfter := rec2.Header().Get(HeaderRetryAfter) assert.NotEmpty(t, retryAfter) retrySeconds, err := strconv.Atoi(retryAfter) require.NoError(t, err) assert.Greater(t, retrySeconds, 0) // Body should be the standard error response var errResp responses.ErrorResponse err = json.Unmarshal(rec2.Body.Bytes(), &errResp) require.NoError(t, err) assert.Contains(t, errResp.Error, "Too many requests") } func TestRateLimiter_RemainingDecreases(t *testing.T) { // Burst=3, slow replenish so we can see remaining decrease mw := AuthRateLimiter(rate.Limit(0.001), 3) e := echo.New() handler := mw(okHandler) var remainingValues []int for i := 0; i < 3; i++ { req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := handler(c) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) rem, err := strconv.Atoi(rec.Header().Get(HeaderXRateLimitRemaining)) require.NoError(t, err) remainingValues = append(remainingValues, rem) } // Remaining should decrease with each request assert.Greater(t, remainingValues[0], remainingValues[1]) assert.Greater(t, remainingValues[1], remainingValues[2]) } func TestRateLimiter_DifferentIPs_IndependentLimits(t *testing.T) { // Strict limiter: burst=1 mw := AuthRateLimiter(rate.Limit(0.01), 1) e := echo.New() handler := mw(okHandler) // Request from IP "1.2.3.4" — should succeed req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) req1.Header.Set("X-Real-Ip", "1.2.3.4") rec1 := httptest.NewRecorder() c1 := e.NewContext(req1, rec1) err := handler(c1) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec1.Code) // Second request from same IP — should be denied req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) req2.Header.Set("X-Real-Ip", "1.2.3.4") rec2 := httptest.NewRecorder() c2 := e.NewContext(req2, rec2) err = handler(c2) require.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, rec2.Code) // Request from different IP "5.6.7.8" — should still succeed req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) req3.Header.Set("X-Real-Ip", "5.6.7.8") rec3 := httptest.NewRecorder() c3 := e.NewContext(req3, rec3) err = handler(c3) require.NoError(t, err) assert.Equal(t, http.StatusOK, rec3.Code) } func TestLoginRateLimiter_ReturnsMiddleware(t *testing.T) { mw := LoginRateLimiter() assert.NotNil(t, mw) } func TestRegistrationRateLimiter_ReturnsMiddleware(t *testing.T) { mw := RegistrationRateLimiter() assert.NotNil(t, mw) } func TestPasswordResetRateLimiter_ReturnsMiddleware(t *testing.T) { mw := PasswordResetRateLimiter() assert.NotNil(t, mw) }