package repositories import ( "errors" "strings" "time" "gorm.io/gorm" "github.com/treytartt/mycrib-api/internal/models" ) var ( ErrUserNotFound = errors.New("user not found") ErrUserExists = errors.New("user already exists") ErrInvalidToken = errors.New("invalid token") ErrTokenNotFound = errors.New("token not found") ErrCodeNotFound = errors.New("code not found") ErrCodeExpired = errors.New("code expired") ErrCodeUsed = errors.New("code already used") ErrTooManyAttempts = errors.New("too many attempts") ErrRateLimitExceeded = errors.New("rate limit exceeded") ) // UserRepository handles user-related database operations type UserRepository struct { db *gorm.DB } // NewUserRepository creates a new user repository func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } // FindByID finds a user by ID func (r *UserRepository) FindByID(id uint) (*models.User, error) { var user models.User if err := r.db.First(&user, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByIDWithProfile finds a user by ID with profile preloaded func (r *UserRepository) FindByIDWithProfile(id uint) (*models.User, error) { var user models.User if err := r.db.Preload("Profile").First(&user, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByUsername finds a user by username (case-insensitive) func (r *UserRepository) FindByUsername(username string) (*models.User, error) { var user models.User if err := r.db.Where("LOWER(username) = LOWER(?)", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByEmail finds a user by email (case-insensitive) func (r *UserRepository) FindByEmail(email string) (*models.User, error) { var user models.User if err := r.db.Where("LOWER(email) = LOWER(?)", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // FindByUsernameOrEmail finds a user by username or email func (r *UserRepository) FindByUsernameOrEmail(identifier string) (*models.User, error) { var user models.User if err := r.db.Where("LOWER(username) = LOWER(?) OR LOWER(email) = LOWER(?)", identifier, identifier).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err } return &user, nil } // Create creates a new user func (r *UserRepository) Create(user *models.User) error { return r.db.Create(user).Error } // Update updates a user func (r *UserRepository) Update(user *models.User) error { return r.db.Save(user).Error } // UpdateLastLogin updates the user's last login timestamp func (r *UserRepository) UpdateLastLogin(userID uint) error { now := time.Now().UTC() return r.db.Model(&models.User{}).Where("id = ?", userID).Update("last_login", now).Error } // ExistsByUsername checks if a username exists func (r *UserRepository) ExistsByUsername(username string) (bool, error) { var count int64 if err := r.db.Model(&models.User{}).Where("LOWER(username) = LOWER(?)", username).Count(&count).Error; err != nil { return false, err } return count > 0, nil } // ExistsByEmail checks if an email exists func (r *UserRepository) ExistsByEmail(email string) (bool, error) { var count int64 if err := r.db.Model(&models.User{}).Where("LOWER(email) = LOWER(?)", email).Count(&count).Error; err != nil { return false, err } return count > 0, nil } // --- Auth Token Methods --- // GetOrCreateToken gets or creates an auth token for a user func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) { var token models.AuthToken result := r.db.Where("user_id = ?", userID).First(&token) if errors.Is(result.Error, gorm.ErrRecordNotFound) { token = models.AuthToken{UserID: userID} if err := r.db.Create(&token).Error; err != nil { return nil, err } } else if result.Error != nil { return nil, result.Error } return &token, nil } // DeleteToken deletes an auth token func (r *UserRepository) DeleteToken(token string) error { result := r.db.Where("key = ?", token).Delete(&models.AuthToken{}) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return ErrTokenNotFound } return nil } // DeleteTokenByUserID deletes an auth token by user ID func (r *UserRepository) DeleteTokenByUserID(userID uint) error { return r.db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error } // --- User Profile Methods --- // GetOrCreateProfile gets or creates a user profile func (r *UserRepository) GetOrCreateProfile(userID uint) (*models.UserProfile, error) { var profile models.UserProfile result := r.db.Where("user_id = ?", userID).First(&profile) if errors.Is(result.Error, gorm.ErrRecordNotFound) { profile = models.UserProfile{UserID: userID} if err := r.db.Create(&profile).Error; err != nil { return nil, err } } else if result.Error != nil { return nil, result.Error } return &profile, nil } // UpdateProfile updates a user profile func (r *UserRepository) UpdateProfile(profile *models.UserProfile) error { return r.db.Save(profile).Error } // SetProfileVerified sets the profile verified status func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error { return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error } // --- Confirmation Code Methods --- // CreateConfirmationCode creates a new confirmation code func (r *UserRepository) CreateConfirmationCode(userID uint, code string, expiresAt time.Time) (*models.ConfirmationCode, error) { // Invalidate any existing unused codes for this user r.db.Model(&models.ConfirmationCode{}). Where("user_id = ? AND is_used = ?", userID, false). Update("is_used", true) confirmCode := &models.ConfirmationCode{ UserID: userID, Code: code, ExpiresAt: expiresAt, IsUsed: false, } if err := r.db.Create(confirmCode).Error; err != nil { return nil, err } return confirmCode, nil } // FindConfirmationCode finds a valid confirmation code for a user func (r *UserRepository) FindConfirmationCode(userID uint, code string) (*models.ConfirmationCode, error) { var confirmCode models.ConfirmationCode if err := r.db.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false). First(&confirmCode).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrCodeNotFound } return nil, err } if !confirmCode.IsValid() { if confirmCode.IsUsed { return nil, ErrCodeUsed } return nil, ErrCodeExpired } return &confirmCode, nil } // MarkConfirmationCodeUsed marks a confirmation code as used func (r *UserRepository) MarkConfirmationCodeUsed(codeID uint) error { return r.db.Model(&models.ConfirmationCode{}).Where("id = ?", codeID).Update("is_used", true).Error } // --- Password Reset Code Methods --- // CreatePasswordResetCode creates a new password reset code func (r *UserRepository) CreatePasswordResetCode(userID uint, codeHash string, resetToken string, expiresAt time.Time) (*models.PasswordResetCode, error) { // Invalidate any existing unused codes for this user r.db.Model(&models.PasswordResetCode{}). Where("user_id = ? AND used = ?", userID, false). Update("used", true) resetCode := &models.PasswordResetCode{ UserID: userID, CodeHash: codeHash, ResetToken: resetToken, ExpiresAt: expiresAt, Used: false, Attempts: 0, MaxAttempts: 5, } if err := r.db.Create(resetCode).Error; err != nil { return nil, err } return resetCode, nil } // FindPasswordResetCode finds a password reset code by email and checks validity func (r *UserRepository) FindPasswordResetCodeByEmail(email string) (*models.PasswordResetCode, *models.User, error) { user, err := r.FindByEmail(email) if err != nil { return nil, nil, err } var resetCode models.PasswordResetCode if err := r.db.Where("user_id = ? AND used = ?", user.ID, false). Order("created_at DESC"). First(&resetCode).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, ErrCodeNotFound } return nil, nil, err } return &resetCode, user, nil } // FindPasswordResetCodeByToken finds a password reset code by reset token func (r *UserRepository) FindPasswordResetCodeByToken(resetToken string) (*models.PasswordResetCode, error) { var resetCode models.PasswordResetCode if err := r.db.Where("reset_token = ?", resetToken).First(&resetCode).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrCodeNotFound } return nil, err } if !resetCode.IsValid() { if resetCode.Used { return nil, ErrCodeUsed } if resetCode.Attempts >= resetCode.MaxAttempts { return nil, ErrTooManyAttempts } return nil, ErrCodeExpired } return &resetCode, nil } // IncrementResetCodeAttempts increments the attempt counter func (r *UserRepository) IncrementResetCodeAttempts(codeID uint) error { return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID). Update("attempts", gorm.Expr("attempts + 1")).Error } // MarkPasswordResetCodeUsed marks a password reset code as used func (r *UserRepository) MarkPasswordResetCodeUsed(codeID uint) error { return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).Update("used", true).Error } // CountRecentPasswordResetRequests counts reset requests in the last hour func (r *UserRepository) CountRecentPasswordResetRequests(userID uint) (int64, error) { var count int64 oneHourAgo := time.Now().UTC().Add(-1 * time.Hour) if err := r.db.Model(&models.PasswordResetCode{}). Where("user_id = ? AND created_at > ?", userID, oneHourAgo). Count(&count).Error; err != nil { return 0, err } return count, nil } // --- Search Methods --- // SearchUsers searches users by username, email, first name, or last name func (r *UserRepository) SearchUsers(query string, limit, offset int) ([]models.User, int64, error) { var users []models.User var total int64 searchQuery := "%" + strings.ToLower(query) + "%" baseQuery := r.db.Model(&models.User{}). Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?", searchQuery, searchQuery, searchQuery, searchQuery) if err := baseQuery.Count(&total).Error; err != nil { return nil, 0, err } if err := baseQuery.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // ListUsers lists all users with pagination func (r *UserRepository) ListUsers(limit, offset int) ([]models.User, int64, error) { var users []models.User var total int64 if err := r.db.Model(&models.User{}).Count(&total).Error; err != nil { return nil, 0, err } if err := r.db.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil }