package middleware import ( "context" "fmt" "strings" "time" "github.com/labstack/echo/v4" "github.com/redis/go-redis/v9" "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/apperrors" "github.com/treytartt/casera-api/internal/models" "github.com/treytartt/casera-api/internal/services" ) const ( // AuthUserKey is the key used to store the authenticated user in the context AuthUserKey = "auth_user" // AuthTokenKey is the key used to store the token in the context AuthTokenKey = "auth_token" // TokenCacheTTL is the duration to cache tokens in Redis TokenCacheTTL = 5 * time.Minute // TokenCachePrefix is the prefix for token cache keys TokenCachePrefix = "auth_token_" ) // AuthMiddleware provides token authentication middleware type AuthMiddleware struct { db *gorm.DB cache *services.CacheService } // NewAuthMiddleware creates a new auth middleware instance func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { return &AuthMiddleware{ db: db, cache: cache, } } // TokenAuth returns an Echo middleware that validates token authentication func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // Extract token from Authorization header token, err := extractToken(c) if err != nil { return apperrors.Unauthorized("error.not_authenticated") } // Try to get user from cache first user, err := m.getUserFromCache(c.Request().Context(), token) if err == nil && user != nil { // Cache hit - set user in context and continue c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) return next(c) } // Cache miss - look up token in database user, err = m.getUserFromDatabase(token) if err != nil { log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed") return apperrors.Unauthorized("error.invalid_token") } // Cache the user ID for future requests if cacheErr := m.cacheUserID(c.Request().Context(), token, user.ID); cacheErr != nil { log.Warn().Err(cacheErr).Msg("Failed to cache user ID") } // Set user in context c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) return next(c) } } } // OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { token, err := extractToken(c) if err != nil { // No token or invalid format - continue without user return next(c) } // Try cache first user, err := m.getUserFromCache(c.Request().Context(), token) if err == nil && user != nil { c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) return next(c) } // Try database user, err = m.getUserFromDatabase(token) if err == nil { m.cacheUserID(c.Request().Context(), token, user.ID) c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) } return next(c) } } } // extractToken extracts the token from the Authorization header func extractToken(c echo.Context) (string, error) { authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { return "", fmt.Errorf("authorization header required") } // Support both "Token xxx" (Django style) and "Bearer xxx" formats parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 { return "", fmt.Errorf("invalid authorization header format") } scheme := parts[0] token := parts[1] if scheme != "Token" && scheme != "Bearer" { return "", fmt.Errorf("invalid authorization scheme: %s", scheme) } if token == "" { return "", fmt.Errorf("token is empty") } return token, nil } // getUserFromCache tries to get user from Redis cache func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { if m.cache == nil { return nil, fmt.Errorf("cache not available") } userID, err := m.cache.GetCachedAuthToken(ctx, token) if err != nil { if err == redis.Nil { return nil, fmt.Errorf("token not in cache") } return nil, err } // Get user from database by ID var user models.User if err := m.db.First(&user, userID).Error; err != nil { // User was deleted - invalidate cache m.cache.InvalidateAuthToken(ctx, token) return nil, err } // Check if user is active if !user.IsActive { m.cache.InvalidateAuthToken(ctx, token) return nil, fmt.Errorf("user is inactive") } return &user, nil } // getUserFromDatabase looks up the token in the database func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { var authToken models.AuthToken if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil { return nil, fmt.Errorf("token not found") } // Check if user is active if !authToken.User.IsActive { return nil, fmt.Errorf("user is inactive") } return &authToken.User, nil } // cacheUserID caches the user ID for a token func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error { if m.cache == nil { return nil } return m.cache.CacheAuthToken(ctx, token, userID) } // InvalidateToken removes a token from the cache func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error { if m.cache == nil { return nil } return m.cache.InvalidateAuthToken(ctx, token) } // GetAuthUser retrieves the authenticated user from the Echo context. // Returns nil if the context value is missing or not of the expected type. func GetAuthUser(c echo.Context) *models.User { val := c.Get(AuthUserKey) if val == nil { return nil } user, ok := val.(*models.User) if !ok { return nil } return user } // GetAuthToken retrieves the auth token from the Echo context func GetAuthToken(c echo.Context) string { token := c.Get(AuthTokenKey) if token == nil { return "" } return token.(string) } // MustGetAuthUser retrieves the authenticated user or returns error with 401 func MustGetAuthUser(c echo.Context) (*models.User, error) { user := GetAuthUser(c) if user == nil { return nil, apperrors.Unauthorized("error.not_authenticated") } return user, nil } // truncateToken safely truncates a token string for logging. // Returns at most the first 8 characters followed by "...". func truncateToken(token string) string { if len(token) > 8 { return token[:8] + "..." } return token + "..." }