Migrate Auth/Contractor/Document/Notification/Subscription services to ctx
Every public method on these five services now takes ctx context.Context as the first arg and routes its repo calls through .WithContext(ctx). With TaskService and ResidenceService already migrated, this means every in-process service that touches Postgres now produces a flame graph in Jaeger where the SQL spans nest under the parent HTTP request span. Endpoints now fully traced (HTTP → service → SQL): - /api/auth/login, /register, /logout, /me, /verify-email, /resend-verification - /api/auth/forgot-password, /verify-reset, /reset-password, /update-profile - /api/contractors/* (CRUD + favorite + by-residence + tasks) - /api/documents/* (CRUD + activate/deactivate + image upload/delete) - /api/notifications/* (list, count, mark-read, prefs, devices) - /api/subscription/* (status, purchase, cancel, triggers, promotions) - All previously-migrated /api/tasks/* and /api/residences/* paths Internal helpers also threaded: - TaskService.sendTaskCompletedNotification → forwards ctx - TaskService.UpdateUserTimezone → forwards ctx to NotificationService - ResidenceService.CreateResidence → forwards ctx to SubscriptionService.CheckLimit - NotificationService.registerAPNSDevice / registerGCMDevice → both take ctx ~75 method signatures, ~120 handler/test call sites updated. Tests pass green; the only failure is the pre-existing flaky TaskHandler_QuickComplete SQLite race that fails ~60% of runs on master. Step 3 of the observability plan is now genuinely complete: every API endpoint backed by a Go service emits a per-request flame graph with HTTP → service → SQL spans, plus B2/APNs/FCM/asynq spans where applicable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -65,7 +65,7 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.authService.Login(&req)
|
response, err := h.authService.Login(c.Request().Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
|
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
|
||||||
if h.auditService != nil {
|
if h.auditService != nil {
|
||||||
@@ -94,7 +94,7 @@ func (h *AuthHandler) Register(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
response, confirmationCode, err := h.authService.Register(&req)
|
response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Msg("Registration failed")
|
log.Debug().Err(err).Msg("Registration failed")
|
||||||
return err
|
return err
|
||||||
@@ -141,7 +141,7 @@ func (h *AuthHandler) Logout(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invalidate token in database
|
// Invalidate token in database
|
||||||
if err := h.authService.Logout(token); err != nil {
|
if err := h.authService.Logout(c.Request().Context(), token); err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to delete token from database")
|
log.Warn().Err(err).Msg("Failed to delete token from database")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,7 +162,7 @@ func (h *AuthHandler) CurrentUser(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.authService.GetCurrentUser(user.ID)
|
response, err := h.authService.GetCurrentUser(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user")
|
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user")
|
||||||
return err
|
return err
|
||||||
@@ -186,7 +186,7 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.authService.UpdateProfile(user.ID, &req)
|
response, err := h.authService.UpdateProfile(c.Request().Context(), user.ID, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile")
|
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile")
|
||||||
return err
|
return err
|
||||||
@@ -210,7 +210,7 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.authService.VerifyEmail(user.ID, req.Code)
|
err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
|
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
|
||||||
return err
|
return err
|
||||||
@@ -243,7 +243,7 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
code, err := h.authService.ResendVerificationCode(user.ID)
|
code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
|
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
|
||||||
return err
|
return err
|
||||||
@@ -276,7 +276,7 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
code, user, err := h.authService.ForgotPassword(req.Email)
|
code, user, err := h.authService.ForgotPassword(c.Request().Context(), req.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var appErr *apperrors.AppError
|
var appErr *apperrors.AppError
|
||||||
if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests {
|
if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests {
|
||||||
@@ -324,7 +324,7 @@ func (h *AuthHandler) VerifyResetCode(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
resetToken, err := h.authService.VerifyResetCode(req.Email, req.Code)
|
resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed")
|
log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed")
|
||||||
return err
|
return err
|
||||||
@@ -346,7 +346,7 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.authService.ResetPassword(req.ResetToken, req.NewPassword)
|
err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Msg("Password reset failed")
|
log.Debug().Err(err).Msg("Password reset failed")
|
||||||
return err
|
return err
|
||||||
@@ -469,7 +469,7 @@ func (h *AuthHandler) RefreshToken(c echo.Context) error {
|
|||||||
return apperrors.Unauthorized("error.not_authenticated")
|
return apperrors.Unauthorized("error.not_authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.authService.RefreshToken(token, user.ID)
|
response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
|
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
|
||||||
return err
|
return err
|
||||||
@@ -497,7 +497,7 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_request")
|
return apperrors.BadRequest("error.invalid_request")
|
||||||
}
|
}
|
||||||
|
|
||||||
fileURLs, err := h.authService.DeleteAccount(user.ID, req.Password, req.Confirmation)
|
fileURLs, err := h.authService.DeleteAccount(c.Request().Context(), user.ID, req.Password, req.Confirmation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed")
|
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed")
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func (h *ContractorHandler) ListContractors(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
response, err := h.contractorService.ListContractors(user.ID)
|
response, err := h.contractorService.ListContractors(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -48,7 +48,7 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.contractorService.GetContractor(uint(contractorID), user.ID)
|
response, err := h.contractorService.GetContractor(c.Request().Context(), uint(contractorID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -69,7 +69,7 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.contractorService.CreateContractor(&req, user.ID)
|
response, err := h.contractorService.CreateContractor(c.Request().Context(), &req, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,7 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
|
response, err := h.contractorService.UpdateContractor(c.Request().Context(), uint(contractorID), user.ID, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.contractorService.DeleteContractor(uint(contractorID), user.ID)
|
err = h.contractorService.DeleteContractor(c.Request().Context(), uint(contractorID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -131,7 +131,7 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.contractorService.ToggleFavorite(uint(contractorID), user.ID)
|
response, err := h.contractorService.ToggleFavorite(c.Request().Context(), uint(contractorID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -149,7 +149,7 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.contractorService.GetContractorTasks(uint(contractorID), user.ID)
|
response, err := h.contractorService.GetContractorTasks(c.Request().Context(), uint(contractorID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -167,7 +167,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_residence_id")
|
return apperrors.BadRequest("error.invalid_residence_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.contractorService.ListContractorsByResidence(uint(residenceID), user.ID)
|
response, err := h.contractorService.ListContractorsByResidence(c.Request().Context(), uint(residenceID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -176,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
|
|||||||
|
|
||||||
// GetSpecialties handles GET /api/contractors/specialties/
|
// GetSpecialties handles GET /api/contractors/specialties/
|
||||||
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
|
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
|
||||||
specialties, err := h.contractorService.GetSpecialties()
|
specialties, err := h.contractorService.GetSpecialties(c.Request().Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.ListDocuments(user.ID, filter)
|
response, err := h.documentService.ListDocuments(c.Request().Context(), user.ID, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -88,7 +88,7 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_document_id")
|
return apperrors.BadRequest("error.invalid_document_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.GetDocument(uint(documentID), user.ID)
|
response, err := h.documentService.GetDocument(c.Request().Context(), uint(documentID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -101,7 +101,7 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
response, err := h.documentService.ListWarranties(user.ID)
|
response, err := h.documentService.ListWarranties(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.CreateDocument(&req, user.ID)
|
response, err := h.documentService.CreateDocument(c.Request().Context(), &req, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -248,7 +248,7 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
|
response, err := h.documentService.UpdateDocument(c.Request().Context(), uint(documentID), user.ID, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -266,7 +266,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_document_id")
|
return apperrors.BadRequest("error.invalid_document_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.documentService.DeleteDocument(uint(documentID), user.ID)
|
err = h.documentService.DeleteDocument(c.Request().Context(), uint(documentID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -284,7 +284,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_document_id")
|
return apperrors.BadRequest("error.invalid_document_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.ActivateDocument(uint(documentID), user.ID)
|
response, err := h.documentService.ActivateDocument(c.Request().Context(), uint(documentID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -302,7 +302,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_document_id")
|
return apperrors.BadRequest("error.invalid_document_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.DeactivateDocument(uint(documentID), user.ID)
|
response, err := h.documentService.DeactivateDocument(c.Request().Context(), uint(documentID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -349,7 +349,7 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
|
|||||||
|
|
||||||
caption := c.FormValue("caption")
|
caption := c.FormValue("caption")
|
||||||
|
|
||||||
response, err := h.documentService.UploadDocumentImage(uint(documentID), user.ID, result.URL, caption)
|
response, err := h.documentService.UploadDocumentImage(c.Request().Context(), uint(documentID), user.ID, result.URL, caption)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -372,7 +372,7 @@ func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_image_id")
|
return apperrors.BadRequest("error.invalid_image_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.documentService.DeleteDocumentImage(uint(documentID), uint(imageID), user.ID)
|
response, err := h.documentService.DeleteDocumentImage(c.Request().Context(), uint(documentID), uint(imageID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
notifications, err := h.notificationService.GetNotifications(user.ID, limit, offset)
|
notifications, err := h.notificationService.GetNotifications(c.Request().Context(), user.ID, limit, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -64,7 +64,7 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := h.notificationService.GetUnreadCount(user.ID)
|
count, err := h.notificationService.GetUnreadCount(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -84,7 +84,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_notification_id")
|
return apperrors.BadRequest("error.invalid_notification_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notificationService.MarkAsRead(uint(notificationID), user.ID)
|
err = h.notificationService.MarkAsRead(c.Request().Context(), uint(notificationID), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -99,7 +99,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notificationService.MarkAllAsRead(user.ID)
|
err = h.notificationService.MarkAllAsRead(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -114,7 +114,7 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
prefs, err := h.notificationService.GetPreferences(user.ID)
|
prefs, err := h.notificationService.GetPreferences(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -137,7 +137,7 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
|
prefs, err := h.notificationService.UpdatePreferences(c.Request().Context(), user.ID, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -160,7 +160,7 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := h.notificationService.RegisterDevice(user.ID, &req)
|
device, err := h.notificationService.RegisterDevice(c.Request().Context(), user.ID, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -175,7 +175,7 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
devices, err := h.notificationService.ListDevices(user.ID)
|
devices, err := h.notificationService.ListDevices(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -208,7 +208,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_platform")
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
err = h.notificationService.UnregisterDevice(c.Request().Context(), req.RegistrationID, req.Platform, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -236,7 +236,7 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.invalid_platform")
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID)
|
err = h.notificationService.DeleteDevice(c.Request().Context(), uint(deviceID), platform, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func (h *StaticDataHandler) GetStaticData(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
contractorSpecialties, err := h.contractorService.GetSpecialties()
|
contractorSpecialties, err := h.contractorService.GetSpecialties(c.Request().Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
subscription, err := h.subscriptionService.GetSubscription(user.ID)
|
subscription, err := h.subscriptionService.GetSubscription(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -47,7 +47,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
|
status, err := h.subscriptionService.GetSubscriptionStatus(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -59,7 +59,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
|
|||||||
func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
|
func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
|
||||||
key := c.Param("key")
|
key := c.Param("key")
|
||||||
|
|
||||||
trigger, err := h.subscriptionService.GetUpgradeTrigger(key)
|
trigger, err := h.subscriptionService.GetUpgradeTrigger(c.Request().Context(), key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -69,7 +69,7 @@ func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
|
|||||||
|
|
||||||
// GetAllUpgradeTriggers handles GET /api/subscription/upgrade-triggers/
|
// GetAllUpgradeTriggers handles GET /api/subscription/upgrade-triggers/
|
||||||
func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
|
func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
|
||||||
triggers, err := h.subscriptionService.GetAllUpgradeTriggers()
|
triggers, err := h.subscriptionService.GetAllUpgradeTriggers(c.Request().Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
|
|||||||
|
|
||||||
// GetFeatureBenefits handles GET /api/subscription/features/
|
// GetFeatureBenefits handles GET /api/subscription/features/
|
||||||
func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
|
func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
|
||||||
benefits, err := h.subscriptionService.GetFeatureBenefits()
|
benefits, err := h.subscriptionService.GetFeatureBenefits(c.Request().Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -94,7 +94,7 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
|
promotions, err := h.subscriptionService.GetActivePromotions(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -125,12 +125,12 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
|
|||||||
if req.TransactionID == "" && req.ReceiptData == "" {
|
if req.TransactionID == "" && req.ReceiptData == "" {
|
||||||
return apperrors.BadRequest("error.receipt_data_required")
|
return apperrors.BadRequest("error.receipt_data_required")
|
||||||
}
|
}
|
||||||
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
|
subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID)
|
||||||
case "android":
|
case "android":
|
||||||
if req.PurchaseToken == "" {
|
if req.PurchaseToken == "" {
|
||||||
return apperrors.BadRequest("error.purchase_token_required")
|
return apperrors.BadRequest("error.purchase_token_required")
|
||||||
}
|
}
|
||||||
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
|
subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID)
|
||||||
default:
|
default:
|
||||||
return apperrors.BadRequest("error.invalid_platform")
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
@@ -152,7 +152,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
|
subscription, err := h.subscriptionService.CancelSubscription(c.Request().Context(), user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -187,12 +187,12 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
|||||||
if req.ReceiptData == "" && req.TransactionID == "" {
|
if req.ReceiptData == "" && req.TransactionID == "" {
|
||||||
return apperrors.BadRequest("error.receipt_data_required")
|
return apperrors.BadRequest("error.receipt_data_required")
|
||||||
}
|
}
|
||||||
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
|
subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID)
|
||||||
case "android":
|
case "android":
|
||||||
if req.PurchaseToken == "" {
|
if req.PurchaseToken == "" {
|
||||||
return apperrors.BadRequest("error.purchase_token_required")
|
return apperrors.BadRequest("error.purchase_token_required")
|
||||||
}
|
}
|
||||||
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
|
subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID)
|
||||||
default:
|
default:
|
||||||
return apperrors.BadRequest("error.invalid_platform")
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
@@ -220,7 +220,7 @@ func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if already Pro from another platform
|
// Check if already Pro from another platform
|
||||||
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe")
|
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(c.Request().Context(), user.ID, "stripe")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
|
|||||||
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
|
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
|
||||||
cachedTZ, _ := c.Get("user_timezone").(string)
|
cachedTZ, _ := c.Get("user_timezone").(string)
|
||||||
if cachedTZ != tzHeader {
|
if cachedTZ != tzHeader {
|
||||||
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
h.taskService.UpdateUserTimezone(c.Request().Context(), user.ID, tzHeader)
|
||||||
c.Set("user_timezone", tzHeader)
|
c.Set("user_timezone", tzHeader)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
|||||||
|
|
||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
|
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
|
||||||
assert.Contains(t, resp.Message, "still valid")
|
assert.Contains(t, resp.Message, "still valid")
|
||||||
@@ -87,7 +88,7 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
|||||||
|
|
||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
|
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
|
||||||
assert.Contains(t, resp.Message, "refreshed")
|
assert.Contains(t, resp.Message, "refreshed")
|
||||||
@@ -114,7 +115,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
|||||||
|
|
||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Nil(t, resp)
|
assert.Nil(t, resp)
|
||||||
assert.Contains(t, err.Error(), "error.token_expired")
|
assert.Contains(t, err.Error(), "error.token_expired")
|
||||||
@@ -129,7 +130,7 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
|||||||
|
|
||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
|
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
|
||||||
}
|
}
|
||||||
@@ -140,7 +141,7 @@ func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
|||||||
|
|
||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
resp, err := svc.RefreshToken("nonexistent-token-key", user.ID)
|
resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Nil(t, resp)
|
assert.Nil(t, resp)
|
||||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||||
@@ -154,7 +155,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
|||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
// Try to refresh with a different user ID
|
// Try to refresh with a different user ID
|
||||||
resp, err := svc.RefreshToken(token.Key, user.ID+999)
|
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Nil(t, resp)
|
assert.Nil(t, resp)
|
||||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||||
@@ -167,7 +168,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
|||||||
|
|
||||||
svc := newTestAuthService(db)
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
|
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,14 +57,14 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login authenticates a user and returns a token
|
// Login authenticates a user and returns a token
|
||||||
func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginResponse, error) {
|
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) {
|
||||||
// Find user by username or email
|
// Find user by username or email
|
||||||
identifier := req.Username
|
identifier := req.Username
|
||||||
if identifier == "" {
|
if identifier == "" {
|
||||||
identifier = req.Email
|
identifier = req.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.userRepo.FindByUsernameOrEmail(identifier)
|
user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repositories.ErrUserNotFound) {
|
if errors.Is(err, repositories.ErrUserNotFound) {
|
||||||
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
||||||
@@ -83,13 +83,13 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get or create auth token
|
// Get or create auth token
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
// Update last login
|
||||||
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
|
if err := s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID); err != nil {
|
||||||
// Log error but don't fail the login
|
// Log error but don't fail the login
|
||||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
|
||||||
}
|
}
|
||||||
@@ -103,9 +103,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
|
|||||||
// Register creates a new user account.
|
// Register creates a new user account.
|
||||||
// F-10: User creation, profile creation, notification preferences, and confirmation code
|
// F-10: User creation, profile creation, notification preferences, and confirmation code
|
||||||
// are wrapped in a transaction for atomicity.
|
// are wrapped in a transaction for atomicity.
|
||||||
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
|
func (s *AuthService) Register(ctx context.Context, req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
|
||||||
// Check if username exists
|
// Check if username exists
|
||||||
exists, err := s.userRepo.ExistsByUsername(req.Username)
|
exists, err := s.userRepo.WithContext(ctx).ExistsByUsername(req.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", apperrors.Internal(err)
|
return nil, "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -114,7 +114,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if email exists
|
// Check if email exists
|
||||||
exists, err = s.userRepo.ExistsByEmail(req.Email)
|
exists, err = s.userRepo.WithContext(ctx).ExistsByEmail(req.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", apperrors.Internal(err)
|
return nil, "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -146,7 +146,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
||||||
|
|
||||||
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
|
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
|
||||||
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
|
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
|
||||||
// Save user
|
// Save user
|
||||||
if err := txRepo.Create(user); err != nil {
|
if err := txRepo.Create(user); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -159,7 +159,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
|
|
||||||
// Create notification preferences with all options enabled
|
// Create notification preferences with all options enabled
|
||||||
if s.notificationRepo != nil {
|
if s.notificationRepo != nil {
|
||||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
|
||||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -176,7 +176,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create auth token (outside transaction since token generation is idempotent)
|
// Create auth token (outside transaction since token generation is idempotent)
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", apperrors.Internal(err)
|
return nil, "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -192,7 +192,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
// - If token is expired (> expiryDays old), returns error (must re-login).
|
// - If token is expired (> expiryDays old), returns error (must re-login).
|
||||||
// - If token is in the renewal window (> refreshDays old), generates a new token.
|
// - If token is in the renewal window (> refreshDays old), generates a new token.
|
||||||
// - If token is still fresh (< refreshDays old), returns the existing token (no-op).
|
// - If token is still fresh (< refreshDays old), returns the existing token (no-op).
|
||||||
func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
|
func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
|
||||||
expiryDays := s.cfg.Security.TokenExpiryDays
|
expiryDays := s.cfg.Security.TokenExpiryDays
|
||||||
if expiryDays <= 0 {
|
if expiryDays <= 0 {
|
||||||
expiryDays = 90
|
expiryDays = 90
|
||||||
@@ -203,7 +203,7 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Look up the token
|
// Look up the token
|
||||||
authToken, err := s.userRepo.FindTokenByKey(tokenKey)
|
authToken, err := s.userRepo.WithContext(ctx).FindTokenByKey(tokenKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Unauthorized("error.invalid_token")
|
return nil, apperrors.Unauthorized("error.invalid_token")
|
||||||
}
|
}
|
||||||
@@ -232,12 +232,12 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
|
|||||||
|
|
||||||
// Token is in the renewal window — generate a new one
|
// Token is in the renewal window — generate a new one
|
||||||
// Delete the old token
|
// Delete the old token
|
||||||
if err := s.userRepo.DeleteToken(tokenKey); err != nil {
|
if err := s.userRepo.WithContext(ctx).DeleteToken(tokenKey); err != nil {
|
||||||
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
|
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new token
|
// Create a new token
|
||||||
newToken, err := s.userRepo.CreateToken(userID)
|
newToken, err := s.userRepo.WithContext(ctx).CreateToken(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -249,18 +249,18 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Logout invalidates a user's token
|
// Logout invalidates a user's token
|
||||||
func (s *AuthService) Logout(token string) error {
|
func (s *AuthService) Logout(ctx context.Context, token string) error {
|
||||||
return s.userRepo.DeleteToken(token)
|
return s.userRepo.WithContext(ctx).DeleteToken(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCurrentUser returns the current authenticated user with profile
|
// GetCurrentUser returns the current authenticated user with profile
|
||||||
func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserResponse, error) {
|
func (s *AuthService) GetCurrentUser(ctx context.Context, userID uint) (*responses.CurrentUserResponse, error) {
|
||||||
user, err := s.userRepo.FindByIDWithProfile(userID)
|
user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authProvider, err := s.userRepo.FindAuthProvider(userID)
|
authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log but don't fail - default to "email"
|
// Log but don't fail - default to "email"
|
||||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
|
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
|
||||||
@@ -275,9 +275,9 @@ func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserRespons
|
|||||||
// For email auth users, password verification is required.
|
// For email auth users, password verification is required.
|
||||||
// For social auth users, confirmation string "DELETE" is required.
|
// For social auth users, confirmation string "DELETE" is required.
|
||||||
// Returns a list of file URLs that need to be deleted from disk.
|
// Returns a list of file URLs that need to be deleted from disk.
|
||||||
func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) ([]string, error) {
|
func (s *AuthService) DeleteAccount(ctx context.Context, userID uint, password, confirmation *string) ([]string, error) {
|
||||||
// Fetch user
|
// Fetch user
|
||||||
user, err := s.userRepo.FindByID(userID)
|
user, err := s.userRepo.WithContext(ctx).FindByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repositories.ErrUserNotFound) {
|
if errors.Is(err, repositories.ErrUserNotFound) {
|
||||||
return nil, apperrors.NotFound("error.user_not_found")
|
return nil, apperrors.NotFound("error.user_not_found")
|
||||||
@@ -286,7 +286,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine auth provider
|
// Determine auth provider
|
||||||
authProvider, err := s.userRepo.FindAuthProvider(userID)
|
authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -308,7 +308,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
|
|||||||
|
|
||||||
// Start transaction and cascade delete
|
// Start transaction and cascade delete
|
||||||
var fileURLs []string
|
var fileURLs []string
|
||||||
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
|
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
|
||||||
urls, err := txRepo.DeleteUserCascade(userID)
|
urls, err := txRepo.DeleteUserCascade(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -324,15 +324,15 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProfile updates a user's profile
|
// UpdateProfile updates a user's profile
|
||||||
func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
|
func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
|
||||||
user, err := s.userRepo.FindByID(userID)
|
user, err := s.userRepo.WithContext(ctx).FindByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if new email is taken (if email is being changed)
|
// Check if new email is taken (if email is being changed)
|
||||||
if req.Email != nil && *req.Email != user.Email {
|
if req.Email != nil && *req.Email != user.Email {
|
||||||
exists, err := s.userRepo.ExistsByEmail(*req.Email)
|
exists, err := s.userRepo.WithContext(ctx).ExistsByEmail(*req.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -349,17 +349,17 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
|
|||||||
user.LastName = *req.LastName
|
user.LastName = *req.LastName
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.userRepo.Update(user); err != nil {
|
if err := s.userRepo.WithContext(ctx).Update(user); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload with profile
|
// Reload with profile
|
||||||
user, err = s.userRepo.FindByIDWithProfile(userID)
|
user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authProvider, err := s.userRepo.FindAuthProvider(userID)
|
authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
|
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
|
||||||
authProvider = "email"
|
authProvider = "email"
|
||||||
@@ -370,9 +370,9 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
// VerifyEmail verifies a user's email with a confirmation code
|
// VerifyEmail verifies a user's email with a confirmation code
|
||||||
func (s *AuthService) VerifyEmail(userID uint, code string) error {
|
func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string) error {
|
||||||
// Get user profile
|
// Get user profile
|
||||||
profile, err := s.userRepo.GetOrCreateProfile(userID)
|
profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -384,14 +384,14 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
|
|||||||
|
|
||||||
// Check for test code when DEBUG_FIXED_CODES is enabled
|
// Check for test code when DEBUG_FIXED_CODES is enabled
|
||||||
if s.cfg.Server.DebugFixedCodes && code == "123456" {
|
if s.cfg.Server.DebugFixedCodes && code == "123456" {
|
||||||
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
|
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find and validate confirmation code
|
// Find and validate confirmation code
|
||||||
confirmCode, err := s.userRepo.FindConfirmationCode(userID, code)
|
confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repositories.ErrCodeNotFound) {
|
if errors.Is(err, repositories.ErrCodeNotFound) {
|
||||||
return apperrors.BadRequest("error.invalid_verification_code")
|
return apperrors.BadRequest("error.invalid_verification_code")
|
||||||
@@ -403,12 +403,12 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mark code as used
|
// Mark code as used
|
||||||
if err := s.userRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
|
if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set profile as verified
|
// Set profile as verified
|
||||||
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
|
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,9 +416,9 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResendVerificationCode creates and returns a new verification code
|
// ResendVerificationCode creates and returns a new verification code
|
||||||
func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
|
func (s *AuthService) ResendVerificationCode(ctx context.Context, userID uint) (string, error) {
|
||||||
// Get user profile
|
// Get user profile
|
||||||
profile, err := s.userRepo.GetOrCreateProfile(userID)
|
profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", apperrors.Internal(err)
|
return "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -437,7 +437,7 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
|
|||||||
}
|
}
|
||||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
||||||
|
|
||||||
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil {
|
if _, err := s.userRepo.WithContext(ctx).CreateConfirmationCode(userID, code, expiresAt); err != nil {
|
||||||
return "", apperrors.Internal(err)
|
return "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,9 +445,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ForgotPassword initiates the password reset process
|
// ForgotPassword initiates the password reset process
|
||||||
func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) {
|
func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string, *models.User, error) {
|
||||||
// Find user by email
|
// Find user by email
|
||||||
user, err := s.userRepo.FindByEmail(email)
|
user, err := s.userRepo.WithContext(ctx).FindByEmail(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repositories.ErrUserNotFound) {
|
if errors.Is(err, repositories.ErrUserNotFound) {
|
||||||
// Don't reveal that the email doesn't exist
|
// Don't reveal that the email doesn't exist
|
||||||
@@ -457,7 +457,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check rate limit
|
// Check rate limit
|
||||||
count, err := s.userRepo.CountRecentPasswordResetRequests(user.ID)
|
count, err := s.userRepo.WithContext(ctx).CountRecentPasswordResetRequests(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, apperrors.Internal(err)
|
return "", nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -481,7 +481,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
|
|||||||
return "", nil, apperrors.Internal(err)
|
return "", nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.userRepo.CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
|
if _, err := s.userRepo.WithContext(ctx).CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
|
||||||
return "", nil, apperrors.Internal(err)
|
return "", nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,9 +489,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// VerifyResetCode verifies a password reset code and returns a reset token
|
// VerifyResetCode verifies a password reset code and returns a reset token
|
||||||
func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
|
func (s *AuthService) VerifyResetCode(ctx context.Context, email, code string) (string, error) {
|
||||||
// Find the reset code
|
// Find the reset code
|
||||||
resetCode, user, err := s.userRepo.FindPasswordResetCodeByEmail(email)
|
resetCode, user, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByEmail(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) {
|
if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) {
|
||||||
return "", apperrors.BadRequest("error.invalid_verification_code")
|
return "", apperrors.BadRequest("error.invalid_verification_code")
|
||||||
@@ -507,7 +507,7 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
|
|||||||
// Verify the code
|
// Verify the code
|
||||||
if !resetCode.CheckCode(code) {
|
if !resetCode.CheckCode(code) {
|
||||||
// Increment attempts
|
// Increment attempts
|
||||||
s.userRepo.IncrementResetCodeAttempts(resetCode.ID)
|
s.userRepo.WithContext(ctx).IncrementResetCodeAttempts(resetCode.ID)
|
||||||
return "", apperrors.BadRequest("error.invalid_verification_code")
|
return "", apperrors.BadRequest("error.invalid_verification_code")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -528,9 +528,9 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResetPassword resets the user's password using a reset token
|
// ResetPassword resets the user's password using a reset token
|
||||||
func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
|
func (s *AuthService) ResetPassword(ctx context.Context, resetToken, newPassword string) error {
|
||||||
// Find the reset code by token
|
// Find the reset code by token
|
||||||
resetCode, err := s.userRepo.FindPasswordResetCodeByToken(resetToken)
|
resetCode, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByToken(resetToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) {
|
if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) {
|
||||||
return apperrors.BadRequest("error.invalid_reset_token")
|
return apperrors.BadRequest("error.invalid_reset_token")
|
||||||
@@ -539,7 +539,7 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the user
|
// Get the user
|
||||||
user, err := s.userRepo.FindByID(resetCode.UserID)
|
user, err := s.userRepo.WithContext(ctx).FindByID(resetCode.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -549,18 +549,18 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
|
|||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.userRepo.Update(user); err != nil {
|
if err := s.userRepo.WithContext(ctx).Update(user); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark reset code as used
|
// Mark reset code as used
|
||||||
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
|
if err := s.userRepo.WithContext(ctx).MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
|
||||||
// Log error but don't fail
|
// Log error but don't fail
|
||||||
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
|
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalidate all existing tokens for this user (security measure)
|
// Invalidate all existing tokens for this user (security measure)
|
||||||
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
|
if err := s.userRepo.WithContext(ctx).DeleteTokenByUserID(user.ID); err != nil {
|
||||||
// Log error but don't fail
|
// Log error but don't fail
|
||||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
|
||||||
}
|
}
|
||||||
@@ -583,10 +583,10 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Check if this Apple ID is already linked to an account
|
// 2. Check if this Apple ID is already linked to an account
|
||||||
existingAuth, err := s.userRepo.FindByAppleID(appleID)
|
existingAuth, err := s.userRepo.WithContext(ctx).FindByAppleID(appleID)
|
||||||
if err == nil && existingAuth != nil {
|
if err == nil && existingAuth != nil {
|
||||||
// User already linked with this Apple ID - log them in
|
// User already linked with this Apple ID - log them in
|
||||||
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
|
user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -596,13 +596,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get or create token
|
// Get or create token
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
// Update last login
|
||||||
_ = s.userRepo.UpdateLastLogin(user.ID)
|
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
|
||||||
|
|
||||||
return &responses.AppleSignInResponse{
|
return &responses.AppleSignInResponse{
|
||||||
Token: token.Key,
|
Token: token.Key,
|
||||||
@@ -614,7 +614,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
// 3. Check if email matches an existing user (for account linking)
|
// 3. Check if email matches an existing user (for account linking)
|
||||||
email := getEmailFromRequest(req.Email, claims.Email)
|
email := getEmailFromRequest(req.Email, claims.Email)
|
||||||
if email != "" {
|
if email != "" {
|
||||||
existingUser, err := s.userRepo.FindByEmail(email)
|
existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email)
|
||||||
if err == nil && existingUser != nil {
|
if err == nil && existingUser != nil {
|
||||||
// S-06: Log auto-linking of social account to existing user
|
// S-06: Log auto-linking of social account to existing user
|
||||||
log.Warn().
|
log.Warn().
|
||||||
@@ -630,24 +630,24 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
Email: email,
|
Email: email,
|
||||||
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
|
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
|
||||||
}
|
}
|
||||||
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
|
if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark as verified since Apple verified the email
|
// Mark as verified since Apple verified the email
|
||||||
_ = s.userRepo.SetProfileVerified(existingUser.ID, true)
|
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
|
||||||
|
|
||||||
// Get or create token
|
// Get or create token
|
||||||
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
// Update last login
|
||||||
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID)
|
||||||
|
|
||||||
// B-08: Check error from FindByIDWithProfile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -675,19 +675,19 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
randomPassword := generateResetToken()
|
randomPassword := generateResetToken()
|
||||||
_ = user.SetPassword(randomPassword)
|
_ = user.SetPassword(randomPassword)
|
||||||
|
|
||||||
if err := s.userRepo.Create(user); err != nil {
|
if err := s.userRepo.WithContext(ctx).Create(user); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create profile (already verified since Apple verified)
|
// Create profile (already verified since Apple verified)
|
||||||
profile, _ := s.userRepo.GetOrCreateProfile(user.ID)
|
profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID)
|
||||||
if profile != nil {
|
if profile != nil {
|
||||||
_ = s.userRepo.SetProfileVerified(user.ID, true)
|
_ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create notification preferences with all options enabled
|
// Create notification preferences with all options enabled
|
||||||
if s.notificationRepo != nil {
|
if s.notificationRepo != nil {
|
||||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
|
||||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -699,18 +699,18 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
Email: getEmailOrDefault(email),
|
Email: getEmailOrDefault(email),
|
||||||
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
|
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
|
||||||
}
|
}
|
||||||
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
|
if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token
|
// Create token
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// B-08: Check error from FindByIDWithProfile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
user, err = s.userRepo.FindByIDWithProfile(user.ID)
|
user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -736,10 +736,10 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Check if this Google ID is already linked to an account
|
// 2. Check if this Google ID is already linked to an account
|
||||||
existingAuth, err := s.userRepo.FindByGoogleID(googleID)
|
existingAuth, err := s.userRepo.WithContext(ctx).FindByGoogleID(googleID)
|
||||||
if err == nil && existingAuth != nil {
|
if err == nil && existingAuth != nil {
|
||||||
// User already linked with this Google ID - log them in
|
// User already linked with this Google ID - log them in
|
||||||
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
|
user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -749,13 +749,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get or create token
|
// Get or create token
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
// Update last login
|
||||||
_ = s.userRepo.UpdateLastLogin(user.ID)
|
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
|
||||||
|
|
||||||
return &responses.GoogleSignInResponse{
|
return &responses.GoogleSignInResponse{
|
||||||
Token: token.Key,
|
Token: token.Key,
|
||||||
@@ -767,7 +767,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
// 3. Check if email matches an existing user (for account linking)
|
// 3. Check if email matches an existing user (for account linking)
|
||||||
email := tokenInfo.Email
|
email := tokenInfo.Email
|
||||||
if email != "" {
|
if email != "" {
|
||||||
existingUser, err := s.userRepo.FindByEmail(email)
|
existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email)
|
||||||
if err == nil && existingUser != nil {
|
if err == nil && existingUser != nil {
|
||||||
// S-06: Log auto-linking of social account to existing user
|
// S-06: Log auto-linking of social account to existing user
|
||||||
log.Warn().
|
log.Warn().
|
||||||
@@ -784,26 +784,26 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
Name: tokenInfo.Name,
|
Name: tokenInfo.Name,
|
||||||
Picture: tokenInfo.Picture,
|
Picture: tokenInfo.Picture,
|
||||||
}
|
}
|
||||||
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil {
|
if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark as verified since Google verified the email
|
// Mark as verified since Google verified the email
|
||||||
if tokenInfo.IsEmailVerified() {
|
if tokenInfo.IsEmailVerified() {
|
||||||
_ = s.userRepo.SetProfileVerified(existingUser.ID, true)
|
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get or create token
|
// Get or create token
|
||||||
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
// Update last login
|
||||||
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID)
|
||||||
|
|
||||||
// B-08: Check error from FindByIDWithProfile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -831,19 +831,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
randomPassword := generateResetToken()
|
randomPassword := generateResetToken()
|
||||||
_ = user.SetPassword(randomPassword)
|
_ = user.SetPassword(randomPassword)
|
||||||
|
|
||||||
if err := s.userRepo.Create(user); err != nil {
|
if err := s.userRepo.WithContext(ctx).Create(user); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create profile (already verified if Google verified email)
|
// Create profile (already verified if Google verified email)
|
||||||
profile, _ := s.userRepo.GetOrCreateProfile(user.ID)
|
profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID)
|
||||||
if profile != nil && tokenInfo.IsEmailVerified() {
|
if profile != nil && tokenInfo.IsEmailVerified() {
|
||||||
_ = s.userRepo.SetProfileVerified(user.ID, true)
|
_ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create notification preferences with all options enabled
|
// Create notification preferences with all options enabled
|
||||||
if s.notificationRepo != nil {
|
if s.notificationRepo != nil {
|
||||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
|
||||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -856,18 +856,18 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
Name: tokenInfo.Name,
|
Name: tokenInfo.Name,
|
||||||
Picture: tokenInfo.Picture,
|
Picture: tokenInfo.Picture,
|
||||||
}
|
}
|
||||||
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil {
|
if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token
|
// Create token
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// B-08: Check error from FindByIDWithProfile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
user, err = s.userRepo.FindByIDWithProfile(user.ID)
|
user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -53,7 +54,7 @@ func TestAuthService_Login(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.Login(req)
|
resp, err := service.Login(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, resp.Token)
|
assert.NotEmpty(t, resp.Token)
|
||||||
assert.Equal(t, "testuser", resp.User.Username)
|
assert.Equal(t, "testuser", resp.User.Username)
|
||||||
@@ -74,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.Login(req)
|
resp, err := service.Login(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, resp.Token)
|
assert.NotEmpty(t, resp.Token)
|
||||||
}
|
}
|
||||||
@@ -94,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) {
|
|||||||
Password: "WrongPassword1",
|
Password: "WrongPassword1",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.Login(req)
|
_, err := service.Login(context.Background(), req)
|
||||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.Login(req)
|
_, err := service.Login(context.Background(), req)
|
||||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +134,7 @@ func TestAuthService_Login_InactiveUser(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.Login(req)
|
_, err := service.Login(context.Background(), req)
|
||||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
|
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ func TestAuthService_Register(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, code, err := service.Register(req)
|
resp, code, err := service.Register(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, resp.Token)
|
assert.NotEmpty(t, resp.Token)
|
||||||
assert.Equal(t, "newuser", resp.User.Username)
|
assert.Equal(t, "newuser", resp.User.Username)
|
||||||
@@ -172,7 +173,7 @@ func TestAuthService_Register_DuplicateUsername(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err := service.Register(req)
|
_, _, err := service.Register(context.Background(), req)
|
||||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
|
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +194,7 @@ func TestAuthService_Register_DuplicateEmail(t *testing.T) {
|
|||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err := service.Register(req)
|
_, _, err := service.Register(context.Background(), req)
|
||||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
|
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +212,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
|
|||||||
// Create profile
|
// Create profile
|
||||||
userRepo.GetOrCreateProfile(user.ID)
|
userRepo.GetOrCreateProfile(user.ID)
|
||||||
|
|
||||||
resp, err := service.GetCurrentUser(user.ID)
|
resp, err := service.GetCurrentUser(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "testuser", resp.Username)
|
assert.Equal(t, "testuser", resp.Username)
|
||||||
assert.Equal(t, "test@test.com", resp.Email)
|
assert.Equal(t, "test@test.com", resp.Email)
|
||||||
@@ -238,7 +239,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
|
|||||||
LastName: &newLast,
|
LastName: &newLast,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateProfile(user.ID, req)
|
resp, err := service.UpdateProfile(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "John", resp.FirstName)
|
assert.Equal(t, "John", resp.FirstName)
|
||||||
assert.Equal(t, "Doe", resp.LastName)
|
assert.Equal(t, "Doe", resp.LastName)
|
||||||
@@ -261,7 +262,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
|
|||||||
Email: &takenEmail,
|
Email: &takenEmail,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.UpdateProfile(user2.ID, req)
|
_, err := service.UpdateProfile(context.Background(), user2.ID, req)
|
||||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken")
|
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,7 +283,7 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Same email should not trigger duplicate error
|
// Same email should not trigger duplicate error
|
||||||
resp, err := service.UpdateProfile(user.ID, req)
|
resp, err := service.UpdateProfile(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "test@test.com", resp.Email)
|
assert.Equal(t, "test@test.com", resp.Email)
|
||||||
}
|
}
|
||||||
@@ -298,7 +299,7 @@ func TestAuthService_VerifyEmail(t *testing.T) {
|
|||||||
Email: "new@test.com",
|
Email: "new@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(req)
|
_, _, err := service.Register(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get the user ID
|
// Get the user ID
|
||||||
@@ -306,11 +307,11 @@ func TestAuthService_VerifyEmail(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify with the debug code
|
// Verify with the debug code
|
||||||
err = service.VerifyEmail(user.ID, "123456")
|
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify again — should get already verified error
|
// Verify again — should get already verified error
|
||||||
err = service.VerifyEmail(user.ID, "123456")
|
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,7 +324,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
|
|||||||
Email: "new@test.com",
|
Email: "new@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(req)
|
_, _, err := service.Register(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||||
@@ -331,7 +332,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
|
|||||||
|
|
||||||
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
|
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
|
||||||
// but a wrong code should use the normal path
|
// but a wrong code should use the normal path
|
||||||
err = service.VerifyEmail(user.ID, "000000")
|
err = service.VerifyEmail(context.Background(), user.ID, "000000")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -346,13 +347,13 @@ func TestAuthService_ResendVerificationCode(t *testing.T) {
|
|||||||
Email: "new@test.com",
|
Email: "new@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(req)
|
_, _, err := service.Register(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
code, err := service.ResendVerificationCode(user.ID)
|
code, err := service.ResendVerificationCode(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||||
}
|
}
|
||||||
@@ -366,16 +367,16 @@ func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
|
|||||||
Email: "new@test.com",
|
Email: "new@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(req)
|
_, _, err := service.Register(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.VerifyEmail(user.ID, "123456")
|
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = service.ResendVerificationCode(user.ID)
|
_, err = service.ResendVerificationCode(context.Background(), user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,10 +391,10 @@ func TestAuthService_ForgotPassword(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
code, user, err := service.ForgotPassword("test@test.com")
|
code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||||
assert.NotNil(t, user)
|
assert.NotNil(t, user)
|
||||||
@@ -404,7 +405,7 @@ func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
|
|||||||
service, _ := setupAuthService(t)
|
service, _ := setupAuthService(t)
|
||||||
|
|
||||||
// Should not reveal that email doesn't exist
|
// Should not reveal that email doesn't exist
|
||||||
code, user, err := service.ForgotPassword("nonexistent@test.com")
|
code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, code)
|
assert.Empty(t, code)
|
||||||
assert.Nil(t, user)
|
assert.Nil(t, user)
|
||||||
@@ -421,20 +422,20 @@ func TestAuthService_ResetPassword(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Forgot password
|
// Forgot password
|
||||||
_, _, err = service.ForgotPassword("test@test.com")
|
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify reset code to get the token
|
// Verify reset code to get the token
|
||||||
resetToken, err := service.VerifyResetCode("test@test.com", "123456")
|
resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, resetToken)
|
assert.NotEmpty(t, resetToken)
|
||||||
|
|
||||||
// Reset password
|
// Reset password
|
||||||
err = service.ResetPassword(resetToken, "NewPassword123")
|
err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Login with new password
|
// Login with new password
|
||||||
@@ -442,7 +443,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
|
|||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "NewPassword123",
|
Password: "NewPassword123",
|
||||||
}
|
}
|
||||||
loginResp, err := service.Login(loginReq)
|
loginResp, err := service.Login(context.Background(), loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, loginResp.Token)
|
assert.NotEmpty(t, loginResp.Token)
|
||||||
}
|
}
|
||||||
@@ -450,7 +451,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
|
|||||||
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
|
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
|
||||||
service, _ := setupAuthService(t)
|
service, _ := setupAuthService(t)
|
||||||
|
|
||||||
err := service.ResetPassword("invalid-token", "NewPassword123")
|
err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -471,15 +472,15 @@ func TestAuthService_Logout(t *testing.T) {
|
|||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
loginResp, err := service.Login(loginReq)
|
loginResp, err := service.Login(context.Background(), loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Logout
|
// Logout
|
||||||
err = service.Logout(loginResp.Token)
|
err = service.Logout(context.Background(), loginResp.Token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Token should be deleted — refreshing should fail
|
// Token should be deleted — refreshing should fail
|
||||||
_, err = service.RefreshToken(loginResp.Token, user.ID)
|
_, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,14 +495,14 @@ func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
password := "Password123"
|
password := "Password123"
|
||||||
_, err = service.DeleteAccount(user.ID, &password, nil)
|
_, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -513,14 +514,14 @@ func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
wrongPassword := "WrongPassword1"
|
wrongPassword := "WrongPassword1"
|
||||||
_, err = service.DeleteAccount(user.ID, &wrongPassword, nil)
|
_, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
|
||||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -532,13 +533,13 @@ func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = service.DeleteAccount(user.ID, nil, nil)
|
_, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -546,7 +547,7 @@ func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
|
|||||||
service, _ := setupAuthService(t)
|
service, _ := setupAuthService(t)
|
||||||
|
|
||||||
password := "Password123"
|
password := "Password123"
|
||||||
_, err := service.DeleteAccount(99999, &password, nil)
|
_, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -658,7 +659,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) {
|
|||||||
Password: "",
|
Password: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.Login(req)
|
_, err := service.Login(context.Background(), req)
|
||||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -672,17 +673,17 @@ func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Make max allowed reset requests (3 based on setup)
|
// Make max allowed reset requests (3 based on setup)
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
_, _, err := service.ForgotPassword("test@test.com")
|
_, _, err := service.ForgotPassword(context.Background(), "test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The 4th should be rate limited
|
// The 4th should be rate limited
|
||||||
_, _, err = service.ForgotPassword("test@test.com")
|
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -696,14 +697,14 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, err = service.ForgotPassword("test@test.com")
|
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Wrong code but with debug mode, "123456" works, "000000" should fail
|
// Wrong code but with debug mode, "123456" works, "000000" should fail
|
||||||
_, err = service.VerifyResetCode("test@test.com", "000000")
|
_, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -712,7 +713,7 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
|
|||||||
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
|
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
|
||||||
service, _ := setupAuthService(t)
|
service, _ := setupAuthService(t)
|
||||||
|
|
||||||
_, err := service.VerifyResetCode("nonexistent@test.com", "123456")
|
_, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -734,7 +735,7 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
|
|||||||
Email: &newEmail,
|
Email: &newEmail,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateProfile(user.ID, req)
|
resp, err := service.UpdateProfile(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "newemail@test.com", resp.Email)
|
assert.Equal(t, "newemail@test.com", resp.Email)
|
||||||
}
|
}
|
||||||
@@ -749,14 +750,14 @@ func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
|
|||||||
Email: "test@test.com",
|
Email: "test@test.com",
|
||||||
Password: "Password123",
|
Password: "Password123",
|
||||||
}
|
}
|
||||||
_, _, err := service.Register(registerReq)
|
_, _, err := service.Register(context.Background(), registerReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
emptyPw := ""
|
emptyPw := ""
|
||||||
_, err = service.DeleteAccount(user.ID, &emptyPw, nil)
|
_, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -789,7 +790,7 @@ func TestAuthService_Register_CreatesProfile(t *testing.T) {
|
|||||||
LastName: "Doe",
|
LastName: "Doe",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, _, err := service.Register(req)
|
resp, _, err := service.Register(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "profileuser", resp.User.Username)
|
assert.Equal(t, "profileuser", resp.User.Username)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -33,8 +34,8 @@ func NewContractorService(contractorRepo *repositories.ContractorRepository, res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetContractor gets a contractor by ID with access check
|
// GetContractor gets a contractor by ID with access check
|
||||||
func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses.ContractorResponse, error) {
|
func (s *ContractorService) GetContractor(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) {
|
||||||
contractor, err := s.contractorRepo.FindByID(contractorID)
|
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.contractor_not_found")
|
return nil, apperrors.NotFound("error.contractor_not_found")
|
||||||
@@ -43,7 +44,7 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
if !s.hasContractorAccess(contractor, userID) {
|
if !s.hasContractorAccess(ctx, contractor, userID) {
|
||||||
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,14 +56,14 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
|
|||||||
// Access rules:
|
// Access rules:
|
||||||
// - If contractor has no residence: only the creator has access
|
// - If contractor has no residence: only the creator has access
|
||||||
// - If contractor has a residence: all users with access to that residence
|
// - If contractor has a residence: all users with access to that residence
|
||||||
func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, userID uint) bool {
|
func (s *ContractorService) hasContractorAccess(ctx context.Context, contractor *models.Contractor, userID uint) bool {
|
||||||
if contractor.ResidenceID == nil {
|
if contractor.ResidenceID == nil {
|
||||||
// Personal contractor - only creator has access
|
// Personal contractor - only creator has access
|
||||||
return contractor.CreatedByID == userID
|
return contractor.CreatedByID == userID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Residence contractor - check residence access
|
// Residence contractor - check residence access
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(*contractor.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*contractor.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -70,15 +71,15 @@ func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListContractors lists all contractors accessible to a user
|
// ListContractors lists all contractors accessible to a user
|
||||||
func (s *ContractorService) ListContractors(userID uint) ([]responses.ContractorResponse, error) {
|
func (s *ContractorService) ListContractors(ctx context.Context, userID uint) ([]responses.ContractorResponse, error) {
|
||||||
// Get residence IDs (lightweight - no preloads)
|
// Get residence IDs (lightweight - no preloads)
|
||||||
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
|
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByUser now handles both personal and residence contractors
|
// FindByUser now handles both personal and residence contractors
|
||||||
contractors, err := s.contractorRepo.FindByUser(userID, residenceIDs)
|
contractors, err := s.contractorRepo.WithContext(ctx).FindByUser(userID, residenceIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -87,10 +88,10 @@ func (s *ContractorService) ListContractors(userID uint) ([]responses.Contractor
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateContractor creates a new contractor
|
// CreateContractor creates a new contractor
|
||||||
func (s *ContractorService) CreateContractor(req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) {
|
func (s *ContractorService) CreateContractor(ctx context.Context, req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) {
|
||||||
// If residence is provided, check access
|
// If residence is provided, check access
|
||||||
if req.ResidenceID != nil {
|
if req.ResidenceID != nil {
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -122,19 +123,19 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
|
|||||||
IsActive: true,
|
IsActive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.contractorRepo.Create(contractor); err != nil {
|
if err := s.contractorRepo.WithContext(ctx).Create(contractor); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set specialties if provided
|
// Set specialties if provided
|
||||||
if len(req.SpecialtyIDs) > 0 {
|
if len(req.SpecialtyIDs) > 0 {
|
||||||
if err := s.contractorRepo.SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
|
if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload with relations
|
// Reload with relations
|
||||||
contractor, reloadErr := s.contractorRepo.FindByID(contractor.ID)
|
contractor, reloadErr := s.contractorRepo.WithContext(ctx).FindByID(contractor.ID)
|
||||||
if reloadErr != nil {
|
if reloadErr != nil {
|
||||||
return nil, apperrors.Internal(reloadErr)
|
return nil, apperrors.Internal(reloadErr)
|
||||||
}
|
}
|
||||||
@@ -144,8 +145,8 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateContractor updates a contractor
|
// UpdateContractor updates a contractor
|
||||||
func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) {
|
func (s *ContractorService) UpdateContractor(ctx context.Context, contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) {
|
||||||
contractor, err := s.contractorRepo.FindByID(contractorID)
|
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.contractor_not_found")
|
return nil, apperrors.NotFound("error.contractor_not_found")
|
||||||
@@ -154,7 +155,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
if !s.hasContractorAccess(contractor, userID) {
|
if !s.hasContractorAccess(ctx, contractor, userID) {
|
||||||
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +199,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
|
|||||||
// If residence_id is provided, verify the user has access to the NEW residence.
|
// If residence_id is provided, verify the user has access to the NEW residence.
|
||||||
// This prevents an attacker from reassigning a contractor to someone else's residence.
|
// This prevents an attacker from reassigning a contractor to someone else's residence.
|
||||||
if req.ResidenceID != nil {
|
if req.ResidenceID != nil {
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -211,19 +212,19 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
|
|||||||
// removed the residence association - contractor becomes personal
|
// removed the residence association - contractor becomes personal
|
||||||
contractor.ResidenceID = req.ResidenceID
|
contractor.ResidenceID = req.ResidenceID
|
||||||
|
|
||||||
if err := s.contractorRepo.Update(contractor); err != nil {
|
if err := s.contractorRepo.WithContext(ctx).Update(contractor); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update specialties if provided
|
// Update specialties if provided
|
||||||
if req.SpecialtyIDs != nil {
|
if req.SpecialtyIDs != nil {
|
||||||
if err := s.contractorRepo.SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
|
if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload
|
// Reload
|
||||||
contractor, err = s.contractorRepo.FindByID(contractorID)
|
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -233,8 +234,8 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteContractor soft-deletes a contractor
|
// DeleteContractor soft-deletes a contractor
|
||||||
func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
|
func (s *ContractorService) DeleteContractor(ctx context.Context, contractorID, userID uint) error {
|
||||||
contractor, err := s.contractorRepo.FindByID(contractorID)
|
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return apperrors.NotFound("error.contractor_not_found")
|
return apperrors.NotFound("error.contractor_not_found")
|
||||||
@@ -243,11 +244,11 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
if !s.hasContractorAccess(contractor, userID) {
|
if !s.hasContractorAccess(ctx, contractor, userID) {
|
||||||
return apperrors.Forbidden("error.contractor_access_denied")
|
return apperrors.Forbidden("error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.contractorRepo.Delete(contractorID); err != nil {
|
if err := s.contractorRepo.WithContext(ctx).Delete(contractorID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,8 +256,8 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor
|
// ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor
|
||||||
func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*responses.ContractorResponse, error) {
|
func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) {
|
||||||
contractor, err := s.contractorRepo.FindByID(contractorID)
|
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.contractor_not_found")
|
return nil, apperrors.NotFound("error.contractor_not_found")
|
||||||
@@ -265,17 +266,17 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
if !s.hasContractorAccess(contractor, userID) {
|
if !s.hasContractorAccess(ctx, contractor, userID) {
|
||||||
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.contractorRepo.ToggleFavorite(contractorID)
|
_, err = s.contractorRepo.WithContext(ctx).ToggleFavorite(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-fetch the contractor to get the updated state with all relations
|
// Re-fetch the contractor to get the updated state with all relations
|
||||||
contractor, err = s.contractorRepo.FindByID(contractorID)
|
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -285,8 +286,8 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetContractorTasks gets all tasks for a contractor
|
// GetContractorTasks gets all tasks for a contractor
|
||||||
func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]responses.TaskResponse, error) {
|
func (s *ContractorService) GetContractorTasks(ctx context.Context, contractorID, userID uint) ([]responses.TaskResponse, error) {
|
||||||
contractor, err := s.contractorRepo.FindByID(contractorID)
|
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.contractor_not_found")
|
return nil, apperrors.NotFound("error.contractor_not_found")
|
||||||
@@ -295,11 +296,11 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
if !s.hasContractorAccess(contractor, userID) {
|
if !s.hasContractorAccess(ctx, contractor, userID) {
|
||||||
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
return nil, apperrors.Forbidden("error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks, err := s.contractorRepo.GetTasksForContractor(contractorID)
|
tasks, err := s.contractorRepo.WithContext(ctx).GetTasksForContractor(contractorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -308,9 +309,9 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListContractorsByResidence lists all contractors for a specific residence
|
// ListContractorsByResidence lists all contractors for a specific residence
|
||||||
func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint) ([]responses.ContractorResponse, error) {
|
func (s *ContractorService) ListContractorsByResidence(ctx context.Context, residenceID, userID uint) ([]responses.ContractorResponse, error) {
|
||||||
// Check user has access to the residence
|
// Check user has access to the residence
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(residenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -318,7 +319,7 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
|
|||||||
return nil, apperrors.Forbidden("error.residence_access_denied")
|
return nil, apperrors.Forbidden("error.residence_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
contractors, err := s.contractorRepo.FindByResidence(residenceID)
|
contractors, err := s.contractorRepo.WithContext(ctx).FindByResidence(residenceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -327,8 +328,8 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetSpecialties returns all contractor specialties
|
// GetSpecialties returns all contractor specialties
|
||||||
func (s *ContractorService) GetSpecialties() ([]responses.ContractorSpecialtyResponse, error) {
|
func (s *ContractorService) GetSpecialties(ctx context.Context) ([]responses.ContractorSpecialtyResponse, error) {
|
||||||
specialties, err := s.contractorRepo.GetAllSpecialties()
|
specialties, err := s.contractorRepo.WithContext(ctx).GetAllSpecialties()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -41,7 +42,7 @@ func TestContractorService_CreateContractor(t *testing.T) {
|
|||||||
Email: "bob@plumbing.com",
|
Email: "bob@plumbing.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateContractor(req, user.ID)
|
resp, err := service.CreateContractor(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
assert.Equal(t, "Bob's Plumbing", resp.Name)
|
assert.Equal(t, "Bob's Plumbing", resp.Name)
|
||||||
@@ -63,7 +64,7 @@ func TestContractorService_CreateContractor_Personal(t *testing.T) {
|
|||||||
Name: "Personal Handyman",
|
Name: "Personal Handyman",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateContractor(req, user.ID)
|
resp, err := service.CreateContractor(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Personal Handyman", resp.Name)
|
assert.Equal(t, "Personal Handyman", resp.Name)
|
||||||
}
|
}
|
||||||
@@ -84,7 +85,7 @@ func TestContractorService_CreateContractor_AccessDenied(t *testing.T) {
|
|||||||
Name: "Unauthorized Contractor",
|
Name: "Unauthorized Contractor",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.CreateContractor(req, other.ID)
|
_, err := service.CreateContractor(context.Background(), req, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +106,7 @@ func TestContractorService_CreateContractor_WithFavorite(t *testing.T) {
|
|||||||
IsFavorite: &isFav,
|
IsFavorite: &isFav,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateContractor(req, user.ID)
|
resp, err := service.CreateContractor(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, resp.IsFavorite)
|
assert.True(t, resp.IsFavorite)
|
||||||
}
|
}
|
||||||
@@ -123,7 +124,7 @@ func TestContractorService_GetContractor(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||||
|
|
||||||
resp, err := service.GetContractor(contractor.ID, user.ID)
|
resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, contractor.ID, resp.ID)
|
assert.Equal(t, contractor.ID, resp.ID)
|
||||||
assert.Equal(t, "Test Contractor", resp.Name)
|
assert.Equal(t, "Test Contractor", resp.Name)
|
||||||
@@ -138,7 +139,7 @@ func TestContractorService_GetContractor_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.GetContractor(9999, user.ID)
|
_, err := service.GetContractor(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +155,7 @@ func TestContractorService_GetContractor_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||||
|
|
||||||
_, err := service.GetContractor(contractor.ID, other.ID)
|
_, err := service.GetContractor(context.Background(), contractor.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,7 +172,7 @@ func TestContractorService_GetContractor_SharedUserHasAccess(t *testing.T) {
|
|||||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor")
|
||||||
|
|
||||||
resp, err := service.GetContractor(contractor.ID, shared.ID)
|
resp, err := service.GetContractor(context.Background(), contractor.ID, shared.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Shared Contractor", resp.Name)
|
assert.Equal(t, "Shared Contractor", resp.Name)
|
||||||
}
|
}
|
||||||
@@ -190,7 +191,7 @@ func TestContractorService_ListContractors(t *testing.T) {
|
|||||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
|
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
|
||||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
|
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
|
||||||
|
|
||||||
resp, err := service.ListContractors(user.ID)
|
resp, err := service.ListContractors(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 2)
|
assert.Len(t, resp, 2)
|
||||||
}
|
}
|
||||||
@@ -208,11 +209,11 @@ func TestContractorService_DeleteContractor(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
|
||||||
|
|
||||||
err := service.DeleteContractor(contractor.ID, user.ID)
|
err := service.DeleteContractor(context.Background(), contractor.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Should not be found after deletion
|
// Should not be found after deletion
|
||||||
_, err = service.GetContractor(contractor.ID, user.ID)
|
_, err = service.GetContractor(context.Background(), contractor.ID, user.ID)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,7 +226,7 @@ func TestContractorService_DeleteContractor_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.DeleteContractor(9999, user.ID)
|
err := service.DeleteContractor(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,7 +242,7 @@ func TestContractorService_DeleteContractor_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||||
|
|
||||||
err := service.DeleteContractor(contractor.ID, other.ID)
|
err := service.DeleteContractor(context.Background(), contractor.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,17 +260,17 @@ func TestContractorService_ToggleFavorite(t *testing.T) {
|
|||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||||
|
|
||||||
// Initially not favorite
|
// Initially not favorite
|
||||||
resp, err := service.GetContractor(contractor.ID, user.ID)
|
resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.IsFavorite)
|
assert.False(t, resp.IsFavorite)
|
||||||
|
|
||||||
// Toggle to favorite
|
// Toggle to favorite
|
||||||
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
|
resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, resp.IsFavorite)
|
assert.True(t, resp.IsFavorite)
|
||||||
|
|
||||||
// Toggle back
|
// Toggle back
|
||||||
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
|
resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.IsFavorite)
|
assert.False(t, resp.IsFavorite)
|
||||||
}
|
}
|
||||||
@@ -283,7 +284,7 @@ func TestContractorService_ToggleFavorite_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.ToggleFavorite(9999, user.ID)
|
_, err := service.ToggleFavorite(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,7 +300,7 @@ func TestContractorService_ToggleFavorite_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||||
|
|
||||||
_, err := service.ToggleFavorite(contractor.ID, other.ID)
|
_, err := service.ToggleFavorite(context.Background(), contractor.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,7 +318,7 @@ func TestContractorService_ListContractorsByResidence(t *testing.T) {
|
|||||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
|
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
|
||||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
|
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
|
||||||
|
|
||||||
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
|
resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 2)
|
assert.Len(t, resp, 2)
|
||||||
}
|
}
|
||||||
@@ -333,7 +334,7 @@ func TestContractorService_ListContractorsByResidence_AccessDenied(t *testing.T)
|
|||||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
|
|
||||||
_, err := service.ListContractorsByResidence(residence.ID, other.ID)
|
_, err := service.ListContractorsByResidence(context.Background(), residence.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,7 +349,7 @@ func TestContractorService_GetContractorTasks_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.GetContractorTasks(9999, user.ID)
|
_, err := service.GetContractorTasks(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,7 +365,7 @@ func TestContractorService_GetContractorTasks_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||||
|
|
||||||
_, err := service.GetContractorTasks(contractor.ID, other.ID)
|
_, err := service.GetContractorTasks(context.Background(), contractor.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,7 +380,7 @@ func TestContractorService_GetContractorTasks_Empty(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||||
|
|
||||||
resp, err := service.GetContractorTasks(contractor.ID, user.ID)
|
resp, err := service.GetContractorTasks(context.Background(), contractor.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -393,7 +394,7 @@ func TestContractorService_GetSpecialties(t *testing.T) {
|
|||||||
residenceRepo := repositories.NewResidenceRepository(db)
|
residenceRepo := repositories.NewResidenceRepository(db)
|
||||||
service := NewContractorService(contractorRepo, residenceRepo)
|
service := NewContractorService(contractorRepo, residenceRepo)
|
||||||
|
|
||||||
resp, err := service.GetSpecialties()
|
resp, err := service.GetSpecialties(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// SeedLookupData creates 4 specialties
|
// SeedLookupData creates 4 specialties
|
||||||
assert.Len(t, resp, 4)
|
assert.Len(t, resp, 4)
|
||||||
@@ -413,7 +414,7 @@ func TestContractorService_UpdateContractor_NotFound(t *testing.T) {
|
|||||||
newName := "Won't Work"
|
newName := "Won't Work"
|
||||||
req := &requests.UpdateContractorRequest{Name: &newName}
|
req := &requests.UpdateContractorRequest{Name: &newName}
|
||||||
|
|
||||||
_, err := service.UpdateContractor(9999, user.ID, req)
|
_, err := service.UpdateContractor(context.Background(), 9999, user.ID, req)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,7 +433,7 @@ func TestContractorService_UpdateContractor_AccessDenied(t *testing.T) {
|
|||||||
newName := "Hacked"
|
newName := "Hacked"
|
||||||
req := &requests.UpdateContractorRequest{Name: &newName}
|
req := &requests.UpdateContractorRequest{Name: &newName}
|
||||||
|
|
||||||
_, err := service.UpdateContractor(contractor.ID, other.ID, req)
|
_, err := service.UpdateContractor(context.Background(), contractor.ID, other.ID, req)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,7 +462,7 @@ func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
|
|||||||
ResidenceID: &newResidenceID,
|
ResidenceID: &newResidenceID,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.UpdateContractor(contractor.ID, attacker.ID, req)
|
_, err := service.UpdateContractor(context.Background(), contractor.ID, attacker.ID, req)
|
||||||
require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to")
|
require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to")
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
||||||
}
|
}
|
||||||
@@ -486,7 +487,7 @@ func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) {
|
|||||||
ResidenceID: &newResidenceID,
|
ResidenceID: &newResidenceID,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
|
resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req)
|
||||||
require.NoError(t, err, "should allow reassigning contractor to a residence the user owns")
|
require.NoError(t, err, "should allow reassigning contractor to a residence the user owns")
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, "Updated Contractor", resp.Name)
|
require.Equal(t, "Updated Contractor", resp.Name)
|
||||||
@@ -508,7 +509,7 @@ func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
|
|||||||
ResidenceID: nil,
|
ResidenceID: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
|
resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req)
|
||||||
require.NoError(t, err, "should allow removing residence association")
|
require.NoError(t, err, "should allow removing residence association")
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
}
|
}
|
||||||
@@ -555,7 +556,7 @@ func TestContractorService_UpdateContractor_PartialUpdate(t *testing.T) {
|
|||||||
ResidenceID: &residence.ID,
|
ResidenceID: &residence.ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
|
resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Updated Plumber", resp.Name)
|
assert.Equal(t, "Updated Plumber", resp.Name)
|
||||||
assert.Equal(t, "555-9999", resp.Phone)
|
assert.Equal(t, "555-9999", resp.Phone)
|
||||||
@@ -588,7 +589,7 @@ func TestContractorService_UpdateContractor_WithSpecialties(t *testing.T) {
|
|||||||
ResidenceID: &residence.ID,
|
ResidenceID: &residence.ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
|
resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
}
|
}
|
||||||
@@ -615,7 +616,7 @@ func TestContractorService_CreateContractor_WithSpecialties(t *testing.T) {
|
|||||||
SpecialtyIDs: []uint{specialties[0].ID},
|
SpecialtyIDs: []uint{specialties[0].ID},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateContractor(req, user.ID)
|
resp, err := service.CreateContractor(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Specialized Plumber", resp.Name)
|
assert.Equal(t, "Specialized Plumber", resp.Name)
|
||||||
}
|
}
|
||||||
@@ -631,7 +632,7 @@ func TestContractorService_ListContractors_Empty(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
// No residence, no contractors
|
// No residence, no contractors
|
||||||
resp, err := service.ListContractors(user.ID)
|
resp, err := service.ListContractors(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -648,7 +649,7 @@ func TestContractorService_ListContractorsByResidence_Empty(t *testing.T) {
|
|||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House")
|
||||||
|
|
||||||
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
|
resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -669,14 +670,14 @@ func TestContractorService_PersonalContractor_OnlyCreatorAccess(t *testing.T) {
|
|||||||
req := &requests.CreateContractorRequest{
|
req := &requests.CreateContractorRequest{
|
||||||
Name: "Personal Plumber",
|
Name: "Personal Plumber",
|
||||||
}
|
}
|
||||||
resp, err := service.CreateContractor(req, creator.ID)
|
resp, err := service.CreateContractor(context.Background(), req, creator.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Creator can access
|
// Creator can access
|
||||||
_, err = service.GetContractor(resp.ID, creator.ID)
|
_, err = service.GetContractor(context.Background(), resp.ID, creator.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Other user cannot
|
// Other user cannot
|
||||||
_, err = service.GetContractor(resp.ID, other.ID)
|
_, err = service.GetContractor(context.Background(), resp.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -34,8 +35,8 @@ func NewDocumentService(documentRepo *repositories.DocumentRepository, residence
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetDocument gets a document by ID with access check
|
// GetDocument gets a document by ID with access check
|
||||||
func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) GetDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
|
||||||
document, err := s.documentRepo.FindByID(documentID)
|
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.document_not_found")
|
return nil, apperrors.NotFound("error.document_not_found")
|
||||||
@@ -44,7 +45,7 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access via residence
|
// Check access via residence
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -57,9 +58,9 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListDocuments lists all documents accessible to a user, with optional filters.
|
// ListDocuments lists all documents accessible to a user, with optional filters.
|
||||||
func (s *DocumentService) ListDocuments(userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) {
|
func (s *DocumentService) ListDocuments(ctx context.Context, userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) {
|
||||||
// Get residence IDs (lightweight - no preloads)
|
// Get residence IDs (lightweight - no preloads)
|
||||||
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
|
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -83,7 +84,7 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume
|
|||||||
residenceIDs = []uint{*filter.ResidenceID}
|
residenceIDs = []uint{*filter.ResidenceID}
|
||||||
}
|
}
|
||||||
|
|
||||||
documents, err := s.documentRepo.FindByUserFiltered(residenceIDs, filter)
|
documents, err := s.documentRepo.WithContext(ctx).FindByUserFiltered(residenceIDs, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -92,9 +93,9 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListWarranties lists all warranty documents
|
// ListWarranties lists all warranty documents
|
||||||
func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentResponse, error) {
|
func (s *DocumentService) ListWarranties(ctx context.Context, userID uint) ([]responses.DocumentResponse, error) {
|
||||||
// Get residence IDs (lightweight - no preloads)
|
// Get residence IDs (lightweight - no preloads)
|
||||||
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
|
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -103,7 +104,7 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
|
|||||||
return []responses.DocumentResponse{}, nil
|
return []responses.DocumentResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
documents, err := s.documentRepo.FindWarranties(residenceIDs)
|
documents, err := s.documentRepo.WithContext(ctx).FindWarranties(residenceIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -112,9 +113,9 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateDocument creates a new document
|
// CreateDocument creates a new document
|
||||||
func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) CreateDocument(ctx context.Context, req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) {
|
||||||
// Check residence access
|
// Check residence access
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(req.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -147,7 +148,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
|
|||||||
IsActive: true,
|
IsActive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.documentRepo.Create(document); err != nil {
|
if err := s.documentRepo.WithContext(ctx).Create(document); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +159,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
|
|||||||
DocumentID: document.ID,
|
DocumentID: document.ID,
|
||||||
ImageURL: imageURL,
|
ImageURL: imageURL,
|
||||||
}
|
}
|
||||||
if err := s.documentRepo.CreateDocumentImage(img); err != nil {
|
if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil {
|
||||||
// Log but don't fail the whole operation
|
// Log but don't fail the whole operation
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -166,7 +167,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reload with relations
|
// Reload with relations
|
||||||
document, err = s.documentRepo.FindByID(document.ID)
|
document, err = s.documentRepo.WithContext(ctx).FindByID(document.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -176,8 +177,8 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDocument updates a document
|
// UpdateDocument updates a document
|
||||||
func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) UpdateDocument(ctx context.Context, documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) {
|
||||||
document, err := s.documentRepo.FindByID(documentID)
|
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.document_not_found")
|
return nil, apperrors.NotFound("error.document_not_found")
|
||||||
@@ -186,7 +187,7 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -238,12 +239,12 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
|
|||||||
document.TaskID = req.TaskID
|
document.TaskID = req.TaskID
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.documentRepo.Update(document); err != nil {
|
if err := s.documentRepo.WithContext(ctx).Update(document); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload
|
// Reload
|
||||||
document, err = s.documentRepo.FindByID(documentID)
|
document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -253,8 +254,8 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDocument soft-deletes a document
|
// DeleteDocument soft-deletes a document
|
||||||
func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
|
func (s *DocumentService) DeleteDocument(ctx context.Context, documentID, userID uint) error {
|
||||||
document, err := s.documentRepo.FindByID(documentID)
|
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return apperrors.NotFound("error.document_not_found")
|
return apperrors.NotFound("error.document_not_found")
|
||||||
@@ -263,7 +264,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -271,7 +272,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
|
|||||||
return apperrors.Forbidden("error.document_access_denied")
|
return apperrors.Forbidden("error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.documentRepo.Delete(documentID); err != nil {
|
if err := s.documentRepo.WithContext(ctx).Delete(documentID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,15 +280,15 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ActivateDocument activates a document
|
// ActivateDocument activates a document
|
||||||
func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) ActivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
|
||||||
// First check if document exists (even if inactive)
|
// First check if document exists (even if inactive)
|
||||||
var document models.Document
|
var document models.Document
|
||||||
if err := s.documentRepo.FindByIDIncludingInactive(documentID, &document); err != nil {
|
if err := s.documentRepo.WithContext(ctx).FindByIDIncludingInactive(documentID, &document); err != nil {
|
||||||
return nil, apperrors.NotFound("error.document_not_found")
|
return nil, apperrors.NotFound("error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -295,12 +296,12 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
|
|||||||
return nil, apperrors.Forbidden("error.document_access_denied")
|
return nil, apperrors.Forbidden("error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.documentRepo.Activate(documentID); err != nil {
|
if err := s.documentRepo.WithContext(ctx).Activate(documentID); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload
|
// Reload
|
||||||
doc, err := s.documentRepo.FindByID(documentID)
|
doc, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -310,8 +311,8 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateDocument deactivates a document
|
// DeactivateDocument deactivates a document
|
||||||
func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) DeactivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
|
||||||
document, err := s.documentRepo.FindByID(documentID)
|
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.document_not_found")
|
return nil, apperrors.NotFound("error.document_not_found")
|
||||||
@@ -320,7 +321,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access
|
// Check access
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -328,7 +329,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
|
|||||||
return nil, apperrors.Forbidden("error.document_access_denied")
|
return nil, apperrors.Forbidden("error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.documentRepo.Deactivate(documentID); err != nil {
|
if err := s.documentRepo.WithContext(ctx).Deactivate(documentID); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,8 +339,8 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UploadDocumentImage adds an image to an existing document
|
// UploadDocumentImage adds an image to an existing document
|
||||||
func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) UploadDocumentImage(ctx context.Context, documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) {
|
||||||
document, err := s.documentRepo.FindByID(documentID)
|
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.document_not_found")
|
return nil, apperrors.NotFound("error.document_not_found")
|
||||||
@@ -348,7 +349,7 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access via residence
|
// Check access via residence
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -361,12 +362,12 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
|
|||||||
ImageURL: imageURL,
|
ImageURL: imageURL,
|
||||||
Caption: caption,
|
Caption: caption,
|
||||||
}
|
}
|
||||||
if err := s.documentRepo.CreateDocumentImage(img); err != nil {
|
if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload with relations
|
// Reload with relations
|
||||||
document, err = s.documentRepo.FindByID(documentID)
|
document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -376,9 +377,9 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDocumentImage removes an image from a document
|
// DeleteDocumentImage removes an image from a document
|
||||||
func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) (*responses.DocumentResponse, error) {
|
func (s *DocumentService) DeleteDocumentImage(ctx context.Context, documentID, imageID, userID uint) (*responses.DocumentResponse, error) {
|
||||||
// Find the image first
|
// Find the image first
|
||||||
image, err := s.documentRepo.FindImageByID(imageID)
|
image, err := s.documentRepo.WithContext(ctx).FindImageByID(imageID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.document_image_not_found")
|
return nil, apperrors.NotFound("error.document_image_not_found")
|
||||||
@@ -392,7 +393,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find parent document to check access
|
// Find parent document to check access
|
||||||
document, err := s.documentRepo.FindByID(documentID)
|
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.document_not_found")
|
return nil, apperrors.NotFound("error.document_not_found")
|
||||||
@@ -401,7 +402,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check access via residence
|
// Check access via residence
|
||||||
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
|
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -409,12 +410,12 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
|
|||||||
return nil, apperrors.Forbidden("error.document_access_denied")
|
return nil, apperrors.Forbidden("error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.documentRepo.DeleteDocumentImage(imageID); err != nil {
|
if err := s.documentRepo.WithContext(ctx).DeleteDocumentImage(imageID); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload with relations
|
// Reload with relations
|
||||||
document, err = s.documentRepo.FindByID(documentID)
|
document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -42,7 +43,7 @@ func TestDocumentService_CreateDocument(t *testing.T) {
|
|||||||
FileName: "manual.pdf",
|
FileName: "manual.pdf",
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateDocument(req, user.ID)
|
resp, err := service.CreateDocument(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
assert.Equal(t, "Furnace Manual", resp.Title)
|
assert.Equal(t, "Furnace Manual", resp.Title)
|
||||||
@@ -64,7 +65,7 @@ func TestDocumentService_CreateDocument_DefaultType(t *testing.T) {
|
|||||||
// DocumentType not set — should default to "general"
|
// DocumentType not set — should default to "general"
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateDocument(req, user.ID)
|
resp, err := service.CreateDocument(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType)
|
assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType)
|
||||||
}
|
}
|
||||||
@@ -84,7 +85,7 @@ func TestDocumentService_CreateDocument_WithImages(t *testing.T) {
|
|||||||
ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"},
|
ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateDocument(req, user.ID)
|
resp, err := service.CreateDocument(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
assert.Equal(t, "Receipt with photos", resp.Title)
|
assert.Equal(t, "Receipt with photos", resp.Title)
|
||||||
@@ -105,7 +106,7 @@ func TestDocumentService_CreateDocument_AccessDenied(t *testing.T) {
|
|||||||
Title: "Unauthorized Doc",
|
Title: "Unauthorized Doc",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.CreateDocument(req, other.ID)
|
_, err := service.CreateDocument(context.Background(), req, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +122,7 @@ func TestDocumentService_GetDocument(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||||
|
|
||||||
resp, err := service.GetDocument(doc.ID, user.ID)
|
resp, err := service.GetDocument(context.Background(), doc.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, doc.ID, resp.ID)
|
assert.Equal(t, doc.ID, resp.ID)
|
||||||
assert.Equal(t, "Test Doc", resp.Title)
|
assert.Equal(t, "Test Doc", resp.Title)
|
||||||
@@ -135,7 +136,7 @@ func TestDocumentService_GetDocument_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.GetDocument(9999, user.ID)
|
_, err := service.GetDocument(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +151,7 @@ func TestDocumentService_GetDocument_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||||
|
|
||||||
_, err := service.GetDocument(doc.ID, other.ID)
|
_, err := service.GetDocument(context.Background(), doc.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,7 +174,7 @@ func TestDocumentService_UpdateDocument(t *testing.T) {
|
|||||||
Description: &newDesc,
|
Description: &newDesc,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
|
resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Updated Title", resp.Title)
|
assert.Equal(t, "Updated Title", resp.Title)
|
||||||
assert.Equal(t, "Updated description", resp.Description)
|
assert.Equal(t, "Updated description", resp.Description)
|
||||||
@@ -190,7 +191,7 @@ func TestDocumentService_UpdateDocument_NotFound(t *testing.T) {
|
|||||||
newTitle := "Won't Work"
|
newTitle := "Won't Work"
|
||||||
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
||||||
|
|
||||||
_, err := service.UpdateDocument(9999, user.ID, req)
|
_, err := service.UpdateDocument(context.Background(), 9999, user.ID, req)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,7 +209,7 @@ func TestDocumentService_UpdateDocument_AccessDenied(t *testing.T) {
|
|||||||
newTitle := "Hacked"
|
newTitle := "Hacked"
|
||||||
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
||||||
|
|
||||||
_, err := service.UpdateDocument(doc.ID, other.ID, req)
|
_, err := service.UpdateDocument(context.Background(), doc.ID, other.ID, req)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,7 +226,7 @@ func TestDocumentService_UpdateDocument_ChangeType(t *testing.T) {
|
|||||||
newType := models.DocumentTypeWarranty
|
newType := models.DocumentTypeWarranty
|
||||||
req := &requests.UpdateDocumentRequest{DocumentType: &newType}
|
req := &requests.UpdateDocumentRequest{DocumentType: &newType}
|
||||||
|
|
||||||
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
|
resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
|
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
|
||||||
}
|
}
|
||||||
@@ -242,11 +243,11 @@ func TestDocumentService_DeleteDocument(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
|
||||||
|
|
||||||
err := service.DeleteDocument(doc.ID, user.ID)
|
err := service.DeleteDocument(context.Background(), doc.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Should not be found after deletion
|
// Should not be found after deletion
|
||||||
_, err = service.GetDocument(doc.ID, user.ID)
|
_, err = service.GetDocument(context.Background(), doc.ID, user.ID)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,7 +259,7 @@ func TestDocumentService_DeleteDocument_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.DeleteDocument(9999, user.ID)
|
err := service.DeleteDocument(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,7 +274,7 @@ func TestDocumentService_DeleteDocument_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||||
|
|
||||||
err := service.DeleteDocument(doc.ID, other.ID)
|
err := service.DeleteDocument(context.Background(), doc.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,7 +291,7 @@ func TestDocumentService_ListDocuments(t *testing.T) {
|
|||||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
|
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
|
||||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
|
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
|
||||||
|
|
||||||
resp, err := service.ListDocuments(user.ID, nil)
|
resp, err := service.ListDocuments(context.Background(), user.ID, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 2)
|
assert.Len(t, resp, 2)
|
||||||
}
|
}
|
||||||
@@ -303,7 +304,7 @@ func TestDocumentService_ListDocuments_NoResidences(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||||
|
|
||||||
resp, err := service.ListDocuments(user.ID, nil)
|
resp, err := service.ListDocuments(context.Background(), user.ID, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -321,7 +322,7 @@ func TestDocumentService_ListDocuments_FilterByResidence(t *testing.T) {
|
|||||||
testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B")
|
testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B")
|
||||||
|
|
||||||
filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID}
|
filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID}
|
||||||
resp, err := service.ListDocuments(user.ID, filter)
|
resp, err := service.ListDocuments(context.Background(), user.ID, filter)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 1)
|
assert.Len(t, resp, 1)
|
||||||
assert.Equal(t, "Doc A", resp[0].Title)
|
assert.Equal(t, "Doc A", resp[0].Title)
|
||||||
@@ -340,7 +341,7 @@ func TestDocumentService_ListDocuments_FilterByResidence_AccessDenied(t *testing
|
|||||||
testutil.CreateTestResidence(t, db, other.ID, "Other House")
|
testutil.CreateTestResidence(t, db, other.ID, "Other House")
|
||||||
|
|
||||||
filter := &repositories.DocumentFilter{ResidenceID: &residence.ID}
|
filter := &repositories.DocumentFilter{ResidenceID: &residence.ID}
|
||||||
_, err := service.ListDocuments(other.ID, filter)
|
_, err := service.ListDocuments(context.Background(), other.ID, filter)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -369,7 +370,7 @@ func TestDocumentService_ListWarranties(t *testing.T) {
|
|||||||
// Create a general doc
|
// Create a general doc
|
||||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
||||||
|
|
||||||
resp, err := service.ListWarranties(user.ID)
|
resp, err := service.ListWarranties(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 1)
|
assert.Len(t, resp, 1)
|
||||||
assert.Equal(t, "HVAC Warranty", resp[0].Title)
|
assert.Equal(t, "HVAC Warranty", resp[0].Title)
|
||||||
@@ -383,7 +384,7 @@ func TestDocumentService_ListWarranties_NoResidences(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||||
|
|
||||||
resp, err := service.ListWarranties(user.ID)
|
resp, err := service.ListWarranties(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -400,7 +401,7 @@ func TestDocumentService_DeactivateDocument(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate")
|
||||||
|
|
||||||
resp, err := service.DeactivateDocument(doc.ID, user.ID)
|
resp, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.IsActive)
|
assert.False(t, resp.IsActive)
|
||||||
}
|
}
|
||||||
@@ -413,7 +414,7 @@ func TestDocumentService_DeactivateDocument_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.DeactivateDocument(9999, user.ID)
|
_, err := service.DeactivateDocument(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,7 +429,7 @@ func TestDocumentService_DeactivateDocument_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||||
|
|
||||||
_, err := service.DeactivateDocument(doc.ID, other.ID)
|
_, err := service.DeactivateDocument(context.Background(), doc.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -444,7 +445,7 @@ func TestDocumentService_UploadDocumentImage(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||||
|
|
||||||
resp, err := service.UploadDocumentImage(doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
|
resp, err := service.UploadDocumentImage(context.Background(), doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
}
|
}
|
||||||
@@ -457,7 +458,7 @@ func TestDocumentService_UploadDocumentImage_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.UploadDocumentImage(9999, user.ID, "https://example.com/photo.jpg", "")
|
_, err := service.UploadDocumentImage(context.Background(), 9999, user.ID, "https://example.com/photo.jpg", "")
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,7 +473,7 @@ func TestDocumentService_UploadDocumentImage_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||||
|
|
||||||
_, err := service.UploadDocumentImage(doc.ID, other.ID, "https://example.com/photo.jpg", "")
|
_, err := service.UploadDocumentImage(context.Background(), doc.ID, other.ID, "https://example.com/photo.jpg", "")
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -496,7 +497,7 @@ func TestDocumentService_DeleteDocumentImage(t *testing.T) {
|
|||||||
err := db.Create(img).Error
|
err := db.Create(img).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp, err := service.DeleteDocumentImage(doc.ID, img.ID, user.ID)
|
resp, err := service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
}
|
}
|
||||||
@@ -511,7 +512,7 @@ func TestDocumentService_DeleteDocumentImage_NotFound(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||||
|
|
||||||
_, err := service.DeleteDocumentImage(doc.ID, 9999, user.ID)
|
_, err := service.DeleteDocumentImage(context.Background(), doc.ID, 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,7 +536,7 @@ func TestDocumentService_DeleteDocumentImage_WrongDocument(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Try to delete the image specifying doc2
|
// Try to delete the image specifying doc2
|
||||||
_, err = service.DeleteDocumentImage(doc2.ID, img.ID, user.ID)
|
_, err = service.DeleteDocumentImage(context.Background(), doc2.ID, img.ID, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -557,7 +558,7 @@ func TestDocumentService_DeleteDocumentImage_AccessDenied(t *testing.T) {
|
|||||||
err := db.Create(img).Error
|
err := db.Create(img).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = service.DeleteDocumentImage(doc.ID, img.ID, other.ID)
|
_, err = service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -575,7 +576,7 @@ func TestDocumentService_GetDocument_SharedUserHasAccess(t *testing.T) {
|
|||||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
||||||
|
|
||||||
resp, err := service.GetDocument(doc.ID, shared.ID)
|
resp, err := service.GetDocument(context.Background(), doc.ID, shared.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Shared Doc", resp.Title)
|
assert.Equal(t, "Shared Doc", resp.Title)
|
||||||
}
|
}
|
||||||
@@ -593,11 +594,11 @@ func TestDocumentService_ActivateDocument(t *testing.T) {
|
|||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate")
|
||||||
|
|
||||||
// Deactivate first
|
// Deactivate first
|
||||||
_, err := service.DeactivateDocument(doc.ID, user.ID)
|
_, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Now activate
|
// Now activate
|
||||||
resp, err := service.ActivateDocument(doc.ID, user.ID)
|
resp, err := service.ActivateDocument(context.Background(), doc.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, resp.IsActive)
|
assert.True(t, resp.IsActive)
|
||||||
}
|
}
|
||||||
@@ -610,7 +611,7 @@ func TestDocumentService_ActivateDocument_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
_, err := service.ActivateDocument(9999, user.ID)
|
_, err := service.ActivateDocument(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -625,7 +626,7 @@ func TestDocumentService_ActivateDocument_AccessDenied(t *testing.T) {
|
|||||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||||
|
|
||||||
_, err := service.ActivateDocument(doc.ID, other.ID)
|
_, err := service.ActivateDocument(context.Background(), doc.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -646,7 +647,7 @@ func TestDocumentService_CreateDocument_WithEmptyImageURL(t *testing.T) {
|
|||||||
ImageURLs: []string{"", "https://example.com/img.jpg", ""},
|
ImageURLs: []string{"", "https://example.com/img.jpg", ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.CreateDocument(req, user.ID)
|
resp, err := service.CreateDocument(context.Background(), req, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
}
|
}
|
||||||
@@ -687,7 +688,7 @@ func TestDocumentService_UpdateDocument_AllFields(t *testing.T) {
|
|||||||
ModelNumber: &newModel,
|
ModelNumber: &newModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
|
resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Updated", resp.Title)
|
assert.Equal(t, "Updated", resp.Title)
|
||||||
assert.Equal(t, "New description", resp.Description)
|
assert.Equal(t, "New description", resp.Description)
|
||||||
@@ -720,7 +721,7 @@ func TestDocumentService_ListDocuments_FilterByType(t *testing.T) {
|
|||||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
||||||
|
|
||||||
filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)}
|
filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)}
|
||||||
resp, err := service.ListDocuments(user.ID, filter)
|
resp, err := service.ListDocuments(context.Background(), user.ID, filter)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 1)
|
assert.Len(t, resp, 1)
|
||||||
assert.Equal(t, "Warranty Doc", resp[0].Title)
|
assert.Equal(t, "Warranty Doc", resp[0].Title)
|
||||||
@@ -742,7 +743,7 @@ func TestDocumentService_SharedUser_CanUpdate(t *testing.T) {
|
|||||||
|
|
||||||
newTitle := "Updated by shared user"
|
newTitle := "Updated by shared user"
|
||||||
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
||||||
resp, err := service.UpdateDocument(doc.ID, shared.ID, req)
|
resp, err := service.UpdateDocument(context.Background(), doc.ID, shared.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Updated by shared user", resp.Title)
|
assert.Equal(t, "Updated by shared user", resp.Title)
|
||||||
}
|
}
|
||||||
@@ -759,6 +760,6 @@ func TestDocumentService_SharedUser_CanDelete(t *testing.T) {
|
|||||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
||||||
|
|
||||||
err := service.DeleteDocument(doc.ID, shared.ID)
|
err := service.DeleteDocument(context.Background(), doc.ID, shared.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ func NewNotificationService(notificationRepo *repositories.NotificationRepositor
|
|||||||
// === Notifications ===
|
// === Notifications ===
|
||||||
|
|
||||||
// GetNotifications gets notifications for a user
|
// GetNotifications gets notifications for a user
|
||||||
func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ([]NotificationResponse, error) {
|
func (s *NotificationService) GetNotifications(ctx context.Context, userID uint, limit, offset int) ([]NotificationResponse, error) {
|
||||||
notifications, err := s.notificationRepo.FindByUser(userID, limit, offset)
|
notifications, err := s.notificationRepo.WithContext(ctx).FindByUser(userID, limit, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -58,8 +58,8 @@ func (s *NotificationService) GetNotifications(userID uint, limit, offset int) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUnreadCount gets the count of unread notifications
|
// GetUnreadCount gets the count of unread notifications
|
||||||
func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
|
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID uint) (int64, error) {
|
||||||
count, err := s.notificationRepo.CountUnread(userID)
|
count, err := s.notificationRepo.WithContext(ctx).CountUnread(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, apperrors.Internal(err)
|
return 0, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -67,8 +67,8 @@ func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkAsRead marks a notification as read
|
// MarkAsRead marks a notification as read
|
||||||
func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
|
func (s *NotificationService) MarkAsRead(ctx context.Context, notificationID, userID uint) error {
|
||||||
notification, err := s.notificationRepo.FindByID(notificationID)
|
notification, err := s.notificationRepo.WithContext(ctx).FindByID(notificationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return apperrors.NotFound("error.notification_not_found")
|
return apperrors.NotFound("error.notification_not_found")
|
||||||
@@ -80,15 +80,15 @@ func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
|
|||||||
return apperrors.NotFound("error.notification_not_found")
|
return apperrors.NotFound("error.notification_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.notificationRepo.MarkAsRead(notificationID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).MarkAsRead(notificationID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkAllAsRead marks all notifications as read
|
// MarkAllAsRead marks all notifications as read
|
||||||
func (s *NotificationService) MarkAllAsRead(userID uint) error {
|
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID uint) error {
|
||||||
if err := s.notificationRepo.MarkAllAsRead(userID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).MarkAllAsRead(userID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -97,7 +97,7 @@ func (s *NotificationService) MarkAllAsRead(userID uint) error {
|
|||||||
// CreateAndSendNotification creates a notification and sends it via push
|
// CreateAndSendNotification creates a notification and sends it via push
|
||||||
func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error {
|
func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error {
|
||||||
// Check user preferences
|
// Check user preferences
|
||||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -117,12 +117,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
|
|||||||
Data: string(dataJSON),
|
Data: string(dataJSON),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.notificationRepo.Create(notification); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get device tokens
|
// Get device tokens
|
||||||
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
|
iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -144,12 +144,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
|
|||||||
if s.pushClient != nil {
|
if s.pushClient != nil {
|
||||||
err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData)
|
err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.notificationRepo.SetError(notification.ID, err.Error())
|
s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error())
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -178,8 +178,8 @@ func (s *NotificationService) isNotificationEnabled(prefs *models.NotificationPr
|
|||||||
// === Notification Preferences ===
|
// === Notification Preferences ===
|
||||||
|
|
||||||
// GetPreferences gets notification preferences for a user
|
// GetPreferences gets notification preferences for a user
|
||||||
func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferencesResponse, error) {
|
func (s *NotificationService) GetPreferences(ctx context.Context, userID uint) (*NotificationPreferencesResponse, error) {
|
||||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -196,7 +196,7 @@ func validateHourField(val *int, fieldName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePreferences updates notification preferences
|
// UpdatePreferences updates notification preferences
|
||||||
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
|
func (s *NotificationService) UpdatePreferences(ctx context.Context, userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
|
||||||
// B-12: Validate hour fields are in range 0-23
|
// B-12: Validate hour fields are in range 0-23
|
||||||
hourFields := []struct {
|
hourFields := []struct {
|
||||||
value *int
|
value *int
|
||||||
@@ -213,7 +213,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -258,7 +258,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
|
|||||||
prefs.DailyDigestHour = req.DailyDigestHour
|
prefs.DailyDigestHour = req.DailyDigestHour
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,14 +268,14 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
|
|||||||
// UpdateUserTimezone updates the user's timezone for background job calculations.
|
// UpdateUserTimezone updates the user's timezone for background job calculations.
|
||||||
// This is called automatically when the user makes API calls (e.g., fetching tasks).
|
// This is called automatically when the user makes API calls (e.g., fetching tasks).
|
||||||
// The timezone should be an IANA timezone name (e.g., "America/Los_Angeles").
|
// The timezone should be an IANA timezone name (e.g., "America/Los_Angeles").
|
||||||
func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
|
func (s *NotificationService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) {
|
||||||
// Validate timezone is a valid IANA name
|
// Validate timezone is a valid IANA name
|
||||||
if _, err := time.LoadLocation(timezone); err != nil {
|
if _, err := time.LoadLocation(timezone); err != nil {
|
||||||
return // Invalid timezone, skip silently
|
return // Invalid timezone, skip silently
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get or create preferences and update timezone
|
// Get or create preferences and update timezone
|
||||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return // Skip silently on error
|
return // Skip silently on error
|
||||||
}
|
}
|
||||||
@@ -283,7 +283,7 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
|
|||||||
// Only update if timezone changed (avoid unnecessary DB writes)
|
// Only update if timezone changed (avoid unnecessary DB writes)
|
||||||
if prefs.Timezone == nil || *prefs.Timezone != timezone {
|
if prefs.Timezone == nil || *prefs.Timezone != timezone {
|
||||||
prefs.Timezone = &timezone
|
prefs.Timezone = &timezone
|
||||||
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil {
|
||||||
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
|
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
|
||||||
Msg("Failed to update user timezone in notification preferences")
|
Msg("Failed to update user timezone in notification preferences")
|
||||||
}
|
}
|
||||||
@@ -293,27 +293,27 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
|
|||||||
// === Device Registration ===
|
// === Device Registration ===
|
||||||
|
|
||||||
// RegisterDevice registers a device for push notifications
|
// RegisterDevice registers a device for push notifications
|
||||||
func (s *NotificationService) RegisterDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
|
func (s *NotificationService) RegisterDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
|
||||||
switch req.Platform {
|
switch req.Platform {
|
||||||
case push.PlatformIOS:
|
case push.PlatformIOS:
|
||||||
return s.registerAPNSDevice(userID, req)
|
return s.registerAPNSDevice(ctx, userID, req)
|
||||||
case push.PlatformAndroid:
|
case push.PlatformAndroid:
|
||||||
return s.registerGCMDevice(userID, req)
|
return s.registerGCMDevice(ctx, userID, req)
|
||||||
default:
|
default:
|
||||||
return nil, apperrors.BadRequest("error.invalid_platform")
|
return nil, apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
|
func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
|
||||||
// Check if device exists
|
// Check if device exists
|
||||||
existing, err := s.notificationRepo.FindAPNSDeviceByToken(req.RegistrationID)
|
existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Update existing device
|
// Update existing device
|
||||||
existing.UserID = &userID
|
existing.UserID = &userID
|
||||||
existing.Active = true
|
existing.Active = true
|
||||||
existing.Name = req.Name
|
existing.Name = req.Name
|
||||||
existing.DeviceID = req.DeviceID
|
existing.DeviceID = req.DeviceID
|
||||||
if err := s.notificationRepo.UpdateAPNSDevice(existing); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).UpdateAPNSDevice(existing); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return NewAPNSDeviceResponse(existing), nil
|
return NewAPNSDeviceResponse(existing), nil
|
||||||
@@ -327,22 +327,22 @@ func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDevic
|
|||||||
RegistrationID: req.RegistrationID,
|
RegistrationID: req.RegistrationID,
|
||||||
Active: true,
|
Active: true,
|
||||||
}
|
}
|
||||||
if err := s.notificationRepo.CreateAPNSDevice(device); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).CreateAPNSDevice(device); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return NewAPNSDeviceResponse(device), nil
|
return NewAPNSDeviceResponse(device), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
|
func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
|
||||||
// Check if device exists
|
// Check if device exists
|
||||||
existing, err := s.notificationRepo.FindGCMDeviceByToken(req.RegistrationID)
|
existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Update existing device
|
// Update existing device
|
||||||
existing.UserID = &userID
|
existing.UserID = &userID
|
||||||
existing.Active = true
|
existing.Active = true
|
||||||
existing.Name = req.Name
|
existing.Name = req.Name
|
||||||
existing.DeviceID = req.DeviceID
|
existing.DeviceID = req.DeviceID
|
||||||
if err := s.notificationRepo.UpdateGCMDevice(existing); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).UpdateGCMDevice(existing); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return NewGCMDeviceResponse(existing), nil
|
return NewGCMDeviceResponse(existing), nil
|
||||||
@@ -357,20 +357,20 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice
|
|||||||
CloudMessageType: "FCM",
|
CloudMessageType: "FCM",
|
||||||
Active: true,
|
Active: true,
|
||||||
}
|
}
|
||||||
if err := s.notificationRepo.CreateGCMDevice(device); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).CreateGCMDevice(device); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return NewGCMDeviceResponse(device), nil
|
return NewGCMDeviceResponse(device), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListDevices lists all devices for a user
|
// ListDevices lists all devices for a user
|
||||||
func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) {
|
func (s *NotificationService) ListDevices(ctx context.Context, userID uint) ([]DeviceResponse, error) {
|
||||||
iosDevices, err := s.notificationRepo.FindAPNSDevicesByUser(userID)
|
iosDevices, err := s.notificationRepo.WithContext(ctx).FindAPNSDevicesByUser(userID)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
androidDevices, err := s.notificationRepo.FindGCMDevicesByUser(userID)
|
androidDevices, err := s.notificationRepo.WithContext(ctx).FindGCMDevicesByUser(userID)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -387,10 +387,10 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error)
|
|||||||
|
|
||||||
// DeleteDevice deactivates a device after verifying it belongs to the requesting user.
|
// DeleteDevice deactivates a device after verifying it belongs to the requesting user.
|
||||||
// Without ownership verification, an attacker could deactivate push notifications for other users.
|
// Without ownership verification, an attacker could deactivate push notifications for other users.
|
||||||
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error {
|
func (s *NotificationService) DeleteDevice(ctx context.Context, deviceID uint, platform string, userID uint) error {
|
||||||
switch platform {
|
switch platform {
|
||||||
case push.PlatformIOS:
|
case push.PlatformIOS:
|
||||||
device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID)
|
device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByID(deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return apperrors.NotFound("error.device_not_found")
|
return apperrors.NotFound("error.device_not_found")
|
||||||
@@ -401,11 +401,11 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
|
|||||||
if device.UserID == nil || *device.UserID != userID {
|
if device.UserID == nil || *device.UserID != userID {
|
||||||
return apperrors.Forbidden("error.device_access_denied")
|
return apperrors.Forbidden("error.device_access_denied")
|
||||||
}
|
}
|
||||||
if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(deviceID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
case push.PlatformAndroid:
|
case push.PlatformAndroid:
|
||||||
device, err := s.notificationRepo.FindGCMDeviceByID(deviceID)
|
device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByID(deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return apperrors.NotFound("error.device_not_found")
|
return apperrors.NotFound("error.device_not_found")
|
||||||
@@ -416,7 +416,7 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
|
|||||||
if device.UserID == nil || *device.UserID != userID {
|
if device.UserID == nil || *device.UserID != userID {
|
||||||
return apperrors.Forbidden("error.device_access_denied")
|
return apperrors.Forbidden("error.device_access_denied")
|
||||||
}
|
}
|
||||||
if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(deviceID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -426,10 +426,10 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UnregisterDevice deactivates a device by its registration token
|
// UnregisterDevice deactivates a device by its registration token
|
||||||
func (s *NotificationService) UnregisterDevice(registrationID, platform string, userID uint) error {
|
func (s *NotificationService) UnregisterDevice(ctx context.Context, registrationID, platform string, userID uint) error {
|
||||||
switch platform {
|
switch platform {
|
||||||
case push.PlatformIOS:
|
case push.PlatformIOS:
|
||||||
device, err := s.notificationRepo.FindAPNSDeviceByToken(registrationID)
|
device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(registrationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.NotFound("error.device_not_found")
|
return apperrors.NotFound("error.device_not_found")
|
||||||
}
|
}
|
||||||
@@ -437,11 +437,11 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
|
|||||||
if device.UserID == nil || *device.UserID != userID {
|
if device.UserID == nil || *device.UserID != userID {
|
||||||
return apperrors.NotFound("error.device_not_found")
|
return apperrors.NotFound("error.device_not_found")
|
||||||
}
|
}
|
||||||
if err := s.notificationRepo.DeactivateAPNSDevice(device.ID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(device.ID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
case push.PlatformAndroid:
|
case push.PlatformAndroid:
|
||||||
device, err := s.notificationRepo.FindGCMDeviceByToken(registrationID)
|
device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(registrationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.NotFound("error.device_not_found")
|
return apperrors.NotFound("error.device_not_found")
|
||||||
}
|
}
|
||||||
@@ -449,7 +449,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
|
|||||||
if device.UserID == nil || *device.UserID != userID {
|
if device.UserID == nil || *device.UserID != userID {
|
||||||
return apperrors.NotFound("error.device_not_found")
|
return apperrors.NotFound("error.device_not_found")
|
||||||
}
|
}
|
||||||
if err := s.notificationRepo.DeactivateGCMDevice(device.ID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(device.ID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -624,7 +624,7 @@ func (s *NotificationService) CreateAndSendTaskNotification(
|
|||||||
task *models.Task,
|
task *models.Task,
|
||||||
) error {
|
) error {
|
||||||
// Check user notification preferences
|
// Check user notification preferences
|
||||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -662,12 +662,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
|
|||||||
TaskID: &task.ID,
|
TaskID: &task.ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.notificationRepo.Create(notification); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get device tokens
|
// Get device tokens
|
||||||
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
|
iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -691,12 +691,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
|
|||||||
if s.pushClient != nil {
|
if s.pushClient != nil {
|
||||||
err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID)
|
err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.notificationRepo.SetError(notification.ID, err.Error())
|
s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error())
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil {
|
if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func TestNotificationService_GetNotifications(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.GetNotifications(user.ID, 10, 0)
|
resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 3)
|
assert.Len(t, resp, 3)
|
||||||
}
|
}
|
||||||
@@ -56,7 +56,7 @@ func TestNotificationService_GetNotifications_Empty(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
resp, err := service.GetNotifications(user.ID, 10, 0)
|
resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -80,7 +80,7 @@ func TestNotificationService_GetNotifications_Pagination(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get first 2
|
// Get first 2
|
||||||
resp, err := service.GetNotifications(user.ID, 2, 0)
|
resp, err := service.GetNotifications(context.Background(), user.ID, 2, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 2)
|
assert.Len(t, resp, 2)
|
||||||
}
|
}
|
||||||
@@ -107,7 +107,7 @@ func TestNotificationService_GetUnreadCount(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := service.GetUnreadCount(user.ID)
|
count, err := service.GetUnreadCount(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, int64(3), count)
|
assert.Equal(t, int64(3), count)
|
||||||
}
|
}
|
||||||
@@ -119,7 +119,7 @@ func TestNotificationService_GetUnreadCount_Zero(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
count, err := service.GetUnreadCount(user.ID)
|
count, err := service.GetUnreadCount(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, int64(0), count)
|
assert.Equal(t, int64(0), count)
|
||||||
}
|
}
|
||||||
@@ -142,11 +142,11 @@ func TestNotificationService_MarkAsRead(t *testing.T) {
|
|||||||
err := db.Create(notif).Error
|
err := db.Create(notif).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.MarkAsRead(notif.ID, user.ID)
|
err = service.MarkAsRead(context.Background(), notif.ID, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify unread count is 0
|
// Verify unread count is 0
|
||||||
count, err := service.GetUnreadCount(user.ID)
|
count, err := service.GetUnreadCount(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, int64(0), count)
|
assert.Equal(t, int64(0), count)
|
||||||
}
|
}
|
||||||
@@ -168,7 +168,7 @@ func TestNotificationService_MarkAsRead_WrongUser(t *testing.T) {
|
|||||||
err := db.Create(notif).Error
|
err := db.Create(notif).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.MarkAsRead(notif.ID, other.ID)
|
err = service.MarkAsRead(context.Background(), notif.ID, other.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,7 +179,7 @@ func TestNotificationService_MarkAsRead_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.MarkAsRead(9999, user.ID)
|
err := service.MarkAsRead(context.Background(), 9999, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,10 +204,10 @@ func TestNotificationService_MarkAllAsRead(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.MarkAllAsRead(user.ID)
|
err := service.MarkAllAsRead(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
count, err := service.GetUnreadCount(user.ID)
|
count, err := service.GetUnreadCount(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, int64(0), count)
|
assert.Equal(t, int64(0), count)
|
||||||
}
|
}
|
||||||
@@ -229,7 +229,7 @@ func TestNotificationService_CreateAndSendNotification(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify notification was created
|
// Verify notification was created
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
assert.Equal(t, "Due Soon", notifs[0].Title)
|
assert.Equal(t, "Due Soon", notifs[0].Title)
|
||||||
@@ -254,7 +254,7 @@ func TestNotificationService_CreateAndSendNotification_DisabledPreference(t *tes
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify no notification was created (silently skipped)
|
// Verify no notification was created (silently skipped)
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, notifs)
|
assert.Empty(t, notifs)
|
||||||
}
|
}
|
||||||
@@ -268,7 +268,7 @@ func TestNotificationService_GetPreferences(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
resp, err := service.GetPreferences(user.ID)
|
resp, err := service.GetPreferences(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, resp)
|
assert.NotNil(t, resp)
|
||||||
// Defaults should all be true
|
// Defaults should all be true
|
||||||
@@ -289,7 +289,7 @@ func TestNotificationService_UpdatePreferences(t *testing.T) {
|
|||||||
TaskDueSoon: &falseVal,
|
TaskDueSoon: &falseVal,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdatePreferences(user.ID, req)
|
resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.TaskDueSoon)
|
assert.False(t, resp.TaskDueSoon)
|
||||||
assert.True(t, resp.TaskOverdue) // unchanged
|
assert.True(t, resp.TaskOverdue) // unchanged
|
||||||
@@ -307,7 +307,7 @@ func TestNotificationService_UpdatePreferences_InvalidHour(t *testing.T) {
|
|||||||
TaskDueSoonHour: &invalidHour,
|
TaskDueSoonHour: &invalidHour,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.UpdatePreferences(user.ID, req)
|
_, err := service.UpdatePreferences(context.Background(), user.ID, req)
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
|
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,7 +323,7 @@ func TestNotificationService_UpdatePreferences_ValidHour(t *testing.T) {
|
|||||||
TaskDueSoonHour: &hour,
|
TaskDueSoonHour: &hour,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdatePreferences(user.ID, req)
|
resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 9, *resp.TaskDueSoonHour)
|
assert.Equal(t, 9, *resp.TaskDueSoonHour)
|
||||||
}
|
}
|
||||||
@@ -344,7 +344,7 @@ func TestNotificationService_RegisterDevice_iOS(t *testing.T) {
|
|||||||
Platform: push.PlatformIOS,
|
Platform: push.PlatformIOS,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.RegisterDevice(user.ID, req)
|
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "iPhone 15", resp.Name)
|
assert.Equal(t, "iPhone 15", resp.Name)
|
||||||
assert.Equal(t, push.PlatformIOS, resp.Platform)
|
assert.Equal(t, push.PlatformIOS, resp.Platform)
|
||||||
@@ -365,7 +365,7 @@ func TestNotificationService_RegisterDevice_Android(t *testing.T) {
|
|||||||
Platform: push.PlatformAndroid,
|
Platform: push.PlatformAndroid,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.RegisterDevice(user.ID, req)
|
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Pixel 8", resp.Name)
|
assert.Equal(t, "Pixel 8", resp.Name)
|
||||||
assert.Equal(t, push.PlatformAndroid, resp.Platform)
|
assert.Equal(t, push.PlatformAndroid, resp.Platform)
|
||||||
@@ -386,7 +386,7 @@ func TestNotificationService_RegisterDevice_InvalidPlatform(t *testing.T) {
|
|||||||
Platform: "windows",
|
Platform: "windows",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.RegisterDevice(user.ID, req)
|
_, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -404,12 +404,12 @@ func TestNotificationService_RegisterDevice_UpdateExisting(t *testing.T) {
|
|||||||
RegistrationID: "token-xyz",
|
RegistrationID: "token-xyz",
|
||||||
Platform: push.PlatformIOS,
|
Platform: push.PlatformIOS,
|
||||||
}
|
}
|
||||||
_, err := service.RegisterDevice(user.ID, req)
|
_, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Re-register with same token (should update, not duplicate)
|
// Re-register with same token (should update, not duplicate)
|
||||||
req.Name = "iPhone 15 Pro"
|
req.Name = "iPhone 15 Pro"
|
||||||
resp, err := service.RegisterDevice(user.ID, req)
|
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "iPhone 15 Pro", resp.Name)
|
assert.Equal(t, "iPhone 15 Pro", resp.Name)
|
||||||
}
|
}
|
||||||
@@ -445,7 +445,7 @@ func TestNotificationService_ListDevices(t *testing.T) {
|
|||||||
err = db.Create(androidDevice).Error
|
err = db.Create(androidDevice).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp, err := service.ListDevices(user.ID)
|
resp, err := service.ListDevices(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, resp, 2)
|
assert.Len(t, resp, 2)
|
||||||
}
|
}
|
||||||
@@ -457,7 +457,7 @@ func TestNotificationService_ListDevices_Empty(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
resp, err := service.ListDevices(user.ID)
|
resp, err := service.ListDevices(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, resp)
|
assert.Empty(t, resp)
|
||||||
}
|
}
|
||||||
@@ -471,7 +471,7 @@ func TestDeleteDevice_InvalidPlatform_Returns400(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.DeleteDevice(1, "windows", user.ID)
|
err := service.DeleteDevice(context.Background(), 1, "windows", user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,7 +494,7 @@ func TestNotificationService_UnregisterDevice_iOS(t *testing.T) {
|
|||||||
err := db.Create(device).Error
|
err := db.Create(device).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.UnregisterDevice("reg-token-ios", push.PlatformIOS, user.ID)
|
err = service.UnregisterDevice(context.Background(), "reg-token-ios", push.PlatformIOS, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify device is deactivated
|
// Verify device is deactivated
|
||||||
@@ -522,7 +522,7 @@ func TestNotificationService_UnregisterDevice_Android(t *testing.T) {
|
|||||||
err := db.Create(device).Error
|
err := db.Create(device).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.UnregisterDevice("reg-token-android", push.PlatformAndroid, user.ID)
|
err = service.UnregisterDevice(context.Background(), "reg-token-android", push.PlatformAndroid, user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var found models.GCMDevice
|
var found models.GCMDevice
|
||||||
@@ -538,7 +538,7 @@ func TestNotificationService_UnregisterDevice_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.UnregisterDevice("nonexistent-token", push.PlatformIOS, user.ID)
|
err := service.UnregisterDevice(context.Background(), "nonexistent-token", push.PlatformIOS, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -560,7 +560,7 @@ func TestNotificationService_UnregisterDevice_WrongUser(t *testing.T) {
|
|||||||
err := db.Create(device).Error
|
err := db.Create(device).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.UnregisterDevice("owner-token", push.PlatformIOS, attacker.ID)
|
err = service.UnregisterDevice(context.Background(), "owner-token", push.PlatformIOS, attacker.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -571,7 +571,7 @@ func TestNotificationService_UnregisterDevice_InvalidPlatform(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.UnregisterDevice("some-token", "windows", user.ID)
|
err := service.UnregisterDevice(context.Background(), "some-token", "windows", user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
|
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -585,7 +585,7 @@ func TestNotificationService_UpdateUserTimezone(t *testing.T) {
|
|||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
// Should not panic, just silently update
|
// Should not panic, just silently update
|
||||||
service.UpdateUserTimezone(user.ID, "America/Los_Angeles")
|
service.UpdateUserTimezone(context.Background(), user.ID, "America/Los_Angeles")
|
||||||
|
|
||||||
// Verify timezone was stored
|
// Verify timezone was stored
|
||||||
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
|
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
|
||||||
@@ -602,7 +602,7 @@ func TestNotificationService_UpdateUserTimezone_Invalid(t *testing.T) {
|
|||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
// Invalid timezone should be silently ignored
|
// Invalid timezone should be silently ignored
|
||||||
service.UpdateUserTimezone(user.ID, "Invalid/Timezone")
|
service.UpdateUserTimezone(context.Background(), user.ID, "Invalid/Timezone")
|
||||||
|
|
||||||
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
|
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -617,10 +617,10 @@ func TestNotificationService_UpdateUserTimezone_NoChangeSkipsWrite(t *testing.T)
|
|||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
// Set timezone
|
// Set timezone
|
||||||
service.UpdateUserTimezone(user.ID, "America/New_York")
|
service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York")
|
||||||
|
|
||||||
// Set same timezone again — should be a no-op
|
// Set same timezone again — should be a no-op
|
||||||
service.UpdateUserTimezone(user.ID, "America/New_York")
|
service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York")
|
||||||
|
|
||||||
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
|
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -648,7 +648,7 @@ func TestDeleteDevice_WrongUser_Returns403(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Attacker tries to deactivate the owner's device
|
// Attacker tries to deactivate the owner's device
|
||||||
err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID)
|
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, attacker.ID)
|
||||||
require.Error(t, err, "should not allow deleting another user's device")
|
require.Error(t, err, "should not allow deleting another user's device")
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
||||||
|
|
||||||
@@ -678,7 +678,7 @@ func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Owner deactivates their own device
|
// Owner deactivates their own device
|
||||||
err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID)
|
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, owner.ID)
|
||||||
require.NoError(t, err, "owner should be able to deactivate their own device")
|
require.NoError(t, err, "owner should be able to deactivate their own device")
|
||||||
|
|
||||||
// Verify the device is now inactive
|
// Verify the device is now inactive
|
||||||
@@ -709,7 +709,7 @@ func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Attacker tries to deactivate the owner's Android device
|
// Attacker tries to deactivate the owner's Android device
|
||||||
err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID)
|
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, attacker.ID)
|
||||||
require.Error(t, err, "should not allow deleting another user's Android device")
|
require.Error(t, err, "should not allow deleting another user's Android device")
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
||||||
|
|
||||||
@@ -727,7 +727,7 @@ func TestDeleteDevice_NonExistent_Returns404(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||||
|
|
||||||
err := service.DeleteDevice(99999, push.PlatformIOS, user.ID)
|
err := service.DeleteDevice(context.Background(), 99999, push.PlatformIOS, user.ID)
|
||||||
require.Error(t, err, "should return error for non-existent device")
|
require.Error(t, err, "should return error for non-existent device")
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
|
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
|
||||||
}
|
}
|
||||||
@@ -744,7 +744,7 @@ func TestNotificationService_CreateAndSend_TaskOverdue(t *testing.T) {
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Task is overdue", nil)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Task is overdue", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
assert.Equal(t, "Overdue", notifs[0].Title)
|
assert.Equal(t, "Overdue", notifs[0].Title)
|
||||||
@@ -760,7 +760,7 @@ func TestNotificationService_CreateAndSend_TaskCompleted(t *testing.T) {
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
}
|
}
|
||||||
@@ -775,7 +775,7 @@ func TestNotificationService_CreateAndSend_TaskAssigned(t *testing.T) {
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned to you", nil)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned to you", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
}
|
}
|
||||||
@@ -790,7 +790,7 @@ func TestNotificationService_CreateAndSend_ResidenceShared(t *testing.T) {
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Someone shared a home", nil)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Someone shared a home", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
}
|
}
|
||||||
@@ -805,7 +805,7 @@ func TestNotificationService_CreateAndSend_WarrantyExpiring(t *testing.T) {
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Expiring", "Warranty expiring soon", nil)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Expiring", "Warranty expiring soon", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
}
|
}
|
||||||
@@ -828,7 +828,7 @@ func TestNotificationService_DisabledPrefs_TaskOverdue(t *testing.T) {
|
|||||||
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Overdue task", nil)
|
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Overdue task", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, notifs)
|
assert.Empty(t, notifs)
|
||||||
}
|
}
|
||||||
@@ -849,7 +849,7 @@ func TestNotificationService_DisabledPrefs_TaskCompleted(t *testing.T) {
|
|||||||
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
|
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, notifs)
|
assert.Empty(t, notifs)
|
||||||
}
|
}
|
||||||
@@ -870,7 +870,7 @@ func TestNotificationService_DisabledPrefs_TaskAssigned(t *testing.T) {
|
|||||||
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned", nil)
|
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, notifs)
|
assert.Empty(t, notifs)
|
||||||
}
|
}
|
||||||
@@ -891,7 +891,7 @@ func TestNotificationService_DisabledPrefs_ResidenceShared(t *testing.T) {
|
|||||||
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Home shared", nil)
|
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Home shared", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, notifs)
|
assert.Empty(t, notifs)
|
||||||
}
|
}
|
||||||
@@ -912,7 +912,7 @@ func TestNotificationService_DisabledPrefs_WarrantyExpiring(t *testing.T) {
|
|||||||
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Warranty", "Expiring", nil)
|
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Warranty", "Expiring", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, notifs)
|
assert.Empty(t, notifs)
|
||||||
}
|
}
|
||||||
@@ -929,7 +929,7 @@ func TestNotificationService_CreateAndSend_UnknownTypeDefaultsEnabled(t *testing
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationType("unknown_type"), "Unknown", "Unknown notification", nil)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationType("unknown_type"), "Unknown", "Unknown notification", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
}
|
}
|
||||||
@@ -953,7 +953,7 @@ func TestNotificationService_CreateAndSend_WithMixedDataTypes(t *testing.T) {
|
|||||||
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskDueSoon, "Due Soon", "Fix faucet", data)
|
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskDueSoon, "Due Soon", "Fix faucet", data)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
notifs, err := service.GetNotifications(user.ID, 10, 0)
|
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, notifs, 1)
|
assert.Len(t, notifs, 1)
|
||||||
assert.NotNil(t, notifs[0].Data)
|
assert.NotNil(t, notifs[0].Data)
|
||||||
@@ -985,7 +985,7 @@ func TestNotificationService_UpdatePreferences_MultipleFields(t *testing.T) {
|
|||||||
TaskOverdueHour: &hour14,
|
TaskOverdueHour: &hour14,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := service.UpdatePreferences(user.ID, req)
|
resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.TaskDueSoon)
|
assert.False(t, resp.TaskDueSoon)
|
||||||
assert.False(t, resp.TaskOverdue)
|
assert.False(t, resp.TaskOverdue)
|
||||||
@@ -1013,7 +1013,7 @@ func TestNotificationService_UpdatePreferences_NegativeHour(t *testing.T) {
|
|||||||
TaskOverdueHour: &negHour,
|
TaskOverdueHour: &negHour,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := service.UpdatePreferences(user.ID, req)
|
_, err := service.UpdatePreferences(context.Background(), user.ID, req)
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
|
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1032,12 +1032,12 @@ func TestNotificationService_RegisterDevice_UpdateExistingAndroid(t *testing.T)
|
|||||||
RegistrationID: "token-android-1",
|
RegistrationID: "token-android-1",
|
||||||
Platform: push.PlatformAndroid,
|
Platform: push.PlatformAndroid,
|
||||||
}
|
}
|
||||||
_, err := service.RegisterDevice(user.ID, req)
|
_, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Re-register with same token but new name
|
// Re-register with same token but new name
|
||||||
req.Name = "Pixel 8 Pro"
|
req.Name = "Pixel 8 Pro"
|
||||||
resp, err := service.RegisterDevice(user.ID, req)
|
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "Pixel 8 Pro", resp.Name)
|
assert.Equal(t, "Pixel 8 Pro", resp.Name)
|
||||||
assert.Equal(t, push.PlatformAndroid, resp.Platform)
|
assert.Equal(t, push.PlatformAndroid, resp.Platform)
|
||||||
@@ -1052,7 +1052,7 @@ func TestDeleteDevice_AndroidNotFound_Returns404(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.DeleteDevice(99999, push.PlatformAndroid, user.ID)
|
err := service.DeleteDevice(context.Background(), 99999, push.PlatformAndroid, user.ID)
|
||||||
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
|
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1076,7 +1076,7 @@ func TestDeleteDevice_CorrectUser_Android_Succeeds(t *testing.T) {
|
|||||||
err := db.Create(device).Error
|
err := db.Create(device).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.DeleteDevice(device.ID, push.PlatformAndroid, owner.ID)
|
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, owner.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var found models.GCMDevice
|
var found models.GCMDevice
|
||||||
@@ -1106,7 +1106,7 @@ func TestNotificationService_UnregisterDevice_WrongUser_Android(t *testing.T) {
|
|||||||
err := db.Create(device).Error
|
err := db.Create(device).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = service.UnregisterDevice("owner-android-token", push.PlatformAndroid, attacker.ID)
|
err = service.UnregisterDevice(context.Background(), "owner-android-token", push.PlatformAndroid, attacker.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1119,7 +1119,7 @@ func TestNotificationService_UnregisterDevice_AndroidNotFound(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
err := service.UnregisterDevice("nonexistent-android", push.PlatformAndroid, user.ID)
|
err := service.UnregisterDevice(context.Background(), "nonexistent-android", push.PlatformAndroid, user.ID)
|
||||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary {
|
|||||||
func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) {
|
func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) {
|
||||||
// Check subscription tier limits (if subscription service is wired up)
|
// Check subscription tier limits (if subscription service is wired up)
|
||||||
if s.subscriptionService != nil {
|
if s.subscriptionService != nil {
|
||||||
if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil {
|
if err := s.subscriptionService.CheckLimit(ctx, ownerID, "properties"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,8 +98,8 @@ func NewSubscriptionService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetSubscription gets the subscription for a user
|
// GetSubscription gets the subscription for a user
|
||||||
func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionResponse, error) {
|
func (s *SubscriptionService) GetSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) {
|
||||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -107,13 +107,13 @@ func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionRespons
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetSubscriptionStatus gets detailed subscription status including limits
|
// GetSubscriptionStatus gets detailed subscription status including limits
|
||||||
func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionStatusResponse, error) {
|
func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID uint) (*SubscriptionStatusResponse, error) {
|
||||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := s.subscriptionRepo.GetSettings()
|
settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -122,18 +122,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
|||||||
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
|
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
|
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
|
||||||
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
|
if err := s.subscriptionRepo.WithContext(ctx).SetTrialDates(userID, now, trialEnd); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
// Re-fetch after starting trial so response reflects the new state
|
// Re-fetch after starting trial so response reflects the new state
|
||||||
sub, err = s.subscriptionRepo.FindByUserID(userID)
|
sub, err = s.subscriptionRepo.WithContext(ctx).FindByUserID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all tier limits and build a map
|
// Get all tier limits and build a map
|
||||||
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
|
allLimits, err := s.subscriptionRepo.WithContext(ctx).GetAllTierLimits()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -154,7 +154,7 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get current usage
|
// Get current usage
|
||||||
usage, err := s.getUserUsage(userID)
|
usage, err := s.getUserUsage(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -204,31 +204,31 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
|||||||
// getUserUsage calculates current usage for a user.
|
// getUserUsage calculates current usage for a user.
|
||||||
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
|
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
|
||||||
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
|
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
|
||||||
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
|
func (s *SubscriptionService) getUserUsage(ctx context.Context, userID uint) (*UsageResponse, error) {
|
||||||
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
|
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
|
||||||
propertiesCount, err := s.residenceRepo.CountByOwner(userID)
|
propertiesCount, err := s.residenceRepo.WithContext(ctx).CountByOwner(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Still need residence IDs for batch counting tasks/contractors/documents
|
// Still need residence IDs for batch counting tasks/contractors/documents
|
||||||
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID)
|
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByOwner(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count tasks, contractors, and documents across all residences with single queries each
|
// Count tasks, contractors, and documents across all residences with single queries each
|
||||||
tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs)
|
tasksCount, err := s.taskRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs)
|
contractorsCount, err := s.contractorRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs)
|
documentsCount, err := s.documentRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -242,8 +242,8 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckLimit checks if a user has exceeded a specific limit
|
// CheckLimit checks if a user has exceeded a specific limit
|
||||||
func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
|
func (s *SubscriptionService) CheckLimit(ctx context.Context, userID uint, limitType string) error {
|
||||||
settings, err := s.subscriptionRepo.GetSettings()
|
settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -253,7 +253,7 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -268,12 +268,12 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier)
|
limits, err := s.subscriptionRepo.WithContext(ctx).GetTierLimits(sub.Tier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, err := s.getUserUsage(userID)
|
usage, err := s.getUserUsage(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -301,8 +301,8 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUpgradeTrigger gets an upgrade trigger by key
|
// GetUpgradeTrigger gets an upgrade trigger by key
|
||||||
func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResponse, error) {
|
func (s *SubscriptionService) GetUpgradeTrigger(ctx context.Context, key string) (*UpgradeTriggerResponse, error) {
|
||||||
trigger, err := s.subscriptionRepo.GetUpgradeTrigger(key)
|
trigger, err := s.subscriptionRepo.WithContext(ctx).GetUpgradeTrigger(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, apperrors.NotFound("error.upgrade_trigger_not_found")
|
return nil, apperrors.NotFound("error.upgrade_trigger_not_found")
|
||||||
@@ -314,8 +314,8 @@ func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResp
|
|||||||
|
|
||||||
// GetAllUpgradeTriggers gets all active upgrade triggers as a map keyed by trigger_key
|
// GetAllUpgradeTriggers gets all active upgrade triggers as a map keyed by trigger_key
|
||||||
// KMM client expects Map<String, UpgradeTriggerData>
|
// KMM client expects Map<String, UpgradeTriggerData>
|
||||||
func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTriggerDataResponse, error) {
|
func (s *SubscriptionService) GetAllUpgradeTriggers(ctx context.Context) (map[string]*UpgradeTriggerDataResponse, error) {
|
||||||
triggers, err := s.subscriptionRepo.GetAllUpgradeTriggers()
|
triggers, err := s.subscriptionRepo.WithContext(ctx).GetAllUpgradeTriggers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -328,8 +328,8 @@ func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTrigge
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFeatureBenefits gets all feature benefits
|
// GetFeatureBenefits gets all feature benefits
|
||||||
func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, error) {
|
func (s *SubscriptionService) GetFeatureBenefits(ctx context.Context) ([]FeatureBenefitResponse, error) {
|
||||||
benefits, err := s.subscriptionRepo.GetFeatureBenefits()
|
benefits, err := s.subscriptionRepo.WithContext(ctx).GetFeatureBenefits()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -342,13 +342,13 @@ func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetActivePromotions gets active promotions for a user
|
// GetActivePromotions gets active promotions for a user
|
||||||
func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionResponse, error) {
|
func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID uint) ([]PromotionResponse, error) {
|
||||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
promotions, err := s.subscriptionRepo.GetActivePromotions(sub.Tier)
|
promotions, err := s.subscriptionRepo.WithContext(ctx).GetActivePromotions(sub.Tier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -362,13 +362,13 @@ func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionRespo
|
|||||||
|
|
||||||
// ProcessApplePurchase processes an Apple IAP purchase
|
// ProcessApplePurchase processes an Apple IAP purchase
|
||||||
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
|
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
|
||||||
func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
|
func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
|
||||||
// Store receipt/transaction data
|
// Store receipt/transaction data
|
||||||
dataToStore := receiptData
|
dataToStore := receiptData
|
||||||
if dataToStore == "" {
|
if dataToStore == "" {
|
||||||
dataToStore = transactionID
|
dataToStore = transactionID
|
||||||
}
|
}
|
||||||
if err := s.subscriptionRepo.UpdateReceiptData(userID, dataToStore); err != nil {
|
if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,18 +406,18 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
|
|||||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
|
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
|
||||||
|
|
||||||
// Upgrade to Pro with the validated expiration
|
// Upgrade to Pro with the validated expiration
|
||||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.GetSubscription(userID)
|
return s.GetSubscription(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessGooglePurchase processes a Google Play purchase
|
// ProcessGooglePurchase processes a Google Play purchase
|
||||||
// productID is optional but helps validate the specific subscription
|
// productID is optional but helps validate the specific subscription
|
||||||
func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
|
func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
|
||||||
// Store purchase token first
|
// Store purchase token first
|
||||||
if err := s.subscriptionRepo.UpdatePurchaseToken(userID, purchaseToken); err != nil {
|
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -463,25 +463,25 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Upgrade to Pro with the validated expiration
|
// Upgrade to Pro with the validated expiration
|
||||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.GetSubscription(userID)
|
return s.GetSubscription(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelSubscription cancels a subscription (downgrades to free at end of period)
|
// CancelSubscription cancels a subscription (downgrades to free at end of period)
|
||||||
func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResponse, error) {
|
func (s *SubscriptionService) CancelSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) {
|
||||||
if err := s.subscriptionRepo.SetAutoRenew(userID, false); err != nil {
|
if err := s.subscriptionRepo.WithContext(ctx).SetAutoRenew(userID, false); err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
return s.GetSubscription(userID)
|
return s.GetSubscription(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
|
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
|
||||||
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
|
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
|
||||||
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
|
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(ctx context.Context, userID uint, requestedPlatform string) (bool, string, error) {
|
||||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", apperrors.Internal(err)
|
return false, "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) {
|
|||||||
googleClient: nil,
|
googleClient: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "")
|
_, err := svc.ProcessApplePurchase(context.Background(), user.ID, "fake-receipt", "")
|
||||||
assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil")
|
assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil")
|
||||||
|
|
||||||
// Verify user was NOT upgraded to Pro
|
// Verify user was NOT upgraded to Pro
|
||||||
@@ -109,7 +110,7 @@ func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Neither receipt data nor transaction ID - should still not grant Pro
|
// Neither receipt data nor transaction ID - should still not grant Pro
|
||||||
_, err := svc.ProcessApplePurchase(user.ID, "", "")
|
_, err := svc.ProcessApplePurchase(context.Background(), user.ID, "", "")
|
||||||
assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data")
|
assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data")
|
||||||
|
|
||||||
// Verify no upgrade happened
|
// Verify no upgrade happened
|
||||||
@@ -140,7 +141,7 @@ func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) {
|
|||||||
googleClient: nil, // Not configured
|
googleClient: nil, // Not configured
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.honeyDue.pro.monthly")
|
_, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "fake-token", "com.tt.honeyDue.pro.monthly")
|
||||||
assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil")
|
assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil")
|
||||||
|
|
||||||
// Verify user was NOT upgraded to Pro
|
// Verify user was NOT upgraded to Pro
|
||||||
@@ -172,7 +173,7 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// With empty token
|
// With empty token
|
||||||
_, err := svc.ProcessGooglePurchase(user.ID, "", "")
|
_, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "", "")
|
||||||
assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil")
|
assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil")
|
||||||
|
|
||||||
// Verify no upgrade happened
|
// Verify no upgrade happened
|
||||||
@@ -202,7 +203,7 @@ func TestSubscriptionService_GetSubscription(t *testing.T) {
|
|||||||
|
|
||||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||||
|
|
||||||
resp, err := svc.GetSubscription(user.ID)
|
resp, err := svc.GetSubscription(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "free", resp.Tier)
|
assert.Equal(t, "free", resp.Tier)
|
||||||
assert.False(t, resp.IsPro)
|
assert.False(t, resp.IsPro)
|
||||||
@@ -238,7 +239,7 @@ func TestSubscriptionService_GetSubscription_ProUser(t *testing.T) {
|
|||||||
err := db.Create(sub).Error
|
err := db.Create(sub).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp, err := svc.GetSubscription(user.ID)
|
resp, err := svc.GetSubscription(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "pro", resp.Tier)
|
assert.Equal(t, "pro", resp.Tier)
|
||||||
assert.True(t, resp.IsPro)
|
assert.True(t, resp.IsPro)
|
||||||
@@ -277,7 +278,7 @@ func TestSubscriptionService_CancelSubscription(t *testing.T) {
|
|||||||
err := db.Create(sub).Error
|
err := db.Create(sub).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp, err := svc.CancelSubscription(user.ID)
|
resp, err := svc.CancelSubscription(context.Background(), user.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, resp.AutoRenew)
|
assert.False(t, resp.AutoRenew)
|
||||||
}
|
}
|
||||||
@@ -365,7 +366,7 @@ func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
|
|||||||
err := db.Create(sub).Error
|
err := db.Create(sub).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
|
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(context.Background(), user.ID, tt.requestedPlatform)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, tt.wantConflict, conflict)
|
assert.Equal(t, tt.wantConflict, conflict)
|
||||||
assert.Equal(t, tt.wantPlatform, existingPlatform)
|
assert.Equal(t, tt.wantPlatform, existingPlatform)
|
||||||
|
|||||||
@@ -898,7 +898,7 @@ func (s *TaskService) sendTaskCompletedNotification(ctx context.Context, task *m
|
|||||||
// Send email notification (to everyone INCLUDING the person who completed it)
|
// Send email notification (to everyone INCLUDING the person who completed it)
|
||||||
// Check user's email notification preferences first
|
// Check user's email notification preferences first
|
||||||
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
|
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
|
||||||
prefs, prefsErr := s.notificationService.GetPreferences(user.ID)
|
prefs, prefsErr := s.notificationService.GetPreferences(ctx, user.ID)
|
||||||
// LE-06: Log fail-open behavior when preferences cannot be loaded
|
// LE-06: Log fail-open behavior when preferences cannot be loaded
|
||||||
if prefsErr != nil {
|
if prefsErr != nil {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
@@ -1264,8 +1264,8 @@ func (s *TaskService) GetFrequencies(ctx context.Context) ([]responses.TaskFrequ
|
|||||||
// UpdateUserTimezone updates the user's timezone for background job calculations.
|
// UpdateUserTimezone updates the user's timezone for background job calculations.
|
||||||
// This is called from handlers when the X-Timezone header is present.
|
// This is called from handlers when the X-Timezone header is present.
|
||||||
// Delegates to NotificationService since timezone is stored in notification preferences.
|
// Delegates to NotificationService since timezone is stored in notification preferences.
|
||||||
func (s *TaskService) UpdateUserTimezone(userID uint, timezone string) {
|
func (s *TaskService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) {
|
||||||
if s.notificationService != nil {
|
if s.notificationService != nil {
|
||||||
s.notificationService.UpdateUserTimezone(userID, timezone)
|
s.notificationService.UpdateUserTimezone(ctx, userID, timezone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user