package repositories import ( "errors" "time" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/models" ) // NotificationRepository handles database operations for notifications type NotificationRepository struct { db *gorm.DB } // NewNotificationRepository creates a new notification repository func NewNotificationRepository(db *gorm.DB) *NotificationRepository { return &NotificationRepository{db: db} } // === Notifications === // FindByID finds a notification by ID func (r *NotificationRepository) FindByID(id uint) (*models.Notification, error) { var notification models.Notification err := r.db.First(¬ification, id).Error if err != nil { return nil, err } return ¬ification, nil } // FindByUser finds all notifications for a user func (r *NotificationRepository) FindByUser(userID uint, limit, offset int) ([]models.Notification, error) { var notifications []models.Notification query := r.db.Where("user_id = ?", userID). Order("created_at DESC") if limit > 0 { query = query.Limit(limit) } if offset > 0 { query = query.Offset(offset) } err := query.Find(¬ifications).Error return notifications, err } // Create creates a new notification func (r *NotificationRepository) Create(notification *models.Notification) error { return r.db.Create(notification).Error } // MarkAsRead marks a notification as read func (r *NotificationRepository) MarkAsRead(id uint) error { now := time.Now().UTC() return r.db.Model(&models.Notification{}). Where("id = ?", id). Updates(map[string]interface{}{ "read": true, "read_at": now, }).Error } // MarkAllAsRead marks all notifications for a user as read func (r *NotificationRepository) MarkAllAsRead(userID uint) error { now := time.Now().UTC() return r.db.Model(&models.Notification{}). Where("user_id = ? AND read = ?", userID, false). Updates(map[string]interface{}{ "read": true, "read_at": now, }).Error } // MarkAsSent marks a notification as sent func (r *NotificationRepository) MarkAsSent(id uint) error { now := time.Now().UTC() return r.db.Model(&models.Notification{}). Where("id = ?", id). Updates(map[string]interface{}{ "sent": true, "sent_at": now, }).Error } // SetError sets an error message on a notification func (r *NotificationRepository) SetError(id uint, errorMsg string) error { return r.db.Model(&models.Notification{}). Where("id = ?", id). Update("error_message", errorMsg).Error } // CountUnread counts unread notifications for a user func (r *NotificationRepository) CountUnread(userID uint) (int64, error) { var count int64 err := r.db.Model(&models.Notification{}). Where("user_id = ? AND read = ?", userID, false). Count(&count).Error return count, err } // GetPendingNotifications gets notifications that need to be sent func (r *NotificationRepository) GetPendingNotifications(limit int) ([]models.Notification, error) { var notifications []models.Notification err := r.db.Where("sent = ?", false). Order("created_at ASC"). Limit(limit). Find(¬ifications).Error return notifications, err } // === Notification Preferences === // FindPreferencesByUser finds notification preferences for a user func (r *NotificationRepository) FindPreferencesByUser(userID uint) (*models.NotificationPreference, error) { var prefs models.NotificationPreference err := r.db.Where("user_id = ?", userID).First(&prefs).Error if err != nil { return nil, err } return &prefs, nil } // CreatePreferences creates notification preferences for a user func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPreference) error { return r.db.Create(prefs).Error } // UpdatePreferences updates notification preferences func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error { return r.db.Omit("User").Save(prefs).Error } // GetOrCreatePreferences gets or creates notification preferences for a user. // Uses a transaction to avoid TOCTOU race conditions on concurrent requests. func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) { var prefs models.NotificationPreference err := r.db.Transaction(func(tx *gorm.DB) error { err := tx.Where("user_id = ?", userID).First(&prefs).Error if err == nil { return nil // Found existing preferences } if !errors.Is(err, gorm.ErrRecordNotFound) { return err // Unexpected error } // Record not found -- create with defaults prefs = models.NotificationPreference{ UserID: userID, TaskDueSoon: true, TaskOverdue: true, TaskCompleted: true, TaskAssigned: true, ResidenceShared: true, WarrantyExpiring: true, EmailTaskCompleted: true, } return tx.Create(&prefs).Error }) if err != nil { return nil, err } return &prefs, nil } // === Device Registration === // FindAPNSDeviceByID finds an APNS device by ID func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) { var device models.APNSDevice err := r.db.First(&device, id).Error if err != nil { return nil, err } return &device, nil } // FindGCMDeviceByID finds a GCM device by ID func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) { var device models.GCMDevice err := r.db.First(&device, id).Error if err != nil { return nil, err } return &device, nil } // FindAPNSDeviceByToken finds an APNS device by registration token func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) { var device models.APNSDevice err := r.db.Where("registration_id = ?", token).First(&device).Error if err != nil { return nil, err } return &device, nil } // FindAPNSDevicesByUser finds all APNS devices for a user func (r *NotificationRepository) FindAPNSDevicesByUser(userID uint) ([]models.APNSDevice, error) { var devices []models.APNSDevice err := r.db.Where("user_id = ? AND active = ?", userID, true).Find(&devices).Error return devices, err } // CreateAPNSDevice creates a new APNS device func (r *NotificationRepository) CreateAPNSDevice(device *models.APNSDevice) error { return r.db.Create(device).Error } // UpdateAPNSDevice updates an APNS device func (r *NotificationRepository) UpdateAPNSDevice(device *models.APNSDevice) error { return r.db.Save(device).Error } // DeleteAPNSDevice deletes an APNS device func (r *NotificationRepository) DeleteAPNSDevice(id uint) error { return r.db.Delete(&models.APNSDevice{}, id).Error } // DeactivateAPNSDevice deactivates an APNS device func (r *NotificationRepository) DeactivateAPNSDevice(id uint) error { return r.db.Model(&models.APNSDevice{}). Where("id = ?", id). Update("active", false).Error } // FindGCMDeviceByToken finds a GCM device by registration token func (r *NotificationRepository) FindGCMDeviceByToken(token string) (*models.GCMDevice, error) { var device models.GCMDevice err := r.db.Where("registration_id = ?", token).First(&device).Error if err != nil { return nil, err } return &device, nil } // FindGCMDevicesByUser finds all GCM devices for a user func (r *NotificationRepository) FindGCMDevicesByUser(userID uint) ([]models.GCMDevice, error) { var devices []models.GCMDevice err := r.db.Where("user_id = ? AND active = ?", userID, true).Find(&devices).Error return devices, err } // CreateGCMDevice creates a new GCM device func (r *NotificationRepository) CreateGCMDevice(device *models.GCMDevice) error { return r.db.Create(device).Error } // UpdateGCMDevice updates a GCM device func (r *NotificationRepository) UpdateGCMDevice(device *models.GCMDevice) error { return r.db.Save(device).Error } // DeleteGCMDevice deletes a GCM device func (r *NotificationRepository) DeleteGCMDevice(id uint) error { return r.db.Delete(&models.GCMDevice{}, id).Error } // DeactivateGCMDevice deactivates a GCM device func (r *NotificationRepository) DeactivateGCMDevice(id uint) error { return r.db.Model(&models.GCMDevice{}). Where("id = ?", id). Update("active", false).Error } // GetActiveTokensForUser gets all active push tokens for a user func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) { apnsDevices, err := r.FindAPNSDevicesByUser(userID) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, err } gcmDevices, err := r.FindGCMDevicesByUser(userID) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, err } iosTokens = make([]string, 0, len(apnsDevices)) for _, d := range apnsDevices { iosTokens = append(iosTokens, d.RegistrationID) } androidTokens = make([]string, 0, len(gcmDevices)) for _, d := range gcmDevices { androidTokens = append(androidTokens, d.RegistrationID) } return iosTokens, androidTokens, nil }