Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
294 lines
8.7 KiB
Go
294 lines
8.7 KiB
Go
package repositories
|
|
|
|
import (
|
|
"errors"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/treytartt/casera-api/internal/models"
|
|
)
|
|
|
|
// NotificationRepository handles database operations for notifications
|
|
type NotificationRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewNotificationRepository creates a new notification repository
|
|
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
|
|
return &NotificationRepository{db: db}
|
|
}
|
|
|
|
// === Notifications ===
|
|
|
|
// FindByID finds a notification by ID
|
|
func (r *NotificationRepository) FindByID(id uint) (*models.Notification, error) {
|
|
var notification models.Notification
|
|
err := r.db.First(¬ification, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ¬ification, nil
|
|
}
|
|
|
|
// FindByUser finds all notifications for a user
|
|
func (r *NotificationRepository) FindByUser(userID uint, limit, offset int) ([]models.Notification, error) {
|
|
var notifications []models.Notification
|
|
query := r.db.Where("user_id = ?", userID).
|
|
Order("created_at DESC")
|
|
|
|
if limit > 0 {
|
|
query = query.Limit(limit)
|
|
}
|
|
if offset > 0 {
|
|
query = query.Offset(offset)
|
|
}
|
|
|
|
err := query.Find(¬ifications).Error
|
|
return notifications, err
|
|
}
|
|
|
|
// Create creates a new notification
|
|
func (r *NotificationRepository) Create(notification *models.Notification) error {
|
|
return r.db.Create(notification).Error
|
|
}
|
|
|
|
// MarkAsRead marks a notification as read
|
|
func (r *NotificationRepository) MarkAsRead(id uint) error {
|
|
now := time.Now().UTC()
|
|
return r.db.Model(&models.Notification{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"read": true,
|
|
"read_at": now,
|
|
}).Error
|
|
}
|
|
|
|
// MarkAllAsRead marks all notifications for a user as read
|
|
func (r *NotificationRepository) MarkAllAsRead(userID uint) error {
|
|
now := time.Now().UTC()
|
|
return r.db.Model(&models.Notification{}).
|
|
Where("user_id = ? AND read = ?", userID, false).
|
|
Updates(map[string]interface{}{
|
|
"read": true,
|
|
"read_at": now,
|
|
}).Error
|
|
}
|
|
|
|
// MarkAsSent marks a notification as sent
|
|
func (r *NotificationRepository) MarkAsSent(id uint) error {
|
|
now := time.Now().UTC()
|
|
return r.db.Model(&models.Notification{}).
|
|
Where("id = ?", id).
|
|
Updates(map[string]interface{}{
|
|
"sent": true,
|
|
"sent_at": now,
|
|
}).Error
|
|
}
|
|
|
|
// SetError sets an error message on a notification
|
|
func (r *NotificationRepository) SetError(id uint, errorMsg string) error {
|
|
return r.db.Model(&models.Notification{}).
|
|
Where("id = ?", id).
|
|
Update("error_message", errorMsg).Error
|
|
}
|
|
|
|
// CountUnread counts unread notifications for a user
|
|
func (r *NotificationRepository) CountUnread(userID uint) (int64, error) {
|
|
var count int64
|
|
err := r.db.Model(&models.Notification{}).
|
|
Where("user_id = ? AND read = ?", userID, false).
|
|
Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
// GetPendingNotifications gets notifications that need to be sent
|
|
func (r *NotificationRepository) GetPendingNotifications(limit int) ([]models.Notification, error) {
|
|
var notifications []models.Notification
|
|
err := r.db.Where("sent = ?", false).
|
|
Order("created_at ASC").
|
|
Limit(limit).
|
|
Find(¬ifications).Error
|
|
return notifications, err
|
|
}
|
|
|
|
// === Notification Preferences ===
|
|
|
|
// FindPreferencesByUser finds notification preferences for a user
|
|
func (r *NotificationRepository) FindPreferencesByUser(userID uint) (*models.NotificationPreference, error) {
|
|
var prefs models.NotificationPreference
|
|
err := r.db.Where("user_id = ?", userID).First(&prefs).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &prefs, nil
|
|
}
|
|
|
|
// CreatePreferences creates notification preferences for a user
|
|
func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPreference) error {
|
|
return r.db.Create(prefs).Error
|
|
}
|
|
|
|
// UpdatePreferences updates notification preferences
|
|
func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error {
|
|
return r.db.Omit("User").Save(prefs).Error
|
|
}
|
|
|
|
// GetOrCreatePreferences gets or creates notification preferences for a user.
|
|
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
|
|
func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) {
|
|
var prefs models.NotificationPreference
|
|
|
|
err := r.db.Transaction(func(tx *gorm.DB) error {
|
|
err := tx.Where("user_id = ?", userID).First(&prefs).Error
|
|
if err == nil {
|
|
return nil // Found existing preferences
|
|
}
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err // Unexpected error
|
|
}
|
|
|
|
// Record not found -- create with defaults
|
|
prefs = models.NotificationPreference{
|
|
UserID: userID,
|
|
TaskDueSoon: true,
|
|
TaskOverdue: true,
|
|
TaskCompleted: true,
|
|
TaskAssigned: true,
|
|
ResidenceShared: true,
|
|
WarrantyExpiring: true,
|
|
EmailTaskCompleted: true,
|
|
}
|
|
return tx.Create(&prefs).Error
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &prefs, nil
|
|
}
|
|
|
|
// === Device Registration ===
|
|
|
|
// FindAPNSDeviceByID finds an APNS device by ID
|
|
func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) {
|
|
var device models.APNSDevice
|
|
err := r.db.First(&device, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &device, nil
|
|
}
|
|
|
|
// FindGCMDeviceByID finds a GCM device by ID
|
|
func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) {
|
|
var device models.GCMDevice
|
|
err := r.db.First(&device, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &device, nil
|
|
}
|
|
|
|
// FindAPNSDeviceByToken finds an APNS device by registration token
|
|
func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) {
|
|
var device models.APNSDevice
|
|
err := r.db.Where("registration_id = ?", token).First(&device).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &device, nil
|
|
}
|
|
|
|
// FindAPNSDevicesByUser finds all APNS devices for a user
|
|
func (r *NotificationRepository) FindAPNSDevicesByUser(userID uint) ([]models.APNSDevice, error) {
|
|
var devices []models.APNSDevice
|
|
err := r.db.Where("user_id = ? AND active = ?", userID, true).Find(&devices).Error
|
|
return devices, err
|
|
}
|
|
|
|
// CreateAPNSDevice creates a new APNS device
|
|
func (r *NotificationRepository) CreateAPNSDevice(device *models.APNSDevice) error {
|
|
return r.db.Create(device).Error
|
|
}
|
|
|
|
// UpdateAPNSDevice updates an APNS device
|
|
func (r *NotificationRepository) UpdateAPNSDevice(device *models.APNSDevice) error {
|
|
return r.db.Save(device).Error
|
|
}
|
|
|
|
// DeleteAPNSDevice deletes an APNS device
|
|
func (r *NotificationRepository) DeleteAPNSDevice(id uint) error {
|
|
return r.db.Delete(&models.APNSDevice{}, id).Error
|
|
}
|
|
|
|
// DeactivateAPNSDevice deactivates an APNS device
|
|
func (r *NotificationRepository) DeactivateAPNSDevice(id uint) error {
|
|
return r.db.Model(&models.APNSDevice{}).
|
|
Where("id = ?", id).
|
|
Update("active", false).Error
|
|
}
|
|
|
|
// FindGCMDeviceByToken finds a GCM device by registration token
|
|
func (r *NotificationRepository) FindGCMDeviceByToken(token string) (*models.GCMDevice, error) {
|
|
var device models.GCMDevice
|
|
err := r.db.Where("registration_id = ?", token).First(&device).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &device, nil
|
|
}
|
|
|
|
// FindGCMDevicesByUser finds all GCM devices for a user
|
|
func (r *NotificationRepository) FindGCMDevicesByUser(userID uint) ([]models.GCMDevice, error) {
|
|
var devices []models.GCMDevice
|
|
err := r.db.Where("user_id = ? AND active = ?", userID, true).Find(&devices).Error
|
|
return devices, err
|
|
}
|
|
|
|
// CreateGCMDevice creates a new GCM device
|
|
func (r *NotificationRepository) CreateGCMDevice(device *models.GCMDevice) error {
|
|
return r.db.Create(device).Error
|
|
}
|
|
|
|
// UpdateGCMDevice updates a GCM device
|
|
func (r *NotificationRepository) UpdateGCMDevice(device *models.GCMDevice) error {
|
|
return r.db.Save(device).Error
|
|
}
|
|
|
|
// DeleteGCMDevice deletes a GCM device
|
|
func (r *NotificationRepository) DeleteGCMDevice(id uint) error {
|
|
return r.db.Delete(&models.GCMDevice{}, id).Error
|
|
}
|
|
|
|
// DeactivateGCMDevice deactivates a GCM device
|
|
func (r *NotificationRepository) DeactivateGCMDevice(id uint) error {
|
|
return r.db.Model(&models.GCMDevice{}).
|
|
Where("id = ?", id).
|
|
Update("active", false).Error
|
|
}
|
|
|
|
// GetActiveTokensForUser gets all active push tokens for a user
|
|
func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) {
|
|
apnsDevices, err := r.FindAPNSDevicesByUser(userID)
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil, err
|
|
}
|
|
|
|
gcmDevices, err := r.FindGCMDevicesByUser(userID)
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil, err
|
|
}
|
|
|
|
iosTokens = make([]string, 0, len(apnsDevices))
|
|
for _, d := range apnsDevices {
|
|
iosTokens = append(iosTokens, d.RegistrationID)
|
|
}
|
|
|
|
androidTokens = make([]string, 0, len(gcmDevices))
|
|
for _, d := range gcmDevices {
|
|
androidTokens = append(androidTokens, d.RegistrationID)
|
|
}
|
|
|
|
return iosTokens, androidTokens, nil
|
|
}
|