package repositories import ( "context" "errors" "strings" "time" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/models" ) // FindByKratosID finds a user by Kratos identity UUID. func (r *UserRepository) FindByKratosID(kratosID string) (*models.User, error) { var user models.User if err := r.db.Where("kratos_id = ?", kratosID).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } var ( ErrUserNotFound = errors.New("user not found") ErrUserExists = errors.New("user already exists") ) // UserRepository handles user-related database operations type UserRepository struct { db *gorm.DB } // NewUserRepository creates a new user repository func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } // DB returns the underlying *gorm.DB connection. This is useful when callers // need to pass the connection (e.g., a transaction) to methods that accept *gorm.DB. func (r *UserRepository) DB() *gorm.DB { return r.db } // Transaction runs fn inside a database transaction. The callback receives a // new UserRepository backed by the transaction so all operations within fn // share the same transactional connection. func (r *UserRepository) Transaction(fn func(txRepo *UserRepository) error) error { return r.db.Transaction(func(tx *gorm.DB) error { txRepo := &UserRepository{db: tx} return fn(txRepo) }) } // FindByID finds a user by ID func (r *UserRepository) FindByID(id uint) (*models.User, error) { var user models.User if err := r.db.First(&user, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByIDWithProfile finds a user by ID with profile preloaded func (r *UserRepository) FindByIDWithProfile(id uint) (*models.User, error) { var user models.User if err := r.db.Preload("Profile").First(&user, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByUsername finds a user by username (case-insensitive) func (r *UserRepository) FindByUsername(username string) (*models.User, error) { var user models.User if err := r.db.Where("LOWER(username) = LOWER(?)", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByEmail finds a user by email (case-insensitive) func (r *UserRepository) FindByEmail(email string) (*models.User, error) { var user models.User if err := r.db.Where("LOWER(email) = LOWER(?)", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByUsernameOrEmail finds a user by username or email with profile preloaded func (r *UserRepository) FindByUsernameOrEmail(identifier string) (*models.User, error) { var user models.User if err := r.db.Preload("Profile").Where("LOWER(username) = LOWER(?) OR LOWER(email) = LOWER(?)", identifier, identifier).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // Create creates a new user func (r *UserRepository) Create(user *models.User) error { return r.db.Create(user).Error } // Update updates a user func (r *UserRepository) Update(user *models.User) error { return r.db.Save(user).Error } // UpdateLastLogin updates the user's last login timestamp func (r *UserRepository) UpdateLastLogin(userID uint) error { now := time.Now().UTC() return r.db.Model(&models.User{}).Where("id = ?", userID).Update("last_login", now).Error } // ExistsByUsername checks if a username exists func (r *UserRepository) ExistsByUsername(username string) (bool, error) { var count int64 if err := r.db.Model(&models.User{}).Where("LOWER(username) = LOWER(?)", username).Count(&count).Error; err != nil { return false, err } return count > 0, nil } // ExistsByEmail checks if an email exists func (r *UserRepository) ExistsByEmail(email string) (bool, error) { var count int64 if err := r.db.Model(&models.User{}).Where("LOWER(email) = LOWER(?)", email).Count(&count).Error; err != nil { return false, err } return count > 0, nil } // --- User Profile Methods --- // GetOrCreateProfile gets or creates a user profile func (r *UserRepository) GetOrCreateProfile(userID uint) (*models.UserProfile, error) { var profile models.UserProfile result := r.db.Where("user_id = ?", userID).First(&profile) if errors.Is(result.Error, gorm.ErrRecordNotFound) { profile = models.UserProfile{UserID: userID} if err := r.db.Create(&profile).Error; err != nil { return nil, err } } else if result.Error != nil { return nil, result.Error } return &profile, nil } // UpdateProfile updates a user profile func (r *UserRepository) UpdateProfile(profile *models.UserProfile) error { return r.db.Save(profile).Error } // SetProfileVerified sets the profile verified status func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error { return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error } // --- Search Methods --- // SearchUsers searches users by username, email, first name, or last name func (r *UserRepository) SearchUsers(query string, limit, offset int) ([]models.User, int64, error) { var users []models.User var total int64 searchQuery := "%" + escapeLikeWildcards(strings.ToLower(query)) + "%" baseQuery := r.db.Model(&models.User{}). Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?", searchQuery, searchQuery, searchQuery, searchQuery) if err := baseQuery.Count(&total).Error; err != nil { return nil, 0, err } if err := baseQuery.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // ListUsers lists all users with pagination func (r *UserRepository) ListUsers(limit, offset int) ([]models.User, int64, error) { var users []models.User var total int64 if err := r.db.Model(&models.User{}).Count(&total).Error; err != nil { return nil, 0, err } if err := r.db.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // FindUsersInSharedResidences finds users that share at least one residence with the given user func (r *UserRepository) FindUsersInSharedResidences(userID uint) ([]models.User, error) { var users []models.User // Find all users that share a residence with the given user // This includes: // 1. Owners of residences where current user is a member // 2. Members of residences owned by current user // 3. Members of residences where current user is also a member err := r.db.Raw(` SELECT DISTINCT u.* FROM auth_user u WHERE u.id != ? AND u.is_active = true AND ( -- Users who own residences where current user is a shared user u.id IN ( SELECT r.owner_id FROM residence_residence r INNER JOIN residence_residence_users ru ON r.id = ru.residence_id WHERE ru.user_id = ? AND r.is_active = true ) OR -- Users who are shared users of residences owned by current user u.id IN ( SELECT ru.user_id FROM residence_residence_users ru INNER JOIN residence_residence r ON ru.residence_id = r.id WHERE r.owner_id = ? AND r.is_active = true ) OR -- Users who share a residence with current user (both are shared users) u.id IN ( SELECT ru2.user_id FROM residence_residence_users ru1 INNER JOIN residence_residence_users ru2 ON ru1.residence_id = ru2.residence_id WHERE ru1.user_id = ? AND ru2.user_id != ? ) ) `, userID, userID, userID, userID, userID).Scan(&users).Error return users, err } // FindUserIfSharedResidence finds a user if they share a residence with the requesting user func (r *UserRepository) FindUserIfSharedResidence(targetUserID, requestingUserID uint) (*models.User, error) { var user models.User err := r.db.Raw(` SELECT u.* FROM auth_user u WHERE u.id = ? AND u.is_active = true AND ( u.id = ? OR -- Target owns a residence where requester is a member u.id IN ( SELECT r.owner_id FROM residence_residence r INNER JOIN residence_residence_users ru ON r.id = ru.residence_id WHERE ru.user_id = ? AND r.is_active = true ) OR -- Target is a member of a residence owned by requester u.id IN ( SELECT ru.user_id FROM residence_residence_users ru INNER JOIN residence_residence r ON ru.residence_id = r.id WHERE r.owner_id = ? AND r.is_active = true ) OR -- Target shares a residence with requester (both are shared users) u.id IN ( SELECT ru2.user_id FROM residence_residence_users ru1 INNER JOIN residence_residence_users ru2 ON ru1.residence_id = ru2.residence_id WHERE ru1.user_id = ? ) ) LIMIT 1 `, targetUserID, requestingUserID, requestingUserID, requestingUserID, requestingUserID).Scan(&user).Error if err != nil { return nil, err } if user.ID == 0 { return nil, nil } return &user, nil } // FindProfilesInSharedResidences finds user profiles for users in shared residences func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.UserProfile, error) { var profiles []models.UserProfile err := r.db.Raw(` SELECT p.* FROM user_userprofile p INNER JOIN auth_user u ON p.user_id = u.id WHERE u.is_active = true AND ( u.id = ? OR -- Users who own residences where current user is a shared user u.id IN ( SELECT r.owner_id FROM residence_residence r INNER JOIN residence_residence_users ru ON r.id = ru.residence_id WHERE ru.user_id = ? AND r.is_active = true ) OR -- Users who are shared users of residences owned by current user u.id IN ( SELECT ru.user_id FROM residence_residence_users ru INNER JOIN residence_residence r ON ru.residence_id = r.id WHERE r.owner_id = ? AND r.is_active = true ) OR -- Users who share a residence with current user (both are shared users) u.id IN ( SELECT ru2.user_id FROM residence_residence_users ru1 INNER JOIN residence_residence_users ru2 ON ru1.residence_id = ru2.residence_id WHERE ru1.user_id = ? ) ) `, userID, userID, userID, userID).Scan(&profiles).Error return profiles, err } // FindAuthProvider returns "kratos" for all Kratos-managed users (the sole // provider after the Ory Kratos migration). Kept for compatibility with // callers that still check the provider string. func (r *UserRepository) FindAuthProvider(_ uint) (string, error) { return "kratos", nil } // --- Account Deletion --- // DeleteUserCascade deletes a user and all related records in dependency order. // Should be called on a repository backed by a transaction (via Transaction callback). // Returns a list of file URLs that need to be deleted from disk after the transaction commits. func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) { var fileURLs []string db := r.db // 1. Push notification devices if err := db.Where("user_id = ?", userID).Delete(&models.APNSDevice{}).Error; err != nil { return nil, err } if err := db.Where("user_id = ?", userID).Delete(&models.GCMDevice{}).Error; err != nil { return nil, err } // 2. Notifications if err := db.Where("user_id = ?", userID).Delete(&models.Notification{}).Error; err != nil { return nil, err } // 3. Notification preferences if err := db.Where("user_id = ?", userID).Delete(&models.NotificationPreference{}).Error; err != nil { return nil, err } // 4. Task reminder logs if err := db.Where("user_id = ?", userID).Delete(&models.TaskReminderLog{}).Error; err != nil { return nil, err } // 5. Find residences owned by user var ownedResidences []models.Residence if err := db.Where("owner_id = ?", userID).Find(&ownedResidences).Error; err != nil { return nil, err } for _, residence := range ownedResidences { // Collect file URLs before deleting // Task completion images (via completion_id -> task_id -> residence_id) var completionImageURLs []string db.Model(&models.TaskCompletionImage{}). Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id"). Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id"). Where("task_task.residence_id = ?", residence.ID). Pluck("task_taskcompletionimage.image_url", &completionImageURLs) fileURLs = append(fileURLs, completionImageURLs...) // Delete task completion images db.Exec(`DELETE FROM task_taskcompletionimage WHERE completion_id IN ( SELECT tc.id FROM task_taskcompletion tc JOIN task_task t ON t.id = tc.task_id WHERE t.residence_id = ? )`, residence.ID) // Delete task completions db.Exec(`DELETE FROM task_taskcompletion WHERE task_id IN ( SELECT id FROM task_task WHERE residence_id = ? )`, residence.ID) // Document images (via document_id -> residence_id) var docImageURLs []string db.Model(&models.DocumentImage{}). Joins("JOIN task_document ON task_document.id = task_documentimage.document_id"). Where("task_document.residence_id = ?", residence.ID). Pluck("task_documentimage.image_url", &docImageURLs) fileURLs = append(fileURLs, docImageURLs...) // Delete document images db.Exec(`DELETE FROM task_documentimage WHERE document_id IN ( SELECT id FROM task_document WHERE residence_id = ? )`, residence.ID) // Document file URLs var docFileURLs []string db.Model(&models.Document{}).Where("residence_id = ?", residence.ID).Pluck("file_url", &docFileURLs) fileURLs = append(fileURLs, docFileURLs...) // Delete documents if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Document{}).Error; err != nil { return nil, err } // Delete tasks if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Task{}).Error; err != nil { return nil, err } // Delete contractor specialties (many-to-many join table) db.Exec(`DELETE FROM task_contractor_specialties WHERE contractor_id IN ( SELECT id FROM task_contractor WHERE residence_id = ? )`, residence.ID) // Delete contractors if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Contractor{}).Error; err != nil { return nil, err } // Delete share codes if err := db.Where("residence_id = ?", residence.ID).Delete(&models.ResidenceShareCode{}).Error; err != nil { return nil, err } // Remove residence membership records (many-to-many join table) db.Exec("DELETE FROM residence_residence_users WHERE residence_id = ?", residence.ID) // Delete the residence itself if err := db.Delete(&residence).Error; err != nil { return nil, err } } // 6. Remove user from shared residences they don't own (membership only) db.Exec("DELETE FROM residence_residence_users WHERE user_id = ?", userID) // 7. Subscription if err := db.Where("user_id = ?", userID).Delete(&models.UserSubscription{}).Error; err != nil { return nil, err } // 8. User profile if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil { return nil, err } // 9. User if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil { return nil, err } // Filter out empty URLs var cleanURLs []string for _, url := range fileURLs { if url != "" { cleanURLs = append(cleanURLs, url) } } return cleanURLs, nil } // WithContext returns a copy of the repository whose underlying *gorm.DB carries // the supplied context. SQL emitted via this copy gets attached to ctx's trace span // (when otelgorm is registered) and respects ctx cancellation/deadlines. func (r *UserRepository) WithContext(ctx context.Context) *UserRepository { return &UserRepository{db: r.db.WithContext(ctx)} }