package middleware import ( "context" "fmt" "strings" "time" "github.com/labstack/echo/v4" "github.com/redis/go-redis/v9" "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/services" ) const ( // AuthUserKey is the key used to store the authenticated user in the context AuthUserKey = "auth_user" // AuthTokenKey is the key used to store the token in the context AuthTokenKey = "auth_token" // TokenCacheTTL is the duration to cache tokens in Redis TokenCacheTTL = 5 * time.Minute // TokenCachePrefix is the prefix for token cache keys TokenCachePrefix = "auth_token_" // UserCacheTTL is how long full user records are cached in memory to // avoid hitting the database on every authenticated request. UserCacheTTL = 30 * time.Second // DefaultTokenExpiryDays is the default number of days before a token expires. DefaultTokenExpiryDays = 90 ) // AuthMiddleware provides token authentication middleware type AuthMiddleware struct { db *gorm.DB cache *services.CacheService userCache *UserCache tokenExpiryDays int } // NewAuthMiddleware creates a new auth middleware instance func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { return &AuthMiddleware{ db: db, cache: cache, userCache: NewUserCache(UserCacheTTL), tokenExpiryDays: DefaultTokenExpiryDays, } } // NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware { expiryDays := DefaultTokenExpiryDays if cfg != nil && cfg.Security.TokenExpiryDays > 0 { expiryDays = cfg.Security.TokenExpiryDays } return &AuthMiddleware{ db: db, cache: cache, userCache: NewUserCache(UserCacheTTL), tokenExpiryDays: expiryDays, } } // TokenExpiryDuration returns the token expiry duration. func (m *AuthMiddleware) TokenExpiryDuration() time.Duration { return time.Duration(m.tokenExpiryDays) * 24 * time.Hour } // isTokenExpired checks if a token's created timestamp indicates expiry. func (m *AuthMiddleware) isTokenExpired(created time.Time) bool { if created.IsZero() { return false // Legacy tokens without created time are not expired } return time.Since(created) > m.TokenExpiryDuration() } // TokenAuth returns an Echo middleware that validates token authentication func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // Extract token from Authorization header token, err := extractToken(c) if err != nil { return apperrors.Unauthorized("error.not_authenticated") } // Try to get user from cache first (includes expiry check) user, err := m.getUserFromCache(c.Request().Context(), token) if err == nil && user != nil { // Cache hit - set user in context and continue c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) return next(c) } // Check if the cache indicated token expiry if err != nil && err.Error() == "token expired" { return apperrors.Unauthorized("error.token_expired") } // Cache miss - look up token in database user, authToken, err := m.getUserFromDatabaseWithToken(token) if err != nil { log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed") return apperrors.Unauthorized("error.invalid_token") } // Check token expiry if m.isTokenExpired(authToken.Created) { log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired") return apperrors.Unauthorized("error.token_expired") } // Cache the user ID and token creation time for future requests if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil { log.Warn().Err(cacheErr).Msg("Failed to cache token info") } // Set user in context c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) return next(c) } } } // OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { token, err := extractToken(c) if err != nil { // No token or invalid format - continue without user return next(c) } // Try cache first user, err := m.getUserFromCache(c.Request().Context(), token) if err == nil && user != nil { c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) return next(c) } // Try database user, authToken, err := m.getUserFromDatabaseWithToken(token) if err == nil && !m.isTokenExpired(authToken.Created) { m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created) c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) } return next(c) } } } // extractToken extracts the token from the Authorization header func extractToken(c echo.Context) (string, error) { authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { return "", fmt.Errorf("authorization header required") } // Support both "Token xxx" (Django style) and "Bearer xxx" formats parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 { return "", fmt.Errorf("invalid authorization header format") } scheme := parts[0] token := parts[1] if scheme != "Token" && scheme != "Bearer" { return "", fmt.Errorf("invalid authorization scheme: %s", scheme) } if token == "" { return "", fmt.Errorf("token is empty") } return token, nil } // getUserFromCache tries to get user from Redis cache, then from the // in-memory user cache, before falling back to the database. // Returns a "token expired" error if the cached creation time indicates expiry. func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { if m.cache == nil { return nil, fmt.Errorf("cache not available") } userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token) if err != nil { if err == redis.Nil { return nil, fmt.Errorf("token not in cache") } return nil, err } // Check token expiry from cached creation time if createdUnix > 0 { created := time.Unix(createdUnix, 0) if m.isTokenExpired(created) { m.cache.InvalidateAuthToken(ctx, token) return nil, fmt.Errorf("token expired") } } // Try in-memory user cache first to avoid a DB round-trip if cached := m.userCache.Get(userID); cached != nil { if !cached.IsActive { m.cache.InvalidateAuthToken(ctx, token) m.userCache.Invalidate(userID) return nil, fmt.Errorf("user is inactive") } return cached, nil } // In-memory cache miss — fetch from database var user models.User if err := m.db.First(&user, userID).Error; err != nil { // User was deleted - invalidate caches m.cache.InvalidateAuthToken(ctx, token) return nil, err } // Check if user is active if !user.IsActive { m.cache.InvalidateAuthToken(ctx, token) return nil, fmt.Errorf("user is inactive") } // Store in in-memory cache for subsequent requests m.userCache.Set(&user) return &user, nil } // getUserFromDatabaseWithToken looks up the token in the database and returns // both the user and the auth token record (for expiry checking). func (m *AuthMiddleware) getUserFromDatabaseWithToken(token string) (*models.User, *models.AuthToken, error) { var authToken models.AuthToken if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil { return nil, nil, fmt.Errorf("token not found") } // Check if user is active if !authToken.User.IsActive { return nil, nil, fmt.Errorf("user is inactive") } // Store in in-memory cache for subsequent requests m.userCache.Set(&authToken.User) return &authToken.User, &authToken, nil } // getUserFromDatabase looks up the token in the database and caches the // resulting user record in memory. // Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking. func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { user, _, err := m.getUserFromDatabaseWithToken(token) return user, err } // cacheTokenInfo caches the user ID and token creation time for a token func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error { if m.cache == nil { return nil } return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix()) } // cacheUserID caches the user ID for a token func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error { if m.cache == nil { return nil } return m.cache.CacheAuthToken(ctx, token, userID) } // InvalidateToken removes a token from the cache func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error { if m.cache == nil { return nil } return m.cache.InvalidateAuthToken(ctx, token) } // GetAuthUser retrieves the authenticated user from the Echo context. // Returns nil if the context value is missing or not of the expected type. func GetAuthUser(c echo.Context) *models.User { val := c.Get(AuthUserKey) if val == nil { return nil } user, ok := val.(*models.User) if !ok { return nil } return user } // GetAuthToken retrieves the auth token from the Echo context func GetAuthToken(c echo.Context) string { token := c.Get(AuthTokenKey) if token == nil { return "" } tokenStr, ok := token.(string) if !ok { return "" } return tokenStr } // MustGetAuthUser retrieves the authenticated user or returns error with 401 func MustGetAuthUser(c echo.Context) (*models.User, error) { user := GetAuthUser(c) if user == nil { return nil, apperrors.Unauthorized("error.not_authenticated") } return user, nil } // truncateToken safely truncates a token string for logging. // Returns at most the first 8 characters followed by "...". func truncateToken(token string) string { if len(token) > 8 { return token[:8] + "..." } return token + "..." }