package middleware import ( "context" "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/models" "github.com/treytartt/casera-api/internal/services" ) const ( // AuthUserKey is the key used to store the authenticated user in the context AuthUserKey = "auth_user" // AuthTokenKey is the key used to store the token in the context AuthTokenKey = "auth_token" // TokenCacheTTL is the duration to cache tokens in Redis TokenCacheTTL = 5 * time.Minute // TokenCachePrefix is the prefix for token cache keys TokenCachePrefix = "auth_token_" ) // AuthMiddleware provides token authentication middleware type AuthMiddleware struct { db *gorm.DB cache *services.CacheService } // NewAuthMiddleware creates a new auth middleware instance func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { return &AuthMiddleware{ db: db, cache: cache, } } // TokenAuth returns a Gin middleware that validates token authentication func (m *AuthMiddleware) TokenAuth() gin.HandlerFunc { return func(c *gin.Context) { // Extract token from Authorization header token, err := extractToken(c) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": err.Error(), }) return } // Try to get user from cache first user, err := m.getUserFromCache(c.Request.Context(), token) if err == nil && user != nil { // Cache hit - set user in context and continue c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) c.Next() return } // Cache miss - look up token in database user, err = m.getUserFromDatabase(token) if err != nil { log.Debug().Err(err).Str("token", token[:8]+"...").Msg("Token authentication failed") c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Invalid token", }) return } // Cache the user ID for future requests if cacheErr := m.cacheUserID(c.Request.Context(), token, user.ID); cacheErr != nil { log.Warn().Err(cacheErr).Msg("Failed to cache user ID") } // Set user in context c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) c.Next() } } // OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it func (m *AuthMiddleware) OptionalTokenAuth() gin.HandlerFunc { return func(c *gin.Context) { token, err := extractToken(c) if err != nil { // No token or invalid format - continue without user c.Next() return } // Try cache first user, err := m.getUserFromCache(c.Request.Context(), token) if err == nil && user != nil { c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) c.Next() return } // Try database user, err = m.getUserFromDatabase(token) if err == nil { m.cacheUserID(c.Request.Context(), token, user.ID) c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) } c.Next() } } // extractToken extracts the token from the Authorization header func extractToken(c *gin.Context) (string, error) { authHeader := c.GetHeader("Authorization") if authHeader == "" { return "", fmt.Errorf("authorization header required") } // Support both "Token xxx" (Django style) and "Bearer xxx" formats parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 { return "", fmt.Errorf("invalid authorization header format") } scheme := parts[0] token := parts[1] if scheme != "Token" && scheme != "Bearer" { return "", fmt.Errorf("invalid authorization scheme: %s", scheme) } if token == "" { return "", fmt.Errorf("token is empty") } return token, nil } // getUserFromCache tries to get user from Redis cache func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { if m.cache == nil { return nil, fmt.Errorf("cache not available") } userID, err := m.cache.GetCachedAuthToken(ctx, token) if err != nil { if err == redis.Nil { return nil, fmt.Errorf("token not in cache") } return nil, err } // Get user from database by ID var user models.User if err := m.db.First(&user, userID).Error; err != nil { // User was deleted - invalidate cache m.cache.InvalidateAuthToken(ctx, token) return nil, err } // Check if user is active if !user.IsActive { m.cache.InvalidateAuthToken(ctx, token) return nil, fmt.Errorf("user is inactive") } return &user, nil } // getUserFromDatabase looks up the token in the database func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { var authToken models.AuthToken if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil { return nil, fmt.Errorf("token not found") } // Check if user is active if !authToken.User.IsActive { return nil, fmt.Errorf("user is inactive") } return &authToken.User, nil } // cacheUserID caches the user ID for a token func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error { if m.cache == nil { return nil } return m.cache.CacheAuthToken(ctx, token, userID) } // InvalidateToken removes a token from the cache func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error { if m.cache == nil { return nil } return m.cache.InvalidateAuthToken(ctx, token) } // GetAuthUser retrieves the authenticated user from the Gin context func GetAuthUser(c *gin.Context) *models.User { user, exists := c.Get(AuthUserKey) if !exists { return nil } return user.(*models.User) } // GetAuthToken retrieves the auth token from the Gin context func GetAuthToken(c *gin.Context) string { token, exists := c.Get(AuthTokenKey) if !exists { return "" } return token.(string) } // MustGetAuthUser retrieves the authenticated user or aborts with 401 func MustGetAuthUser(c *gin.Context) *models.User { user := GetAuthUser(c) if user == nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Authentication required", }) return nil } return user }