package middleware import ( "fmt" "math" "net/http" "sync" "time" "github.com/labstack/echo/v4" "golang.org/x/time/rate" "github.com/treytartt/honeydue-api/internal/dto/responses" ) // rateLimitEntry holds a per-client token-bucket limiter and metadata needed // to compute the standard rate-limit response headers. type rateLimitEntry struct { limiter *rate.Limiter lastSeen time.Time } // rateLimitStore is a thread-safe in-memory store keyed by client identifier // (typically IP address). Stale entries are lazily evicted. type rateLimitStore struct { mu sync.Mutex entries map[string]*rateLimitEntry rate rate.Limit burst int expiresIn time.Duration } func newRateLimitStore(r rate.Limit, burst int, expiresIn time.Duration) *rateLimitStore { return &rateLimitStore{ entries: make(map[string]*rateLimitEntry), rate: r, burst: burst, expiresIn: expiresIn, } } // allow checks the rate limiter for the given identifier and returns: // - allowed: whether the request should be permitted // - remaining: approximate number of tokens left (requests remaining) // - resetAt: when the bucket will next be full func (s *rateLimitStore) allow(identifier string) (allowed bool, remaining int, resetAt time.Time) { s.mu.Lock() defer s.mu.Unlock() now := time.Now() // Lazy eviction of stale entries if len(s.entries) > 1000 { for key, entry := range s.entries { if now.Sub(entry.lastSeen) > s.expiresIn { delete(s.entries, key) } } } entry, exists := s.entries[identifier] if !exists { limiter := rate.NewLimiter(s.rate, s.burst) entry = &rateLimitEntry{limiter: limiter} s.entries[identifier] = entry } entry.lastSeen = now allowed = entry.limiter.Allow() // Compute remaining tokens (floor to nearest int, min 0). // After Allow(), tokens are already decremented if allowed. tokens := entry.limiter.Tokens() remaining = int(math.Floor(tokens)) if remaining < 0 { remaining = 0 } // Compute when the bucket will be fully replenished. // tokens needed = burst - current tokens tokensNeeded := float64(s.burst) - tokens if tokensNeeded <= 0 { resetAt = now } else { secondsToFull := tokensNeeded / float64(s.rate) resetAt = now.Add(time.Duration(secondsToFull * float64(time.Second))) } return allowed, remaining, resetAt } // HeaderXRateLimitLimit is the max requests allowed in the window. const HeaderXRateLimitLimit = "X-RateLimit-Limit" // HeaderXRateLimitRemaining is the number of requests remaining. const HeaderXRateLimitRemaining = "X-RateLimit-Remaining" // HeaderXRateLimitReset is the Unix timestamp when the window resets. const HeaderXRateLimitReset = "X-RateLimit-Reset" // HeaderRetryAfter is the seconds until the client can retry (429 only). const HeaderRetryAfter = "Retry-After" // AuthRateLimiter returns rate-limiting middleware tuned for authentication // endpoints. It uses a custom in-memory token-bucket rate limiter keyed by // client IP address and sets standard rate-limit headers on every response. // // Parameters: // - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute) // - burst: maximum burst size above the sustained rate func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc { store := newRateLimitStore(ratePerSecond, burst, 5*time.Minute) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { identifier := c.RealIP() allowed, remaining, resetAt := store.allow(identifier) // Set rate-limit headers on every response h := c.Response().Header() h.Set(HeaderXRateLimitLimit, fmt.Sprintf("%d", burst)) h.Set(HeaderXRateLimitRemaining, fmt.Sprintf("%d", remaining)) h.Set(HeaderXRateLimitReset, fmt.Sprintf("%d", resetAt.Unix())) if !allowed { // Calculate Retry-After in seconds (ceiling to be safe) retryAfter := int(math.Ceil(time.Until(resetAt).Seconds())) if retryAfter < 1 { retryAfter = 1 } h.Set(HeaderRetryAfter, fmt.Sprintf("%d", retryAfter)) return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{ Error: "Too many requests. Please try again later.", }) } return next(c) } } } // LoginRateLimiter returns rate-limiting middleware for login endpoints. // Allows 10 requests per minute with a burst of 5. func LoginRateLimiter() echo.MiddlewareFunc { // 10 requests per 60 seconds = ~0.167 req/s, burst 5 return AuthRateLimiter(rate.Limit(10.0/60.0), 5) } // RegistrationRateLimiter returns rate-limiting middleware for registration // endpoints. Allows 5 requests per minute with a burst of 3. func RegistrationRateLimiter() echo.MiddlewareFunc { // 5 requests per 60 seconds = ~0.083 req/s, burst 3 return AuthRateLimiter(rate.Limit(5.0/60.0), 3) } // PasswordResetRateLimiter returns rate-limiting middleware for password // reset endpoints. Allows 3 requests per minute with a burst of 2. func PasswordResetRateLimiter() echo.MiddlewareFunc { // 3 requests per 60 seconds = 0.05 req/s, burst 2 return AuthRateLimiter(rate.Limit(3.0/60.0), 2) }