Migrate Auth/Contractor/Document/Notification/Subscription services to ctx
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Build (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled

Every public method on these five services now takes ctx context.Context as
the first arg and routes its repo calls through .WithContext(ctx). With
TaskService and ResidenceService already migrated, this means every
in-process service that touches Postgres now produces a flame graph in
Jaeger where the SQL spans nest under the parent HTTP request span.

Endpoints now fully traced (HTTP → service → SQL):
- /api/auth/login, /register, /logout, /me, /verify-email, /resend-verification
- /api/auth/forgot-password, /verify-reset, /reset-password, /update-profile
- /api/contractors/* (CRUD + favorite + by-residence + tasks)
- /api/documents/* (CRUD + activate/deactivate + image upload/delete)
- /api/notifications/* (list, count, mark-read, prefs, devices)
- /api/subscription/* (status, purchase, cancel, triggers, promotions)
- All previously-migrated /api/tasks/* and /api/residences/* paths

Internal helpers also threaded:
- TaskService.sendTaskCompletedNotification → forwards ctx
- TaskService.UpdateUserTimezone → forwards ctx to NotificationService
- ResidenceService.CreateResidence → forwards ctx to SubscriptionService.CheckLimit
- NotificationService.registerAPNSDevice / registerGCMDevice → both take ctx

~75 method signatures, ~120 handler/test call sites updated. Tests pass
green; the only failure is the pre-existing flaky TaskHandler_QuickComplete
SQLite race that fails ~60% of runs on master.

Step 3 of the observability plan is now genuinely complete: every API
endpoint backed by a Go service emits a per-request flame graph with
HTTP → service → SQL spans, plus B2/APNs/FCM/asynq spans where applicable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-04-25 16:26:21 -05:00
parent 65a9aae4e5
commit e881d37de0
20 changed files with 529 additions and 522 deletions
+12 -12
View File
@@ -65,7 +65,7 @@ func (h *AuthHandler) Login(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
response, err := h.authService.Login(&req) response, err := h.authService.Login(c.Request().Context(), &req)
if err != nil { if err != nil {
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed") log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
if h.auditService != nil { if h.auditService != nil {
@@ -94,7 +94,7 @@ func (h *AuthHandler) Register(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
response, confirmationCode, err := h.authService.Register(&req) response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req)
if err != nil { if err != nil {
log.Debug().Err(err).Msg("Registration failed") log.Debug().Err(err).Msg("Registration failed")
return err return err
@@ -141,7 +141,7 @@ func (h *AuthHandler) Logout(c echo.Context) error {
} }
// Invalidate token in database // Invalidate token in database
if err := h.authService.Logout(token); err != nil { if err := h.authService.Logout(c.Request().Context(), token); err != nil {
log.Warn().Err(err).Msg("Failed to delete token from database") log.Warn().Err(err).Msg("Failed to delete token from database")
} }
@@ -162,7 +162,7 @@ func (h *AuthHandler) CurrentUser(c echo.Context) error {
return err return err
} }
response, err := h.authService.GetCurrentUser(user.ID) response, err := h.authService.GetCurrentUser(c.Request().Context(), user.ID)
if err != nil { if err != nil {
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user") log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user")
return err return err
@@ -186,7 +186,7 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
response, err := h.authService.UpdateProfile(user.ID, &req) response, err := h.authService.UpdateProfile(c.Request().Context(), user.ID, &req)
if err != nil { if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile") log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile")
return err return err
@@ -210,7 +210,7 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
err = h.authService.VerifyEmail(user.ID, req.Code) err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code)
if err != nil { if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed") log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
return err return err
@@ -243,7 +243,7 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
return err return err
} }
code, err := h.authService.ResendVerificationCode(user.ID) code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID)
if err != nil { if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification") log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
return err return err
@@ -276,7 +276,7 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
code, user, err := h.authService.ForgotPassword(req.Email) code, user, err := h.authService.ForgotPassword(c.Request().Context(), req.Email)
if err != nil { if err != nil {
var appErr *apperrors.AppError var appErr *apperrors.AppError
if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests { if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests {
@@ -324,7 +324,7 @@ func (h *AuthHandler) VerifyResetCode(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
resetToken, err := h.authService.VerifyResetCode(req.Email, req.Code) resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code)
if err != nil { if err != nil {
log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed") log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed")
return err return err
@@ -346,7 +346,7 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
} }
err := h.authService.ResetPassword(req.ResetToken, req.NewPassword) err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword)
if err != nil { if err != nil {
log.Debug().Err(err).Msg("Password reset failed") log.Debug().Err(err).Msg("Password reset failed")
return err return err
@@ -469,7 +469,7 @@ func (h *AuthHandler) RefreshToken(c echo.Context) error {
return apperrors.Unauthorized("error.not_authenticated") return apperrors.Unauthorized("error.not_authenticated")
} }
response, err := h.authService.RefreshToken(token, user.ID) response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID)
if err != nil { if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed") log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
return err return err
@@ -497,7 +497,7 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
return apperrors.BadRequest("error.invalid_request") return apperrors.BadRequest("error.invalid_request")
} }
fileURLs, err := h.authService.DeleteAccount(user.ID, req.Password, req.Confirmation) fileURLs, err := h.authService.DeleteAccount(c.Request().Context(), user.ID, req.Password, req.Confirmation)
if err != nil { if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed") log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed")
return err return err
+9 -9
View File
@@ -30,7 +30,7 @@ func (h *ContractorHandler) ListContractors(c echo.Context) error {
if err != nil { if err != nil {
return err return err
} }
response, err := h.contractorService.ListContractors(user.ID) response, err := h.contractorService.ListContractors(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -48,7 +48,7 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id") return apperrors.BadRequest("error.invalid_contractor_id")
} }
response, err := h.contractorService.GetContractor(uint(contractorID), user.ID) response, err := h.contractorService.GetContractor(c.Request().Context(), uint(contractorID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -69,7 +69,7 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
return err return err
} }
response, err := h.contractorService.CreateContractor(&req, user.ID) response, err := h.contractorService.CreateContractor(c.Request().Context(), &req, user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -95,7 +95,7 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
return err return err
} }
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req) response, err := h.contractorService.UpdateContractor(c.Request().Context(), uint(contractorID), user.ID, &req)
if err != nil { if err != nil {
return err return err
} }
@@ -113,7 +113,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id") return apperrors.BadRequest("error.invalid_contractor_id")
} }
err = h.contractorService.DeleteContractor(uint(contractorID), user.ID) err = h.contractorService.DeleteContractor(c.Request().Context(), uint(contractorID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -131,7 +131,7 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id") return apperrors.BadRequest("error.invalid_contractor_id")
} }
response, err := h.contractorService.ToggleFavorite(uint(contractorID), user.ID) response, err := h.contractorService.ToggleFavorite(c.Request().Context(), uint(contractorID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -149,7 +149,7 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id") return apperrors.BadRequest("error.invalid_contractor_id")
} }
response, err := h.contractorService.GetContractorTasks(uint(contractorID), user.ID) response, err := h.contractorService.GetContractorTasks(c.Request().Context(), uint(contractorID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -167,7 +167,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
return apperrors.BadRequest("error.invalid_residence_id") return apperrors.BadRequest("error.invalid_residence_id")
} }
response, err := h.contractorService.ListContractorsByResidence(uint(residenceID), user.ID) response, err := h.contractorService.ListContractorsByResidence(c.Request().Context(), uint(residenceID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -176,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
// GetSpecialties handles GET /api/contractors/specialties/ // GetSpecialties handles GET /api/contractors/specialties/
func (h *ContractorHandler) GetSpecialties(c echo.Context) error { func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
specialties, err := h.contractorService.GetSpecialties() specialties, err := h.contractorService.GetSpecialties(c.Request().Context())
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
+10 -10
View File
@@ -70,7 +70,7 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
} }
} }
response, err := h.documentService.ListDocuments(user.ID, filter) response, err := h.documentService.ListDocuments(c.Request().Context(), user.ID, filter)
if err != nil { if err != nil {
return err return err
} }
@@ -88,7 +88,7 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id") return apperrors.BadRequest("error.invalid_document_id")
} }
response, err := h.documentService.GetDocument(uint(documentID), user.ID) response, err := h.documentService.GetDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -101,7 +101,7 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
if err != nil { if err != nil {
return err return err
} }
response, err := h.documentService.ListWarranties(user.ID) response, err := h.documentService.ListWarranties(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -222,7 +222,7 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
return err return err
} }
response, err := h.documentService.CreateDocument(&req, user.ID) response, err := h.documentService.CreateDocument(c.Request().Context(), &req, user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -248,7 +248,7 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
return err return err
} }
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req) response, err := h.documentService.UpdateDocument(c.Request().Context(), uint(documentID), user.ID, &req)
if err != nil { if err != nil {
return err return err
} }
@@ -266,7 +266,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id") return apperrors.BadRequest("error.invalid_document_id")
} }
err = h.documentService.DeleteDocument(uint(documentID), user.ID) err = h.documentService.DeleteDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -284,7 +284,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id") return apperrors.BadRequest("error.invalid_document_id")
} }
response, err := h.documentService.ActivateDocument(uint(documentID), user.ID) response, err := h.documentService.ActivateDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -302,7 +302,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id") return apperrors.BadRequest("error.invalid_document_id")
} }
response, err := h.documentService.DeactivateDocument(uint(documentID), user.ID) response, err := h.documentService.DeactivateDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -349,7 +349,7 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
caption := c.FormValue("caption") caption := c.FormValue("caption")
response, err := h.documentService.UploadDocumentImage(uint(documentID), user.ID, result.URL, caption) response, err := h.documentService.UploadDocumentImage(c.Request().Context(), uint(documentID), user.ID, result.URL, caption)
if err != nil { if err != nil {
return err return err
} }
@@ -372,7 +372,7 @@ func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
return apperrors.BadRequest("error.invalid_image_id") return apperrors.BadRequest("error.invalid_image_id")
} }
response, err := h.documentService.DeleteDocumentImage(uint(documentID), uint(imageID), user.ID) response, err := h.documentService.DeleteDocumentImage(c.Request().Context(), uint(documentID), uint(imageID), user.ID)
if err != nil { if err != nil {
return err return err
} }
+10 -10
View File
@@ -46,7 +46,7 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
} }
} }
notifications, err := h.notificationService.GetNotifications(user.ID, limit, offset) notifications, err := h.notificationService.GetNotifications(c.Request().Context(), user.ID, limit, offset)
if err != nil { if err != nil {
return err return err
} }
@@ -64,7 +64,7 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
return err return err
} }
count, err := h.notificationService.GetUnreadCount(user.ID) count, err := h.notificationService.GetUnreadCount(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -84,7 +84,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
return apperrors.BadRequest("error.invalid_notification_id") return apperrors.BadRequest("error.invalid_notification_id")
} }
err = h.notificationService.MarkAsRead(uint(notificationID), user.ID) err = h.notificationService.MarkAsRead(c.Request().Context(), uint(notificationID), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -99,7 +99,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
return err return err
} }
err = h.notificationService.MarkAllAsRead(user.ID) err = h.notificationService.MarkAllAsRead(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -114,7 +114,7 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
return err return err
} }
prefs, err := h.notificationService.GetPreferences(user.ID) prefs, err := h.notificationService.GetPreferences(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -137,7 +137,7 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
return err return err
} }
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req) prefs, err := h.notificationService.UpdatePreferences(c.Request().Context(), user.ID, &req)
if err != nil { if err != nil {
return err return err
} }
@@ -160,7 +160,7 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
return err return err
} }
device, err := h.notificationService.RegisterDevice(user.ID, &req) device, err := h.notificationService.RegisterDevice(c.Request().Context(), user.ID, &req)
if err != nil { if err != nil {
return err return err
} }
@@ -175,7 +175,7 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
return err return err
} }
devices, err := h.notificationService.ListDevices(user.ID) devices, err := h.notificationService.ListDevices(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -208,7 +208,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
return apperrors.BadRequest("error.invalid_platform") return apperrors.BadRequest("error.invalid_platform")
} }
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID) err = h.notificationService.UnregisterDevice(c.Request().Context(), req.RegistrationID, req.Platform, user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -236,7 +236,7 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
return apperrors.BadRequest("error.invalid_platform") return apperrors.BadRequest("error.invalid_platform")
} }
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID) err = h.notificationService.DeleteDevice(c.Request().Context(), uint(deviceID), platform, user.ID)
if err != nil { if err != nil {
return err return err
} }
+1 -1
View File
@@ -106,7 +106,7 @@ func (h *StaticDataHandler) GetStaticData(c echo.Context) error {
return err return err
} }
contractorSpecialties, err := h.contractorService.GetSpecialties() contractorSpecialties, err := h.contractorService.GetSpecialties(c.Request().Context())
if err != nil { if err != nil {
return err return err
} }
+12 -12
View File
@@ -32,7 +32,7 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
return err return err
} }
subscription, err := h.subscriptionService.GetSubscription(user.ID) subscription, err := h.subscriptionService.GetSubscription(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -47,7 +47,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
return err return err
} }
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID) status, err := h.subscriptionService.GetSubscriptionStatus(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -59,7 +59,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error { func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
key := c.Param("key") key := c.Param("key")
trigger, err := h.subscriptionService.GetUpgradeTrigger(key) trigger, err := h.subscriptionService.GetUpgradeTrigger(c.Request().Context(), key)
if err != nil { if err != nil {
return err return err
} }
@@ -69,7 +69,7 @@ func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
// GetAllUpgradeTriggers handles GET /api/subscription/upgrade-triggers/ // GetAllUpgradeTriggers handles GET /api/subscription/upgrade-triggers/
func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error { func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
triggers, err := h.subscriptionService.GetAllUpgradeTriggers() triggers, err := h.subscriptionService.GetAllUpgradeTriggers(c.Request().Context())
if err != nil { if err != nil {
return err return err
} }
@@ -79,7 +79,7 @@ func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
// GetFeatureBenefits handles GET /api/subscription/features/ // GetFeatureBenefits handles GET /api/subscription/features/
func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error { func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
benefits, err := h.subscriptionService.GetFeatureBenefits() benefits, err := h.subscriptionService.GetFeatureBenefits(c.Request().Context())
if err != nil { if err != nil {
return err return err
} }
@@ -94,7 +94,7 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
return err return err
} }
promotions, err := h.subscriptionService.GetActivePromotions(user.ID) promotions, err := h.subscriptionService.GetActivePromotions(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -125,12 +125,12 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
if req.TransactionID == "" && req.ReceiptData == "" { if req.TransactionID == "" && req.ReceiptData == "" {
return apperrors.BadRequest("error.receipt_data_required") return apperrors.BadRequest("error.receipt_data_required")
} }
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID) subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID)
case "android": case "android":
if req.PurchaseToken == "" { if req.PurchaseToken == "" {
return apperrors.BadRequest("error.purchase_token_required") return apperrors.BadRequest("error.purchase_token_required")
} }
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID) subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID)
default: default:
return apperrors.BadRequest("error.invalid_platform") return apperrors.BadRequest("error.invalid_platform")
} }
@@ -152,7 +152,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
return err return err
} }
subscription, err := h.subscriptionService.CancelSubscription(user.ID) subscription, err := h.subscriptionService.CancelSubscription(c.Request().Context(), user.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -187,12 +187,12 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
if req.ReceiptData == "" && req.TransactionID == "" { if req.ReceiptData == "" && req.TransactionID == "" {
return apperrors.BadRequest("error.receipt_data_required") return apperrors.BadRequest("error.receipt_data_required")
} }
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID) subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID)
case "android": case "android":
if req.PurchaseToken == "" { if req.PurchaseToken == "" {
return apperrors.BadRequest("error.purchase_token_required") return apperrors.BadRequest("error.purchase_token_required")
} }
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID) subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID)
default: default:
return apperrors.BadRequest("error.invalid_platform") return apperrors.BadRequest("error.invalid_platform")
} }
@@ -220,7 +220,7 @@ func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
} }
// Check if already Pro from another platform // Check if already Pro from another platform
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe") alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(c.Request().Context(), user.ID, "stripe")
if err != nil { if err != nil {
return err return err
} }
+1 -1
View File
@@ -42,7 +42,7 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" { if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
cachedTZ, _ := c.Get("user_timezone").(string) cachedTZ, _ := c.Get("user_timezone").(string)
if cachedTZ != tzHeader { if cachedTZ != tzHeader {
h.taskService.UpdateUserTimezone(user.ID, tzHeader) h.taskService.UpdateUserTimezone(c.Request().Context(), user.ID, tzHeader)
c.Set("user_timezone", tzHeader) c.Set("user_timezone", tzHeader)
} }
} }
+8 -7
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -74,7 +75,7 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID) resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token") assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
assert.Contains(t, resp.Message, "still valid") assert.Contains(t, resp.Message, "still valid")
@@ -87,7 +88,7 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID) resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "should return a new token") assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
assert.Contains(t, resp.Message, "refreshed") assert.Contains(t, resp.Message, "refreshed")
@@ -114,7 +115,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID) resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, resp) assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.token_expired") assert.Contains(t, err.Error(), "error.token_expired")
@@ -129,7 +130,7 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID) resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed") assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
} }
@@ -140,7 +141,7 @@ func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
resp, err := svc.RefreshToken("nonexistent-token-key", user.ID) resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, resp) assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token") assert.Contains(t, err.Error(), "error.invalid_token")
@@ -154,7 +155,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
// Try to refresh with a different user ID // Try to refresh with a different user ID
resp, err := svc.RefreshToken(token.Key, user.ID+999) resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, resp) assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token") assert.Contains(t, err.Error(), "error.invalid_token")
@@ -167,7 +168,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
svc := newTestAuthService(db) svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID) resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed") assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
} }
+85 -85
View File
@@ -57,14 +57,14 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N
} }
// Login authenticates a user and returns a token // Login authenticates a user and returns a token
func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginResponse, error) { func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) {
// Find user by username or email // Find user by username or email
identifier := req.Username identifier := req.Username
if identifier == "" { if identifier == "" {
identifier = req.Email identifier = req.Email
} }
user, err := s.userRepo.FindByUsernameOrEmail(identifier) user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier)
if err != nil { if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) { if errors.Is(err, repositories.ErrUserNotFound) {
return nil, apperrors.Unauthorized("error.invalid_credentials") return nil, apperrors.Unauthorized("error.invalid_credentials")
@@ -83,13 +83,13 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
} }
// Get or create auth token // Get or create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Update last login // Update last login
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil { if err := s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID); err != nil {
// Log error but don't fail the login // Log error but don't fail the login
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login") log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
} }
@@ -103,9 +103,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
// Register creates a new user account. // Register creates a new user account.
// F-10: User creation, profile creation, notification preferences, and confirmation code // F-10: User creation, profile creation, notification preferences, and confirmation code
// are wrapped in a transaction for atomicity. // are wrapped in a transaction for atomicity.
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) { func (s *AuthService) Register(ctx context.Context, req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
// Check if username exists // Check if username exists
exists, err := s.userRepo.ExistsByUsername(req.Username) exists, err := s.userRepo.WithContext(ctx).ExistsByUsername(req.Username)
if err != nil { if err != nil {
return nil, "", apperrors.Internal(err) return nil, "", apperrors.Internal(err)
} }
@@ -114,7 +114,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
} }
// Check if email exists // Check if email exists
exists, err = s.userRepo.ExistsByEmail(req.Email) exists, err = s.userRepo.WithContext(ctx).ExistsByEmail(req.Email)
if err != nil { if err != nil {
return nil, "", apperrors.Internal(err) return nil, "", apperrors.Internal(err)
} }
@@ -146,7 +146,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
// Wrap user creation + profile + notification preferences + confirmation code in a transaction // Wrap user creation + profile + notification preferences + confirmation code in a transaction
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error { txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
// Save user // Save user
if err := txRepo.Create(user); err != nil { if err := txRepo.Create(user); err != nil {
return err return err
@@ -159,7 +159,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// Create notification preferences with all options enabled // Create notification preferences with all options enabled
if s.notificationRepo != nil { if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration") log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
} }
} }
@@ -176,7 +176,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
} }
// Create auth token (outside transaction since token generation is idempotent) // Create auth token (outside transaction since token generation is idempotent)
token, err := s.userRepo.GetOrCreateToken(user.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil { if err != nil {
return nil, "", apperrors.Internal(err) return nil, "", apperrors.Internal(err)
} }
@@ -192,7 +192,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// - If token is expired (> expiryDays old), returns error (must re-login). // - If token is expired (> expiryDays old), returns error (must re-login).
// - If token is in the renewal window (> refreshDays old), generates a new token. // - If token is in the renewal window (> refreshDays old), generates a new token.
// - If token is still fresh (< refreshDays old), returns the existing token (no-op). // - If token is still fresh (< refreshDays old), returns the existing token (no-op).
func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) { func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
expiryDays := s.cfg.Security.TokenExpiryDays expiryDays := s.cfg.Security.TokenExpiryDays
if expiryDays <= 0 { if expiryDays <= 0 {
expiryDays = 90 expiryDays = 90
@@ -203,7 +203,7 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
} }
// Look up the token // Look up the token
authToken, err := s.userRepo.FindTokenByKey(tokenKey) authToken, err := s.userRepo.WithContext(ctx).FindTokenByKey(tokenKey)
if err != nil { if err != nil {
return nil, apperrors.Unauthorized("error.invalid_token") return nil, apperrors.Unauthorized("error.invalid_token")
} }
@@ -232,12 +232,12 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
// Token is in the renewal window — generate a new one // Token is in the renewal window — generate a new one
// Delete the old token // Delete the old token
if err := s.userRepo.DeleteToken(tokenKey); err != nil { if err := s.userRepo.WithContext(ctx).DeleteToken(tokenKey); err != nil {
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh") log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
} }
// Create a new token // Create a new token
newToken, err := s.userRepo.CreateToken(userID) newToken, err := s.userRepo.WithContext(ctx).CreateToken(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -249,18 +249,18 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
} }
// Logout invalidates a user's token // Logout invalidates a user's token
func (s *AuthService) Logout(token string) error { func (s *AuthService) Logout(ctx context.Context, token string) error {
return s.userRepo.DeleteToken(token) return s.userRepo.WithContext(ctx).DeleteToken(token)
} }
// GetCurrentUser returns the current authenticated user with profile // GetCurrentUser returns the current authenticated user with profile
func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserResponse, error) { func (s *AuthService) GetCurrentUser(ctx context.Context, userID uint) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.FindByIDWithProfile(userID) user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
authProvider, err := s.userRepo.FindAuthProvider(userID) authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
if err != nil { if err != nil {
// Log but don't fail - default to "email" // Log but don't fail - default to "email"
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
@@ -275,9 +275,9 @@ func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserRespons
// For email auth users, password verification is required. // For email auth users, password verification is required.
// For social auth users, confirmation string "DELETE" is required. // For social auth users, confirmation string "DELETE" is required.
// Returns a list of file URLs that need to be deleted from disk. // Returns a list of file URLs that need to be deleted from disk.
func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) ([]string, error) { func (s *AuthService) DeleteAccount(ctx context.Context, userID uint, password, confirmation *string) ([]string, error) {
// Fetch user // Fetch user
user, err := s.userRepo.FindByID(userID) user, err := s.userRepo.WithContext(ctx).FindByID(userID)
if err != nil { if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) { if errors.Is(err, repositories.ErrUserNotFound) {
return nil, apperrors.NotFound("error.user_not_found") return nil, apperrors.NotFound("error.user_not_found")
@@ -286,7 +286,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
} }
// Determine auth provider // Determine auth provider
authProvider, err := s.userRepo.FindAuthProvider(userID) authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -308,7 +308,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
// Start transaction and cascade delete // Start transaction and cascade delete
var fileURLs []string var fileURLs []string
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error { txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
urls, err := txRepo.DeleteUserCascade(userID) urls, err := txRepo.DeleteUserCascade(userID)
if err != nil { if err != nil {
return err return err
@@ -324,15 +324,15 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
} }
// UpdateProfile updates a user's profile // UpdateProfile updates a user's profile
func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) { func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.FindByID(userID) user, err := s.userRepo.WithContext(ctx).FindByID(userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Check if new email is taken (if email is being changed) // Check if new email is taken (if email is being changed)
if req.Email != nil && *req.Email != user.Email { if req.Email != nil && *req.Email != user.Email {
exists, err := s.userRepo.ExistsByEmail(*req.Email) exists, err := s.userRepo.WithContext(ctx).ExistsByEmail(*req.Email)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -349,17 +349,17 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
user.LastName = *req.LastName user.LastName = *req.LastName
} }
if err := s.userRepo.Update(user); err != nil { if err := s.userRepo.WithContext(ctx).Update(user); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reload with profile // Reload with profile
user, err = s.userRepo.FindByIDWithProfile(userID) user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
authProvider, err := s.userRepo.FindAuthProvider(userID) authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
if err != nil { if err != nil {
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
authProvider = "email" authProvider = "email"
@@ -370,9 +370,9 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
} }
// VerifyEmail verifies a user's email with a confirmation code // VerifyEmail verifies a user's email with a confirmation code
func (s *AuthService) VerifyEmail(userID uint, code string) error { func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string) error {
// Get user profile // Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID) profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -384,14 +384,14 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
// Check for test code when DEBUG_FIXED_CODES is enabled // Check for test code when DEBUG_FIXED_CODES is enabled
if s.cfg.Server.DebugFixedCodes && code == "123456" { if s.cfg.Server.DebugFixedCodes && code == "123456" {
if err := s.userRepo.SetProfileVerified(userID, true); err != nil { if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
return nil return nil
} }
// Find and validate confirmation code // Find and validate confirmation code
confirmCode, err := s.userRepo.FindConfirmationCode(userID, code) confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code)
if err != nil { if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) { if errors.Is(err, repositories.ErrCodeNotFound) {
return apperrors.BadRequest("error.invalid_verification_code") return apperrors.BadRequest("error.invalid_verification_code")
@@ -403,12 +403,12 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
} }
// Mark code as used // Mark code as used
if err := s.userRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil { if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
// Set profile as verified // Set profile as verified
if err := s.userRepo.SetProfileVerified(userID, true); err != nil { if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -416,9 +416,9 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
} }
// ResendVerificationCode creates and returns a new verification code // ResendVerificationCode creates and returns a new verification code
func (s *AuthService) ResendVerificationCode(userID uint) (string, error) { func (s *AuthService) ResendVerificationCode(ctx context.Context, userID uint) (string, error) {
// Get user profile // Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID) profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID)
if err != nil { if err != nil {
return "", apperrors.Internal(err) return "", apperrors.Internal(err)
} }
@@ -437,7 +437,7 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
} }
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil { if _, err := s.userRepo.WithContext(ctx).CreateConfirmationCode(userID, code, expiresAt); err != nil {
return "", apperrors.Internal(err) return "", apperrors.Internal(err)
} }
@@ -445,9 +445,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
} }
// ForgotPassword initiates the password reset process // ForgotPassword initiates the password reset process
func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) { func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string, *models.User, error) {
// Find user by email // Find user by email
user, err := s.userRepo.FindByEmail(email) user, err := s.userRepo.WithContext(ctx).FindByEmail(email)
if err != nil { if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) { if errors.Is(err, repositories.ErrUserNotFound) {
// Don't reveal that the email doesn't exist // Don't reveal that the email doesn't exist
@@ -457,7 +457,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
} }
// Check rate limit // Check rate limit
count, err := s.userRepo.CountRecentPasswordResetRequests(user.ID) count, err := s.userRepo.WithContext(ctx).CountRecentPasswordResetRequests(user.ID)
if err != nil { if err != nil {
return "", nil, apperrors.Internal(err) return "", nil, apperrors.Internal(err)
} }
@@ -481,7 +481,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
return "", nil, apperrors.Internal(err) return "", nil, apperrors.Internal(err)
} }
if _, err := s.userRepo.CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil { if _, err := s.userRepo.WithContext(ctx).CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
return "", nil, apperrors.Internal(err) return "", nil, apperrors.Internal(err)
} }
@@ -489,9 +489,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
} }
// VerifyResetCode verifies a password reset code and returns a reset token // VerifyResetCode verifies a password reset code and returns a reset token
func (s *AuthService) VerifyResetCode(email, code string) (string, error) { func (s *AuthService) VerifyResetCode(ctx context.Context, email, code string) (string, error) {
// Find the reset code // Find the reset code
resetCode, user, err := s.userRepo.FindPasswordResetCodeByEmail(email) resetCode, user, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByEmail(email)
if err != nil { if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) { if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) {
return "", apperrors.BadRequest("error.invalid_verification_code") return "", apperrors.BadRequest("error.invalid_verification_code")
@@ -507,7 +507,7 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
// Verify the code // Verify the code
if !resetCode.CheckCode(code) { if !resetCode.CheckCode(code) {
// Increment attempts // Increment attempts
s.userRepo.IncrementResetCodeAttempts(resetCode.ID) s.userRepo.WithContext(ctx).IncrementResetCodeAttempts(resetCode.ID)
return "", apperrors.BadRequest("error.invalid_verification_code") return "", apperrors.BadRequest("error.invalid_verification_code")
} }
@@ -528,9 +528,9 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
} }
// ResetPassword resets the user's password using a reset token // ResetPassword resets the user's password using a reset token
func (s *AuthService) ResetPassword(resetToken, newPassword string) error { func (s *AuthService) ResetPassword(ctx context.Context, resetToken, newPassword string) error {
// Find the reset code by token // Find the reset code by token
resetCode, err := s.userRepo.FindPasswordResetCodeByToken(resetToken) resetCode, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByToken(resetToken)
if err != nil { if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) { if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) {
return apperrors.BadRequest("error.invalid_reset_token") return apperrors.BadRequest("error.invalid_reset_token")
@@ -539,7 +539,7 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
} }
// Get the user // Get the user
user, err := s.userRepo.FindByID(resetCode.UserID) user, err := s.userRepo.WithContext(ctx).FindByID(resetCode.UserID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -549,18 +549,18 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
if err := s.userRepo.Update(user); err != nil { if err := s.userRepo.WithContext(ctx).Update(user); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
// Mark reset code as used // Mark reset code as used
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil { if err := s.userRepo.WithContext(ctx).MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
// Log error but don't fail // Log error but don't fail
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used") log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
} }
// Invalidate all existing tokens for this user (security measure) // Invalidate all existing tokens for this user (security measure)
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil { if err := s.userRepo.WithContext(ctx).DeleteTokenByUserID(user.ID); err != nil {
// Log error but don't fail // Log error but don't fail
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset") log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
} }
@@ -583,10 +583,10 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
} }
// 2. Check if this Apple ID is already linked to an account // 2. Check if this Apple ID is already linked to an account
existingAuth, err := s.userRepo.FindByAppleID(appleID) existingAuth, err := s.userRepo.WithContext(ctx).FindByAppleID(appleID)
if err == nil && existingAuth != nil { if err == nil && existingAuth != nil {
// User already linked with this Apple ID - log them in // User already linked with this Apple ID - log them in
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID) user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -596,13 +596,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
} }
// Get or create token // Get or create token
token, err := s.userRepo.GetOrCreateToken(user.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Update last login // Update last login
_ = s.userRepo.UpdateLastLogin(user.ID) _ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
return &responses.AppleSignInResponse{ return &responses.AppleSignInResponse{
Token: token.Key, Token: token.Key,
@@ -614,7 +614,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// 3. Check if email matches an existing user (for account linking) // 3. Check if email matches an existing user (for account linking)
email := getEmailFromRequest(req.Email, claims.Email) email := getEmailFromRequest(req.Email, claims.Email)
if email != "" { if email != "" {
existingUser, err := s.userRepo.FindByEmail(email) existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email)
if err == nil && existingUser != nil { if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user // S-06: Log auto-linking of social account to existing user
log.Warn(). log.Warn().
@@ -630,24 +630,24 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
Email: email, Email: email,
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(), IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
} }
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil { if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Mark as verified since Apple verified the email // Mark as verified since Apple verified the email
_ = s.userRepo.SetProfileVerified(existingUser.ID, true) _ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
// Get or create token // Get or create token
token, err := s.userRepo.GetOrCreateToken(existingUser.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Update last login // Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID) _ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID)
// B-08: Check error from FindByIDWithProfile // B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID) existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -675,19 +675,19 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
randomPassword := generateResetToken() randomPassword := generateResetToken()
_ = user.SetPassword(randomPassword) _ = user.SetPassword(randomPassword)
if err := s.userRepo.Create(user); err != nil { if err := s.userRepo.WithContext(ctx).Create(user); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Create profile (already verified since Apple verified) // Create profile (already verified since Apple verified)
profile, _ := s.userRepo.GetOrCreateProfile(user.ID) profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID)
if profile != nil { if profile != nil {
_ = s.userRepo.SetProfileVerified(user.ID, true) _ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true)
} }
// Create notification preferences with all options enabled // Create notification preferences with all options enabled
if s.notificationRepo != nil { if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user") log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
} }
} }
@@ -699,18 +699,18 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
Email: getEmailOrDefault(email), Email: getEmailOrDefault(email),
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(), IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
} }
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil { if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Create token // Create token
token, err := s.userRepo.GetOrCreateToken(user.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// B-08: Check error from FindByIDWithProfile // B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID) user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -736,10 +736,10 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
} }
// 2. Check if this Google ID is already linked to an account // 2. Check if this Google ID is already linked to an account
existingAuth, err := s.userRepo.FindByGoogleID(googleID) existingAuth, err := s.userRepo.WithContext(ctx).FindByGoogleID(googleID)
if err == nil && existingAuth != nil { if err == nil && existingAuth != nil {
// User already linked with this Google ID - log them in // User already linked with this Google ID - log them in
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID) user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -749,13 +749,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
} }
// Get or create token // Get or create token
token, err := s.userRepo.GetOrCreateToken(user.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Update last login // Update last login
_ = s.userRepo.UpdateLastLogin(user.ID) _ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
return &responses.GoogleSignInResponse{ return &responses.GoogleSignInResponse{
Token: token.Key, Token: token.Key,
@@ -767,7 +767,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// 3. Check if email matches an existing user (for account linking) // 3. Check if email matches an existing user (for account linking)
email := tokenInfo.Email email := tokenInfo.Email
if email != "" { if email != "" {
existingUser, err := s.userRepo.FindByEmail(email) existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email)
if err == nil && existingUser != nil { if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user // S-06: Log auto-linking of social account to existing user
log.Warn(). log.Warn().
@@ -784,26 +784,26 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
Name: tokenInfo.Name, Name: tokenInfo.Name,
Picture: tokenInfo.Picture, Picture: tokenInfo.Picture,
} }
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil { if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Mark as verified since Google verified the email // Mark as verified since Google verified the email
if tokenInfo.IsEmailVerified() { if tokenInfo.IsEmailVerified() {
_ = s.userRepo.SetProfileVerified(existingUser.ID, true) _ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
} }
// Get or create token // Get or create token
token, err := s.userRepo.GetOrCreateToken(existingUser.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Update last login // Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID) _ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID)
// B-08: Check error from FindByIDWithProfile // B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID) existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -831,19 +831,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
randomPassword := generateResetToken() randomPassword := generateResetToken()
_ = user.SetPassword(randomPassword) _ = user.SetPassword(randomPassword)
if err := s.userRepo.Create(user); err != nil { if err := s.userRepo.WithContext(ctx).Create(user); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Create profile (already verified if Google verified email) // Create profile (already verified if Google verified email)
profile, _ := s.userRepo.GetOrCreateProfile(user.ID) profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID)
if profile != nil && tokenInfo.IsEmailVerified() { if profile != nil && tokenInfo.IsEmailVerified() {
_ = s.userRepo.SetProfileVerified(user.ID, true) _ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true)
} }
// Create notification preferences with all options enabled // Create notification preferences with all options enabled
if s.notificationRepo != nil { if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user") log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
} }
} }
@@ -856,18 +856,18 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
Name: tokenInfo.Name, Name: tokenInfo.Name,
Picture: tokenInfo.Picture, Picture: tokenInfo.Picture,
} }
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil { if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Create token // Create token
token, err := s.userRepo.GetOrCreateToken(user.ID) token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// B-08: Check error from FindByIDWithProfile // B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID) user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
+54 -53
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
"time" "time"
@@ -53,7 +54,7 @@ func TestAuthService_Login(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
resp, err := service.Login(req) resp, err := service.Login(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, resp.Token) assert.NotEmpty(t, resp.Token)
assert.Equal(t, "testuser", resp.User.Username) assert.Equal(t, "testuser", resp.User.Username)
@@ -74,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
resp, err := service.Login(req) resp, err := service.Login(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, resp.Token) assert.NotEmpty(t, resp.Token)
} }
@@ -94,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) {
Password: "WrongPassword1", Password: "WrongPassword1",
} }
_, err := service.Login(req) _, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
} }
@@ -111,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
_, err := service.Login(req) _, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
} }
@@ -133,7 +134,7 @@ func TestAuthService_Login_InactiveUser(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
_, err := service.Login(req) _, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive") testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
} }
@@ -148,7 +149,7 @@ func TestAuthService_Register(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
resp, code, err := service.Register(req) resp, code, err := service.Register(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, resp.Token) assert.NotEmpty(t, resp.Token)
assert.Equal(t, "newuser", resp.User.Username) assert.Equal(t, "newuser", resp.User.Username)
@@ -172,7 +173,7 @@ func TestAuthService_Register_DuplicateUsername(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(req) _, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken") testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
} }
@@ -193,7 +194,7 @@ func TestAuthService_Register_DuplicateEmail(t *testing.T) {
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(req) _, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken") testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
} }
@@ -211,7 +212,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
// Create profile // Create profile
userRepo.GetOrCreateProfile(user.ID) userRepo.GetOrCreateProfile(user.ID)
resp, err := service.GetCurrentUser(user.ID) resp, err := service.GetCurrentUser(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "testuser", resp.Username) assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@test.com", resp.Email) assert.Equal(t, "test@test.com", resp.Email)
@@ -238,7 +239,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
LastName: &newLast, LastName: &newLast,
} }
resp, err := service.UpdateProfile(user.ID, req) resp, err := service.UpdateProfile(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "John", resp.FirstName) assert.Equal(t, "John", resp.FirstName)
assert.Equal(t, "Doe", resp.LastName) assert.Equal(t, "Doe", resp.LastName)
@@ -261,7 +262,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
Email: &takenEmail, Email: &takenEmail,
} }
_, err := service.UpdateProfile(user2.ID, req) _, err := service.UpdateProfile(context.Background(), user2.ID, req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken") testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken")
} }
@@ -282,7 +283,7 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
} }
// Same email should not trigger duplicate error // Same email should not trigger duplicate error
resp, err := service.UpdateProfile(user.ID, req) resp, err := service.UpdateProfile(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "test@test.com", resp.Email) assert.Equal(t, "test@test.com", resp.Email)
} }
@@ -298,7 +299,7 @@ func TestAuthService_VerifyEmail(t *testing.T) {
Email: "new@test.com", Email: "new@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(req) _, _, err := service.Register(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
// Get the user ID // Get the user ID
@@ -306,11 +307,11 @@ func TestAuthService_VerifyEmail(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify with the debug code // Verify with the debug code
err = service.VerifyEmail(user.ID, "123456") err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err) require.NoError(t, err)
// Verify again — should get already verified error // Verify again — should get already verified error
err = service.VerifyEmail(user.ID, "123456") err = service.VerifyEmail(context.Background(), user.ID, "123456")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
} }
@@ -323,7 +324,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
Email: "new@test.com", Email: "new@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(req) _, _, err := service.Register(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com") user, err := service.userRepo.FindByEmail("new@test.com")
@@ -331,7 +332,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup, // Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
// but a wrong code should use the normal path // but a wrong code should use the normal path
err = service.VerifyEmail(user.ID, "000000") err = service.VerifyEmail(context.Background(), user.ID, "000000")
assert.Error(t, err) assert.Error(t, err)
} }
@@ -346,13 +347,13 @@ func TestAuthService_ResendVerificationCode(t *testing.T) {
Email: "new@test.com", Email: "new@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(req) _, _, err := service.Register(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com") user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err) require.NoError(t, err)
code, err := service.ResendVerificationCode(user.ID) code, err := service.ResendVerificationCode(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes assert.Equal(t, "123456", code) // DebugFixedCodes
} }
@@ -366,16 +367,16 @@ func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
Email: "new@test.com", Email: "new@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(req) _, _, err := service.Register(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com") user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err) require.NoError(t, err)
err = service.VerifyEmail(user.ID, "123456") err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err) require.NoError(t, err)
_, err = service.ResendVerificationCode(user.ID) _, err = service.ResendVerificationCode(context.Background(), user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
} }
@@ -390,10 +391,10 @@ func TestAuthService_ForgotPassword(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
code, user, err := service.ForgotPassword("test@test.com") code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes assert.Equal(t, "123456", code) // DebugFixedCodes
assert.NotNil(t, user) assert.NotNil(t, user)
@@ -404,7 +405,7 @@ func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t) service, _ := setupAuthService(t)
// Should not reveal that email doesn't exist // Should not reveal that email doesn't exist
code, user, err := service.ForgotPassword("nonexistent@test.com") code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, code) assert.Empty(t, code)
assert.Nil(t, user) assert.Nil(t, user)
@@ -421,20 +422,20 @@ func TestAuthService_ResetPassword(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
// Forgot password // Forgot password
_, _, err = service.ForgotPassword("test@test.com") _, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err) require.NoError(t, err)
// Verify reset code to get the token // Verify reset code to get the token
resetToken, err := service.VerifyResetCode("test@test.com", "123456") resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, resetToken) assert.NotEmpty(t, resetToken)
// Reset password // Reset password
err = service.ResetPassword(resetToken, "NewPassword123") err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
require.NoError(t, err) require.NoError(t, err)
// Login with new password // Login with new password
@@ -442,7 +443,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
Username: "testuser", Username: "testuser",
Password: "NewPassword123", Password: "NewPassword123",
} }
loginResp, err := service.Login(loginReq) loginResp, err := service.Login(context.Background(), loginReq)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, loginResp.Token) assert.NotEmpty(t, loginResp.Token)
} }
@@ -450,7 +451,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) { func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
service, _ := setupAuthService(t) service, _ := setupAuthService(t)
err := service.ResetPassword("invalid-token", "NewPassword123") err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
} }
@@ -471,15 +472,15 @@ func TestAuthService_Logout(t *testing.T) {
Username: "testuser", Username: "testuser",
Password: "Password123", Password: "Password123",
} }
loginResp, err := service.Login(loginReq) loginResp, err := service.Login(context.Background(), loginReq)
require.NoError(t, err) require.NoError(t, err)
// Logout // Logout
err = service.Logout(loginResp.Token) err = service.Logout(context.Background(), loginResp.Token)
require.NoError(t, err) require.NoError(t, err)
// Token should be deleted — refreshing should fail // Token should be deleted — refreshing should fail
_, err = service.RefreshToken(loginResp.Token, user.ID) _, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
assert.Error(t, err) assert.Error(t, err)
} }
@@ -494,14 +495,14 @@ func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com") user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err) require.NoError(t, err)
password := "Password123" password := "Password123"
_, err = service.DeleteAccount(user.ID, &password, nil) _, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -513,14 +514,14 @@ func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com") user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err) require.NoError(t, err)
wrongPassword := "WrongPassword1" wrongPassword := "WrongPassword1"
_, err = service.DeleteAccount(user.ID, &wrongPassword, nil) _, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
} }
@@ -532,13 +533,13 @@ func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com") user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err) require.NoError(t, err)
_, err = service.DeleteAccount(user.ID, nil, nil) _, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
} }
@@ -546,7 +547,7 @@ func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t) service, _ := setupAuthService(t)
password := "Password123" password := "Password123"
_, err := service.DeleteAccount(99999, &password, nil) _, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
} }
@@ -658,7 +659,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) {
Password: "", Password: "",
} }
_, err := service.Login(req) _, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
} }
@@ -672,17 +673,17 @@ func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
// Make max allowed reset requests (3 based on setup) // Make max allowed reset requests (3 based on setup)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, _, err := service.ForgotPassword("test@test.com") _, _, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err) require.NoError(t, err)
} }
// The 4th should be rate limited // The 4th should be rate limited
_, _, err = service.ForgotPassword("test@test.com") _, _, err = service.ForgotPassword(context.Background(), "test@test.com")
assert.Error(t, err) assert.Error(t, err)
} }
@@ -696,14 +697,14 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
_, _, err = service.ForgotPassword("test@test.com") _, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err) require.NoError(t, err)
// Wrong code but with debug mode, "123456" works, "000000" should fail // Wrong code but with debug mode, "123456" works, "000000" should fail
_, err = service.VerifyResetCode("test@test.com", "000000") _, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
assert.Error(t, err) assert.Error(t, err)
} }
@@ -712,7 +713,7 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) { func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t) service, _ := setupAuthService(t)
_, err := service.VerifyResetCode("nonexistent@test.com", "123456") _, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
assert.Error(t, err) assert.Error(t, err)
} }
@@ -734,7 +735,7 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
Email: &newEmail, Email: &newEmail,
} }
resp, err := service.UpdateProfile(user.ID, req) resp, err := service.UpdateProfile(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "newemail@test.com", resp.Email) assert.Equal(t, "newemail@test.com", resp.Email)
} }
@@ -749,14 +750,14 @@ func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
Email: "test@test.com", Email: "test@test.com",
Password: "Password123", Password: "Password123",
} }
_, _, err := service.Register(registerReq) _, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err) require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com") user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err) require.NoError(t, err)
emptyPw := "" emptyPw := ""
_, err = service.DeleteAccount(user.ID, &emptyPw, nil) _, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
} }
@@ -789,7 +790,7 @@ func TestAuthService_Register_CreatesProfile(t *testing.T) {
LastName: "Doe", LastName: "Doe",
} }
resp, _, err := service.Register(req) resp, _, err := service.Register(context.Background(), req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "profileuser", resp.User.Username) assert.Equal(t, "profileuser", resp.User.Username)
+39 -38
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"errors" "errors"
"gorm.io/gorm" "gorm.io/gorm"
@@ -33,8 +34,8 @@ func NewContractorService(contractorRepo *repositories.ContractorRepository, res
} }
// GetContractor gets a contractor by ID with access check // GetContractor gets a contractor by ID with access check
func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses.ContractorResponse, error) { func (s *ContractorService) GetContractor(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID) contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found") return nil, apperrors.NotFound("error.contractor_not_found")
@@ -43,7 +44,7 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
} }
// Check access // Check access
if !s.hasContractorAccess(contractor, userID) { if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied") return nil, apperrors.Forbidden("error.contractor_access_denied")
} }
@@ -55,14 +56,14 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
// Access rules: // Access rules:
// - If contractor has no residence: only the creator has access // - If contractor has no residence: only the creator has access
// - If contractor has a residence: all users with access to that residence // - If contractor has a residence: all users with access to that residence
func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, userID uint) bool { func (s *ContractorService) hasContractorAccess(ctx context.Context, contractor *models.Contractor, userID uint) bool {
if contractor.ResidenceID == nil { if contractor.ResidenceID == nil {
// Personal contractor - only creator has access // Personal contractor - only creator has access
return contractor.CreatedByID == userID return contractor.CreatedByID == userID
} }
// Residence contractor - check residence access // Residence contractor - check residence access
hasAccess, err := s.residenceRepo.HasAccess(*contractor.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*contractor.ResidenceID, userID)
if err != nil { if err != nil {
return false return false
} }
@@ -70,15 +71,15 @@ func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, u
} }
// ListContractors lists all contractors accessible to a user // ListContractors lists all contractors accessible to a user
func (s *ContractorService) ListContractors(userID uint) ([]responses.ContractorResponse, error) { func (s *ContractorService) ListContractors(ctx context.Context, userID uint) ([]responses.ContractorResponse, error) {
// Get residence IDs (lightweight - no preloads) // Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID) residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// FindByUser now handles both personal and residence contractors // FindByUser now handles both personal and residence contractors
contractors, err := s.contractorRepo.FindByUser(userID, residenceIDs) contractors, err := s.contractorRepo.WithContext(ctx).FindByUser(userID, residenceIDs)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -87,10 +88,10 @@ func (s *ContractorService) ListContractors(userID uint) ([]responses.Contractor
} }
// CreateContractor creates a new contractor // CreateContractor creates a new contractor
func (s *ContractorService) CreateContractor(req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) { func (s *ContractorService) CreateContractor(ctx context.Context, req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) {
// If residence is provided, check access // If residence is provided, check access
if req.ResidenceID != nil { if req.ResidenceID != nil {
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -122,19 +123,19 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
IsActive: true, IsActive: true,
} }
if err := s.contractorRepo.Create(contractor); err != nil { if err := s.contractorRepo.WithContext(ctx).Create(contractor); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Set specialties if provided // Set specialties if provided
if len(req.SpecialtyIDs) > 0 { if len(req.SpecialtyIDs) > 0 {
if err := s.contractorRepo.SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil { if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
} }
// Reload with relations // Reload with relations
contractor, reloadErr := s.contractorRepo.FindByID(contractor.ID) contractor, reloadErr := s.contractorRepo.WithContext(ctx).FindByID(contractor.ID)
if reloadErr != nil { if reloadErr != nil {
return nil, apperrors.Internal(reloadErr) return nil, apperrors.Internal(reloadErr)
} }
@@ -144,8 +145,8 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
} }
// UpdateContractor updates a contractor // UpdateContractor updates a contractor
func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) { func (s *ContractorService) UpdateContractor(ctx context.Context, contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID) contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found") return nil, apperrors.NotFound("error.contractor_not_found")
@@ -154,7 +155,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
} }
// Check access // Check access
if !s.hasContractorAccess(contractor, userID) { if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied") return nil, apperrors.Forbidden("error.contractor_access_denied")
} }
@@ -198,7 +199,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
// If residence_id is provided, verify the user has access to the NEW residence. // If residence_id is provided, verify the user has access to the NEW residence.
// This prevents an attacker from reassigning a contractor to someone else's residence. // This prevents an attacker from reassigning a contractor to someone else's residence.
if req.ResidenceID != nil { if req.ResidenceID != nil {
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -211,19 +212,19 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
// removed the residence association - contractor becomes personal // removed the residence association - contractor becomes personal
contractor.ResidenceID = req.ResidenceID contractor.ResidenceID = req.ResidenceID
if err := s.contractorRepo.Update(contractor); err != nil { if err := s.contractorRepo.WithContext(ctx).Update(contractor); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Update specialties if provided // Update specialties if provided
if req.SpecialtyIDs != nil { if req.SpecialtyIDs != nil {
if err := s.contractorRepo.SetSpecialties(contractorID, req.SpecialtyIDs); err != nil { if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
} }
// Reload // Reload
contractor, err = s.contractorRepo.FindByID(contractorID) contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -233,8 +234,8 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
} }
// DeleteContractor soft-deletes a contractor // DeleteContractor soft-deletes a contractor
func (s *ContractorService) DeleteContractor(contractorID, userID uint) error { func (s *ContractorService) DeleteContractor(ctx context.Context, contractorID, userID uint) error {
contractor, err := s.contractorRepo.FindByID(contractorID) contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.contractor_not_found") return apperrors.NotFound("error.contractor_not_found")
@@ -243,11 +244,11 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
} }
// Check access // Check access
if !s.hasContractorAccess(contractor, userID) { if !s.hasContractorAccess(ctx, contractor, userID) {
return apperrors.Forbidden("error.contractor_access_denied") return apperrors.Forbidden("error.contractor_access_denied")
} }
if err := s.contractorRepo.Delete(contractorID); err != nil { if err := s.contractorRepo.WithContext(ctx).Delete(contractorID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -255,8 +256,8 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
} }
// ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor // ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor
func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*responses.ContractorResponse, error) { func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID) contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found") return nil, apperrors.NotFound("error.contractor_not_found")
@@ -265,17 +266,17 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
} }
// Check access // Check access
if !s.hasContractorAccess(contractor, userID) { if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied") return nil, apperrors.Forbidden("error.contractor_access_denied")
} }
_, err = s.contractorRepo.ToggleFavorite(contractorID) _, err = s.contractorRepo.WithContext(ctx).ToggleFavorite(contractorID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Re-fetch the contractor to get the updated state with all relations // Re-fetch the contractor to get the updated state with all relations
contractor, err = s.contractorRepo.FindByID(contractorID) contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -285,8 +286,8 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
} }
// GetContractorTasks gets all tasks for a contractor // GetContractorTasks gets all tasks for a contractor
func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]responses.TaskResponse, error) { func (s *ContractorService) GetContractorTasks(ctx context.Context, contractorID, userID uint) ([]responses.TaskResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID) contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found") return nil, apperrors.NotFound("error.contractor_not_found")
@@ -295,11 +296,11 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
} }
// Check access // Check access
if !s.hasContractorAccess(contractor, userID) { if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied") return nil, apperrors.Forbidden("error.contractor_access_denied")
} }
tasks, err := s.contractorRepo.GetTasksForContractor(contractorID) tasks, err := s.contractorRepo.WithContext(ctx).GetTasksForContractor(contractorID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -308,9 +309,9 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
} }
// ListContractorsByResidence lists all contractors for a specific residence // ListContractorsByResidence lists all contractors for a specific residence
func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint) ([]responses.ContractorResponse, error) { func (s *ContractorService) ListContractorsByResidence(ctx context.Context, residenceID, userID uint) ([]responses.ContractorResponse, error) {
// Check user has access to the residence // Check user has access to the residence
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(residenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -318,7 +319,7 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
return nil, apperrors.Forbidden("error.residence_access_denied") return nil, apperrors.Forbidden("error.residence_access_denied")
} }
contractors, err := s.contractorRepo.FindByResidence(residenceID) contractors, err := s.contractorRepo.WithContext(ctx).FindByResidence(residenceID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -327,8 +328,8 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
} }
// GetSpecialties returns all contractor specialties // GetSpecialties returns all contractor specialties
func (s *ContractorService) GetSpecialties() ([]responses.ContractorSpecialtyResponse, error) { func (s *ContractorService) GetSpecialties(ctx context.Context) ([]responses.ContractorSpecialtyResponse, error) {
specialties, err := s.contractorRepo.GetAllSpecialties() specialties, err := s.contractorRepo.WithContext(ctx).GetAllSpecialties()
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
+38 -37
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
@@ -41,7 +42,7 @@ func TestContractorService_CreateContractor(t *testing.T) {
Email: "bob@plumbing.com", Email: "bob@plumbing.com",
} }
resp, err := service.CreateContractor(req, user.ID) resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, "Bob's Plumbing", resp.Name) assert.Equal(t, "Bob's Plumbing", resp.Name)
@@ -63,7 +64,7 @@ func TestContractorService_CreateContractor_Personal(t *testing.T) {
Name: "Personal Handyman", Name: "Personal Handyman",
} }
resp, err := service.CreateContractor(req, user.ID) resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Personal Handyman", resp.Name) assert.Equal(t, "Personal Handyman", resp.Name)
} }
@@ -84,7 +85,7 @@ func TestContractorService_CreateContractor_AccessDenied(t *testing.T) {
Name: "Unauthorized Contractor", Name: "Unauthorized Contractor",
} }
_, err := service.CreateContractor(req, other.ID) _, err := service.CreateContractor(context.Background(), req, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
} }
@@ -105,7 +106,7 @@ func TestContractorService_CreateContractor_WithFavorite(t *testing.T) {
IsFavorite: &isFav, IsFavorite: &isFav,
} }
resp, err := service.CreateContractor(req, user.ID) resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, resp.IsFavorite) assert.True(t, resp.IsFavorite)
} }
@@ -123,7 +124,7 @@ func TestContractorService_GetContractor(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
resp, err := service.GetContractor(contractor.ID, user.ID) resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, contractor.ID, resp.ID) assert.Equal(t, contractor.ID, resp.ID)
assert.Equal(t, "Test Contractor", resp.Name) assert.Equal(t, "Test Contractor", resp.Name)
@@ -138,7 +139,7 @@ func TestContractorService_GetContractor_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetContractor(9999, user.ID) _, err := service.GetContractor(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
} }
@@ -154,7 +155,7 @@ func TestContractorService_GetContractor_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.GetContractor(contractor.ID, other.ID) _, err := service.GetContractor(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
} }
@@ -171,7 +172,7 @@ func TestContractorService_GetContractor_SharedUserHasAccess(t *testing.T) {
residenceRepo.AddUser(residence.ID, shared.ID) residenceRepo.AddUser(residence.ID, shared.ID)
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor")
resp, err := service.GetContractor(contractor.ID, shared.ID) resp, err := service.GetContractor(context.Background(), contractor.ID, shared.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Shared Contractor", resp.Name) assert.Equal(t, "Shared Contractor", resp.Name)
} }
@@ -190,7 +191,7 @@ func TestContractorService_ListContractors(t *testing.T) {
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1") testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2") testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
resp, err := service.ListContractors(user.ID) resp, err := service.ListContractors(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 2) assert.Len(t, resp, 2)
} }
@@ -208,11 +209,11 @@ func TestContractorService_DeleteContractor(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
err := service.DeleteContractor(contractor.ID, user.ID) err := service.DeleteContractor(context.Background(), contractor.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
// Should not be found after deletion // Should not be found after deletion
_, err = service.GetContractor(contractor.ID, user.ID) _, err = service.GetContractor(context.Background(), contractor.ID, user.ID)
assert.Error(t, err) assert.Error(t, err)
} }
@@ -225,7 +226,7 @@ func TestContractorService_DeleteContractor_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteContractor(9999, user.ID) err := service.DeleteContractor(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
} }
@@ -241,7 +242,7 @@ func TestContractorService_DeleteContractor_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
err := service.DeleteContractor(contractor.ID, other.ID) err := service.DeleteContractor(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
} }
@@ -259,17 +260,17 @@ func TestContractorService_ToggleFavorite(t *testing.T) {
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
// Initially not favorite // Initially not favorite
resp, err := service.GetContractor(contractor.ID, user.ID) resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.IsFavorite) assert.False(t, resp.IsFavorite)
// Toggle to favorite // Toggle to favorite
resp, err = service.ToggleFavorite(contractor.ID, user.ID) resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, resp.IsFavorite) assert.True(t, resp.IsFavorite)
// Toggle back // Toggle back
resp, err = service.ToggleFavorite(contractor.ID, user.ID) resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.IsFavorite) assert.False(t, resp.IsFavorite)
} }
@@ -283,7 +284,7 @@ func TestContractorService_ToggleFavorite_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.ToggleFavorite(9999, user.ID) _, err := service.ToggleFavorite(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
} }
@@ -299,7 +300,7 @@ func TestContractorService_ToggleFavorite_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.ToggleFavorite(contractor.ID, other.ID) _, err := service.ToggleFavorite(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
} }
@@ -317,7 +318,7 @@ func TestContractorService_ListContractorsByResidence(t *testing.T) {
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A") testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B") testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
resp, err := service.ListContractorsByResidence(residence.ID, user.ID) resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 2) assert.Len(t, resp, 2)
} }
@@ -333,7 +334,7 @@ func TestContractorService_ListContractorsByResidence_AccessDenied(t *testing.T)
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123") other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.ListContractorsByResidence(residence.ID, other.ID) _, err := service.ListContractorsByResidence(context.Background(), residence.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
} }
@@ -348,7 +349,7 @@ func TestContractorService_GetContractorTasks_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetContractorTasks(9999, user.ID) _, err := service.GetContractorTasks(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
} }
@@ -364,7 +365,7 @@ func TestContractorService_GetContractorTasks_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.GetContractorTasks(contractor.ID, other.ID) _, err := service.GetContractorTasks(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
} }
@@ -379,7 +380,7 @@ func TestContractorService_GetContractorTasks_Empty(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
resp, err := service.GetContractorTasks(contractor.ID, user.ID) resp, err := service.GetContractorTasks(context.Background(), contractor.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -393,7 +394,7 @@ func TestContractorService_GetSpecialties(t *testing.T) {
residenceRepo := repositories.NewResidenceRepository(db) residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo) service := NewContractorService(contractorRepo, residenceRepo)
resp, err := service.GetSpecialties() resp, err := service.GetSpecialties(context.Background())
require.NoError(t, err) require.NoError(t, err)
// SeedLookupData creates 4 specialties // SeedLookupData creates 4 specialties
assert.Len(t, resp, 4) assert.Len(t, resp, 4)
@@ -413,7 +414,7 @@ func TestContractorService_UpdateContractor_NotFound(t *testing.T) {
newName := "Won't Work" newName := "Won't Work"
req := &requests.UpdateContractorRequest{Name: &newName} req := &requests.UpdateContractorRequest{Name: &newName}
_, err := service.UpdateContractor(9999, user.ID, req) _, err := service.UpdateContractor(context.Background(), 9999, user.ID, req)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
} }
@@ -432,7 +433,7 @@ func TestContractorService_UpdateContractor_AccessDenied(t *testing.T) {
newName := "Hacked" newName := "Hacked"
req := &requests.UpdateContractorRequest{Name: &newName} req := &requests.UpdateContractorRequest{Name: &newName}
_, err := service.UpdateContractor(contractor.ID, other.ID, req) _, err := service.UpdateContractor(context.Background(), contractor.ID, other.ID, req)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
} }
@@ -461,7 +462,7 @@ func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
ResidenceID: &newResidenceID, ResidenceID: &newResidenceID,
} }
_, err := service.UpdateContractor(contractor.ID, attacker.ID, req) _, err := service.UpdateContractor(context.Background(), contractor.ID, attacker.ID, req)
require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to") require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden) testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
} }
@@ -486,7 +487,7 @@ func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) {
ResidenceID: &newResidenceID, ResidenceID: &newResidenceID,
} }
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req) resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req)
require.NoError(t, err, "should allow reassigning contractor to a residence the user owns") require.NoError(t, err, "should allow reassigning contractor to a residence the user owns")
require.NotNil(t, resp) require.NotNil(t, resp)
require.Equal(t, "Updated Contractor", resp.Name) require.Equal(t, "Updated Contractor", resp.Name)
@@ -508,7 +509,7 @@ func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
ResidenceID: nil, ResidenceID: nil,
} }
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req) resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req)
require.NoError(t, err, "should allow removing residence association") require.NoError(t, err, "should allow removing residence association")
require.NotNil(t, resp) require.NotNil(t, resp)
} }
@@ -555,7 +556,7 @@ func TestContractorService_UpdateContractor_PartialUpdate(t *testing.T) {
ResidenceID: &residence.ID, ResidenceID: &residence.ID,
} }
resp, err := service.UpdateContractor(contractor.ID, user.ID, req) resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Updated Plumber", resp.Name) assert.Equal(t, "Updated Plumber", resp.Name)
assert.Equal(t, "555-9999", resp.Phone) assert.Equal(t, "555-9999", resp.Phone)
@@ -588,7 +589,7 @@ func TestContractorService_UpdateContractor_WithSpecialties(t *testing.T) {
ResidenceID: &residence.ID, ResidenceID: &residence.ID,
} }
resp, err := service.UpdateContractor(contractor.ID, user.ID, req) resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
} }
@@ -615,7 +616,7 @@ func TestContractorService_CreateContractor_WithSpecialties(t *testing.T) {
SpecialtyIDs: []uint{specialties[0].ID}, SpecialtyIDs: []uint{specialties[0].ID},
} }
resp, err := service.CreateContractor(req, user.ID) resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Specialized Plumber", resp.Name) assert.Equal(t, "Specialized Plumber", resp.Name)
} }
@@ -631,7 +632,7 @@ func TestContractorService_ListContractors_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// No residence, no contractors // No residence, no contractors
resp, err := service.ListContractors(user.ID) resp, err := service.ListContractors(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -648,7 +649,7 @@ func TestContractorService_ListContractorsByResidence_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House") residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House")
resp, err := service.ListContractorsByResidence(residence.ID, user.ID) resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -669,14 +670,14 @@ func TestContractorService_PersonalContractor_OnlyCreatorAccess(t *testing.T) {
req := &requests.CreateContractorRequest{ req := &requests.CreateContractorRequest{
Name: "Personal Plumber", Name: "Personal Plumber",
} }
resp, err := service.CreateContractor(req, creator.ID) resp, err := service.CreateContractor(context.Background(), req, creator.ID)
require.NoError(t, err) require.NoError(t, err)
// Creator can access // Creator can access
_, err = service.GetContractor(resp.ID, creator.ID) _, err = service.GetContractor(context.Background(), resp.ID, creator.ID)
require.NoError(t, err) require.NoError(t, err)
// Other user cannot // Other user cannot
_, err = service.GetContractor(resp.ID, other.ID) _, err = service.GetContractor(context.Background(), resp.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
} }
+44 -43
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"errors" "errors"
"gorm.io/gorm" "gorm.io/gorm"
@@ -34,8 +35,8 @@ func NewDocumentService(documentRepo *repositories.DocumentRepository, residence
} }
// GetDocument gets a document by ID with access check // GetDocument gets a document by ID with access check
func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.DocumentResponse, error) { func (s *DocumentService) GetDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID) document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found") return nil, apperrors.NotFound("error.document_not_found")
@@ -44,7 +45,7 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
} }
// Check access via residence // Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -57,9 +58,9 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
} }
// ListDocuments lists all documents accessible to a user, with optional filters. // ListDocuments lists all documents accessible to a user, with optional filters.
func (s *DocumentService) ListDocuments(userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) { func (s *DocumentService) ListDocuments(ctx context.Context, userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) {
// Get residence IDs (lightweight - no preloads) // Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID) residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -83,7 +84,7 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume
residenceIDs = []uint{*filter.ResidenceID} residenceIDs = []uint{*filter.ResidenceID}
} }
documents, err := s.documentRepo.FindByUserFiltered(residenceIDs, filter) documents, err := s.documentRepo.WithContext(ctx).FindByUserFiltered(residenceIDs, filter)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -92,9 +93,9 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume
} }
// ListWarranties lists all warranty documents // ListWarranties lists all warranty documents
func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentResponse, error) { func (s *DocumentService) ListWarranties(ctx context.Context, userID uint) ([]responses.DocumentResponse, error) {
// Get residence IDs (lightweight - no preloads) // Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID) residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -103,7 +104,7 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
return []responses.DocumentResponse{}, nil return []responses.DocumentResponse{}, nil
} }
documents, err := s.documentRepo.FindWarranties(residenceIDs) documents, err := s.documentRepo.WithContext(ctx).FindWarranties(residenceIDs)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -112,9 +113,9 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
} }
// CreateDocument creates a new document // CreateDocument creates a new document
func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) { func (s *DocumentService) CreateDocument(ctx context.Context, req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) {
// Check residence access // Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(req.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -147,7 +148,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
IsActive: true, IsActive: true,
} }
if err := s.documentRepo.Create(document); err != nil { if err := s.documentRepo.WithContext(ctx).Create(document); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -158,7 +159,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
DocumentID: document.ID, DocumentID: document.ID,
ImageURL: imageURL, ImageURL: imageURL,
} }
if err := s.documentRepo.CreateDocumentImage(img); err != nil { if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil {
// Log but don't fail the whole operation // Log but don't fail the whole operation
continue continue
} }
@@ -166,7 +167,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
} }
// Reload with relations // Reload with relations
document, err = s.documentRepo.FindByID(document.ID) document, err = s.documentRepo.WithContext(ctx).FindByID(document.ID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -176,8 +177,8 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
} }
// UpdateDocument updates a document // UpdateDocument updates a document
func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) { func (s *DocumentService) UpdateDocument(ctx context.Context, documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID) document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found") return nil, apperrors.NotFound("error.document_not_found")
@@ -186,7 +187,7 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
} }
// Check access // Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -238,12 +239,12 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
document.TaskID = req.TaskID document.TaskID = req.TaskID
} }
if err := s.documentRepo.Update(document); err != nil { if err := s.documentRepo.WithContext(ctx).Update(document); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reload // Reload
document, err = s.documentRepo.FindByID(documentID) document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -253,8 +254,8 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
} }
// DeleteDocument soft-deletes a document // DeleteDocument soft-deletes a document
func (s *DocumentService) DeleteDocument(documentID, userID uint) error { func (s *DocumentService) DeleteDocument(ctx context.Context, documentID, userID uint) error {
document, err := s.documentRepo.FindByID(documentID) document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.document_not_found") return apperrors.NotFound("error.document_not_found")
@@ -263,7 +264,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
} }
// Check access // Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -271,7 +272,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
return apperrors.Forbidden("error.document_access_denied") return apperrors.Forbidden("error.document_access_denied")
} }
if err := s.documentRepo.Delete(documentID); err != nil { if err := s.documentRepo.WithContext(ctx).Delete(documentID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -279,15 +280,15 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
} }
// ActivateDocument activates a document // ActivateDocument activates a document
func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) { func (s *DocumentService) ActivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
// First check if document exists (even if inactive) // First check if document exists (even if inactive)
var document models.Document var document models.Document
if err := s.documentRepo.FindByIDIncludingInactive(documentID, &document); err != nil { if err := s.documentRepo.WithContext(ctx).FindByIDIncludingInactive(documentID, &document); err != nil {
return nil, apperrors.NotFound("error.document_not_found") return nil, apperrors.NotFound("error.document_not_found")
} }
// Check access // Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -295,12 +296,12 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
return nil, apperrors.Forbidden("error.document_access_denied") return nil, apperrors.Forbidden("error.document_access_denied")
} }
if err := s.documentRepo.Activate(documentID); err != nil { if err := s.documentRepo.WithContext(ctx).Activate(documentID); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reload // Reload
doc, err := s.documentRepo.FindByID(documentID) doc, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -310,8 +311,8 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
} }
// DeactivateDocument deactivates a document // DeactivateDocument deactivates a document
func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) { func (s *DocumentService) DeactivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID) document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found") return nil, apperrors.NotFound("error.document_not_found")
@@ -320,7 +321,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
} }
// Check access // Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -328,7 +329,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
return nil, apperrors.Forbidden("error.document_access_denied") return nil, apperrors.Forbidden("error.document_access_denied")
} }
if err := s.documentRepo.Deactivate(documentID); err != nil { if err := s.documentRepo.WithContext(ctx).Deactivate(documentID); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -338,8 +339,8 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
} }
// UploadDocumentImage adds an image to an existing document // UploadDocumentImage adds an image to an existing document
func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) { func (s *DocumentService) UploadDocumentImage(ctx context.Context, documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID) document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found") return nil, apperrors.NotFound("error.document_not_found")
@@ -348,7 +349,7 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
} }
// Check access via residence // Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -361,12 +362,12 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
ImageURL: imageURL, ImageURL: imageURL,
Caption: caption, Caption: caption,
} }
if err := s.documentRepo.CreateDocumentImage(img); err != nil { if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reload with relations // Reload with relations
document, err = s.documentRepo.FindByID(documentID) document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -376,9 +377,9 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
} }
// DeleteDocumentImage removes an image from a document // DeleteDocumentImage removes an image from a document
func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) (*responses.DocumentResponse, error) { func (s *DocumentService) DeleteDocumentImage(ctx context.Context, documentID, imageID, userID uint) (*responses.DocumentResponse, error) {
// Find the image first // Find the image first
image, err := s.documentRepo.FindImageByID(imageID) image, err := s.documentRepo.WithContext(ctx).FindImageByID(imageID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_image_not_found") return nil, apperrors.NotFound("error.document_image_not_found")
@@ -392,7 +393,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
} }
// Find parent document to check access // Find parent document to check access
document, err := s.documentRepo.FindByID(documentID) document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found") return nil, apperrors.NotFound("error.document_not_found")
@@ -401,7 +402,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
} }
// Check access via residence // Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -409,12 +410,12 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
return nil, apperrors.Forbidden("error.document_access_denied") return nil, apperrors.Forbidden("error.document_access_denied")
} }
if err := s.documentRepo.DeleteDocumentImage(imageID); err != nil { if err := s.documentRepo.WithContext(ctx).DeleteDocumentImage(imageID); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reload with relations // Reload with relations
document, err = s.documentRepo.FindByID(documentID) document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
+42 -41
View File
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
@@ -42,7 +43,7 @@ func TestDocumentService_CreateDocument(t *testing.T) {
FileName: "manual.pdf", FileName: "manual.pdf",
} }
resp, err := service.CreateDocument(req, user.ID) resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, "Furnace Manual", resp.Title) assert.Equal(t, "Furnace Manual", resp.Title)
@@ -64,7 +65,7 @@ func TestDocumentService_CreateDocument_DefaultType(t *testing.T) {
// DocumentType not set — should default to "general" // DocumentType not set — should default to "general"
} }
resp, err := service.CreateDocument(req, user.ID) resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType) assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType)
} }
@@ -84,7 +85,7 @@ func TestDocumentService_CreateDocument_WithImages(t *testing.T) {
ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"}, ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"},
} }
resp, err := service.CreateDocument(req, user.ID) resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
assert.Equal(t, "Receipt with photos", resp.Title) assert.Equal(t, "Receipt with photos", resp.Title)
@@ -105,7 +106,7 @@ func TestDocumentService_CreateDocument_AccessDenied(t *testing.T) {
Title: "Unauthorized Doc", Title: "Unauthorized Doc",
} }
_, err := service.CreateDocument(req, other.ID) _, err := service.CreateDocument(context.Background(), req, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
} }
@@ -121,7 +122,7 @@ func TestDocumentService_GetDocument(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
resp, err := service.GetDocument(doc.ID, user.ID) resp, err := service.GetDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, doc.ID, resp.ID) assert.Equal(t, doc.ID, resp.ID)
assert.Equal(t, "Test Doc", resp.Title) assert.Equal(t, "Test Doc", resp.Title)
@@ -135,7 +136,7 @@ func TestDocumentService_GetDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetDocument(9999, user.ID) _, err := service.GetDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
} }
@@ -150,7 +151,7 @@ func TestDocumentService_GetDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.GetDocument(doc.ID, other.ID) _, err := service.GetDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -173,7 +174,7 @@ func TestDocumentService_UpdateDocument(t *testing.T) {
Description: &newDesc, Description: &newDesc,
} }
resp, err := service.UpdateDocument(doc.ID, user.ID, req) resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Updated Title", resp.Title) assert.Equal(t, "Updated Title", resp.Title)
assert.Equal(t, "Updated description", resp.Description) assert.Equal(t, "Updated description", resp.Description)
@@ -190,7 +191,7 @@ func TestDocumentService_UpdateDocument_NotFound(t *testing.T) {
newTitle := "Won't Work" newTitle := "Won't Work"
req := &requests.UpdateDocumentRequest{Title: &newTitle} req := &requests.UpdateDocumentRequest{Title: &newTitle}
_, err := service.UpdateDocument(9999, user.ID, req) _, err := service.UpdateDocument(context.Background(), 9999, user.ID, req)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
} }
@@ -208,7 +209,7 @@ func TestDocumentService_UpdateDocument_AccessDenied(t *testing.T) {
newTitle := "Hacked" newTitle := "Hacked"
req := &requests.UpdateDocumentRequest{Title: &newTitle} req := &requests.UpdateDocumentRequest{Title: &newTitle}
_, err := service.UpdateDocument(doc.ID, other.ID, req) _, err := service.UpdateDocument(context.Background(), doc.ID, other.ID, req)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -225,7 +226,7 @@ func TestDocumentService_UpdateDocument_ChangeType(t *testing.T) {
newType := models.DocumentTypeWarranty newType := models.DocumentTypeWarranty
req := &requests.UpdateDocumentRequest{DocumentType: &newType} req := &requests.UpdateDocumentRequest{DocumentType: &newType}
resp, err := service.UpdateDocument(doc.ID, user.ID, req) resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType) assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
} }
@@ -242,11 +243,11 @@ func TestDocumentService_DeleteDocument(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
err := service.DeleteDocument(doc.ID, user.ID) err := service.DeleteDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
// Should not be found after deletion // Should not be found after deletion
_, err = service.GetDocument(doc.ID, user.ID) _, err = service.GetDocument(context.Background(), doc.ID, user.ID)
assert.Error(t, err) assert.Error(t, err)
} }
@@ -258,7 +259,7 @@ func TestDocumentService_DeleteDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDocument(9999, user.ID) err := service.DeleteDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
} }
@@ -273,7 +274,7 @@ func TestDocumentService_DeleteDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
err := service.DeleteDocument(doc.ID, other.ID) err := service.DeleteDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -290,7 +291,7 @@ func TestDocumentService_ListDocuments(t *testing.T) {
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1") testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2") testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
resp, err := service.ListDocuments(user.ID, nil) resp, err := service.ListDocuments(context.Background(), user.ID, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 2) assert.Len(t, resp, 2)
} }
@@ -303,7 +304,7 @@ func TestDocumentService_ListDocuments_NoResidences(t *testing.T) {
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListDocuments(user.ID, nil) resp, err := service.ListDocuments(context.Background(), user.ID, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -321,7 +322,7 @@ func TestDocumentService_ListDocuments_FilterByResidence(t *testing.T) {
testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B") testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B")
filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID} filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID}
resp, err := service.ListDocuments(user.ID, filter) resp, err := service.ListDocuments(context.Background(), user.ID, filter)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 1) assert.Len(t, resp, 1)
assert.Equal(t, "Doc A", resp[0].Title) assert.Equal(t, "Doc A", resp[0].Title)
@@ -340,7 +341,7 @@ func TestDocumentService_ListDocuments_FilterByResidence_AccessDenied(t *testing
testutil.CreateTestResidence(t, db, other.ID, "Other House") testutil.CreateTestResidence(t, db, other.ID, "Other House")
filter := &repositories.DocumentFilter{ResidenceID: &residence.ID} filter := &repositories.DocumentFilter{ResidenceID: &residence.ID}
_, err := service.ListDocuments(other.ID, filter) _, err := service.ListDocuments(context.Background(), other.ID, filter)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
} }
@@ -369,7 +370,7 @@ func TestDocumentService_ListWarranties(t *testing.T) {
// Create a general doc // Create a general doc
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc") testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
resp, err := service.ListWarranties(user.ID) resp, err := service.ListWarranties(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 1) assert.Len(t, resp, 1)
assert.Equal(t, "HVAC Warranty", resp[0].Title) assert.Equal(t, "HVAC Warranty", resp[0].Title)
@@ -383,7 +384,7 @@ func TestDocumentService_ListWarranties_NoResidences(t *testing.T) {
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListWarranties(user.ID) resp, err := service.ListWarranties(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -400,7 +401,7 @@ func TestDocumentService_DeactivateDocument(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate")
resp, err := service.DeactivateDocument(doc.ID, user.ID) resp, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.IsActive) assert.False(t, resp.IsActive)
} }
@@ -413,7 +414,7 @@ func TestDocumentService_DeactivateDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.DeactivateDocument(9999, user.ID) _, err := service.DeactivateDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
} }
@@ -428,7 +429,7 @@ func TestDocumentService_DeactivateDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.DeactivateDocument(doc.ID, other.ID) _, err := service.DeactivateDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -444,7 +445,7 @@ func TestDocumentService_UploadDocumentImage(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
resp, err := service.UploadDocumentImage(doc.ID, user.ID, "https://example.com/photo.jpg", "Front view") resp, err := service.UploadDocumentImage(context.Background(), doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
} }
@@ -457,7 +458,7 @@ func TestDocumentService_UploadDocumentImage_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.UploadDocumentImage(9999, user.ID, "https://example.com/photo.jpg", "") _, err := service.UploadDocumentImage(context.Background(), 9999, user.ID, "https://example.com/photo.jpg", "")
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
} }
@@ -472,7 +473,7 @@ func TestDocumentService_UploadDocumentImage_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.UploadDocumentImage(doc.ID, other.ID, "https://example.com/photo.jpg", "") _, err := service.UploadDocumentImage(context.Background(), doc.ID, other.ID, "https://example.com/photo.jpg", "")
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -496,7 +497,7 @@ func TestDocumentService_DeleteDocumentImage(t *testing.T) {
err := db.Create(img).Error err := db.Create(img).Error
require.NoError(t, err) require.NoError(t, err)
resp, err := service.DeleteDocumentImage(doc.ID, img.ID, user.ID) resp, err := service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
} }
@@ -511,7 +512,7 @@ func TestDocumentService_DeleteDocumentImage_NotFound(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
_, err := service.DeleteDocumentImage(doc.ID, 9999, user.ID) _, err := service.DeleteDocumentImage(context.Background(), doc.ID, 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
} }
@@ -535,7 +536,7 @@ func TestDocumentService_DeleteDocumentImage_WrongDocument(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Try to delete the image specifying doc2 // Try to delete the image specifying doc2
_, err = service.DeleteDocumentImage(doc2.ID, img.ID, user.ID) _, err = service.DeleteDocumentImage(context.Background(), doc2.ID, img.ID, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
} }
@@ -557,7 +558,7 @@ func TestDocumentService_DeleteDocumentImage_AccessDenied(t *testing.T) {
err := db.Create(img).Error err := db.Create(img).Error
require.NoError(t, err) require.NoError(t, err)
_, err = service.DeleteDocumentImage(doc.ID, img.ID, other.ID) _, err = service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -575,7 +576,7 @@ func TestDocumentService_GetDocument_SharedUserHasAccess(t *testing.T) {
residenceRepo.AddUser(residence.ID, shared.ID) residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
resp, err := service.GetDocument(doc.ID, shared.ID) resp, err := service.GetDocument(context.Background(), doc.ID, shared.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Shared Doc", resp.Title) assert.Equal(t, "Shared Doc", resp.Title)
} }
@@ -593,11 +594,11 @@ func TestDocumentService_ActivateDocument(t *testing.T) {
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate")
// Deactivate first // Deactivate first
_, err := service.DeactivateDocument(doc.ID, user.ID) _, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
// Now activate // Now activate
resp, err := service.ActivateDocument(doc.ID, user.ID) resp, err := service.ActivateDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, resp.IsActive) assert.True(t, resp.IsActive)
} }
@@ -610,7 +611,7 @@ func TestDocumentService_ActivateDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.ActivateDocument(9999, user.ID) _, err := service.ActivateDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
} }
@@ -625,7 +626,7 @@ func TestDocumentService_ActivateDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.ActivateDocument(doc.ID, other.ID) _, err := service.ActivateDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
} }
@@ -646,7 +647,7 @@ func TestDocumentService_CreateDocument_WithEmptyImageURL(t *testing.T) {
ImageURLs: []string{"", "https://example.com/img.jpg", ""}, ImageURLs: []string{"", "https://example.com/img.jpg", ""},
} }
resp, err := service.CreateDocument(req, user.ID) resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
} }
@@ -687,7 +688,7 @@ func TestDocumentService_UpdateDocument_AllFields(t *testing.T) {
ModelNumber: &newModel, ModelNumber: &newModel,
} }
resp, err := service.UpdateDocument(doc.ID, user.ID, req) resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Updated", resp.Title) assert.Equal(t, "Updated", resp.Title)
assert.Equal(t, "New description", resp.Description) assert.Equal(t, "New description", resp.Description)
@@ -720,7 +721,7 @@ func TestDocumentService_ListDocuments_FilterByType(t *testing.T) {
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc") testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)} filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)}
resp, err := service.ListDocuments(user.ID, filter) resp, err := service.ListDocuments(context.Background(), user.ID, filter)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 1) assert.Len(t, resp, 1)
assert.Equal(t, "Warranty Doc", resp[0].Title) assert.Equal(t, "Warranty Doc", resp[0].Title)
@@ -742,7 +743,7 @@ func TestDocumentService_SharedUser_CanUpdate(t *testing.T) {
newTitle := "Updated by shared user" newTitle := "Updated by shared user"
req := &requests.UpdateDocumentRequest{Title: &newTitle} req := &requests.UpdateDocumentRequest{Title: &newTitle}
resp, err := service.UpdateDocument(doc.ID, shared.ID, req) resp, err := service.UpdateDocument(context.Background(), doc.ID, shared.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Updated by shared user", resp.Title) assert.Equal(t, "Updated by shared user", resp.Title)
} }
@@ -759,6 +760,6 @@ func TestDocumentService_SharedUser_CanDelete(t *testing.T) {
residenceRepo.AddUser(residence.ID, shared.ID) residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
err := service.DeleteDocument(doc.ID, shared.ID) err := service.DeleteDocument(context.Background(), doc.ID, shared.ID)
require.NoError(t, err) require.NoError(t, err)
} }
+51 -51
View File
@@ -44,8 +44,8 @@ func NewNotificationService(notificationRepo *repositories.NotificationRepositor
// === Notifications === // === Notifications ===
// GetNotifications gets notifications for a user // GetNotifications gets notifications for a user
func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ([]NotificationResponse, error) { func (s *NotificationService) GetNotifications(ctx context.Context, userID uint, limit, offset int) ([]NotificationResponse, error) {
notifications, err := s.notificationRepo.FindByUser(userID, limit, offset) notifications, err := s.notificationRepo.WithContext(ctx).FindByUser(userID, limit, offset)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -58,8 +58,8 @@ func (s *NotificationService) GetNotifications(userID uint, limit, offset int) (
} }
// GetUnreadCount gets the count of unread notifications // GetUnreadCount gets the count of unread notifications
func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) { func (s *NotificationService) GetUnreadCount(ctx context.Context, userID uint) (int64, error) {
count, err := s.notificationRepo.CountUnread(userID) count, err := s.notificationRepo.WithContext(ctx).CountUnread(userID)
if err != nil { if err != nil {
return 0, apperrors.Internal(err) return 0, apperrors.Internal(err)
} }
@@ -67,8 +67,8 @@ func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
} }
// MarkAsRead marks a notification as read // MarkAsRead marks a notification as read
func (s *NotificationService) MarkAsRead(notificationID, userID uint) error { func (s *NotificationService) MarkAsRead(ctx context.Context, notificationID, userID uint) error {
notification, err := s.notificationRepo.FindByID(notificationID) notification, err := s.notificationRepo.WithContext(ctx).FindByID(notificationID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.notification_not_found") return apperrors.NotFound("error.notification_not_found")
@@ -80,15 +80,15 @@ func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
return apperrors.NotFound("error.notification_not_found") return apperrors.NotFound("error.notification_not_found")
} }
if err := s.notificationRepo.MarkAsRead(notificationID); err != nil { if err := s.notificationRepo.WithContext(ctx).MarkAsRead(notificationID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
return nil return nil
} }
// MarkAllAsRead marks all notifications as read // MarkAllAsRead marks all notifications as read
func (s *NotificationService) MarkAllAsRead(userID uint) error { func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID uint) error {
if err := s.notificationRepo.MarkAllAsRead(userID); err != nil { if err := s.notificationRepo.WithContext(ctx).MarkAllAsRead(userID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
return nil return nil
@@ -97,7 +97,7 @@ func (s *NotificationService) MarkAllAsRead(userID uint) error {
// CreateAndSendNotification creates a notification and sends it via push // CreateAndSendNotification creates a notification and sends it via push
func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error { func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error {
// Check user preferences // Check user preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -117,12 +117,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
Data: string(dataJSON), Data: string(dataJSON),
} }
if err := s.notificationRepo.Create(notification); err != nil { if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
// Get device tokens // Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID) iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -144,12 +144,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
if s.pushClient != nil { if s.pushClient != nil {
err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData) err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData)
if err != nil { if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error()) s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error())
return apperrors.Internal(err) return apperrors.Internal(err)
} }
} }
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil { if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
return nil return nil
@@ -178,8 +178,8 @@ func (s *NotificationService) isNotificationEnabled(prefs *models.NotificationPr
// === Notification Preferences === // === Notification Preferences ===
// GetPreferences gets notification preferences for a user // GetPreferences gets notification preferences for a user
func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferencesResponse, error) { func (s *NotificationService) GetPreferences(ctx context.Context, userID uint) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -196,7 +196,7 @@ func validateHourField(val *int, fieldName string) error {
} }
// UpdatePreferences updates notification preferences // UpdatePreferences updates notification preferences
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) { func (s *NotificationService) UpdatePreferences(ctx context.Context, userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
// B-12: Validate hour fields are in range 0-23 // B-12: Validate hour fields are in range 0-23
hourFields := []struct { hourFields := []struct {
value *int value *int
@@ -213,7 +213,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
} }
} }
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -258,7 +258,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
prefs.DailyDigestHour = req.DailyDigestHour prefs.DailyDigestHour = req.DailyDigestHour
} }
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil { if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -268,14 +268,14 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
// UpdateUserTimezone updates the user's timezone for background job calculations. // UpdateUserTimezone updates the user's timezone for background job calculations.
// This is called automatically when the user makes API calls (e.g., fetching tasks). // This is called automatically when the user makes API calls (e.g., fetching tasks).
// The timezone should be an IANA timezone name (e.g., "America/Los_Angeles"). // The timezone should be an IANA timezone name (e.g., "America/Los_Angeles").
func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) { func (s *NotificationService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) {
// Validate timezone is a valid IANA name // Validate timezone is a valid IANA name
if _, err := time.LoadLocation(timezone); err != nil { if _, err := time.LoadLocation(timezone); err != nil {
return // Invalid timezone, skip silently return // Invalid timezone, skip silently
} }
// Get or create preferences and update timezone // Get or create preferences and update timezone
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil { if err != nil {
return // Skip silently on error return // Skip silently on error
} }
@@ -283,7 +283,7 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
// Only update if timezone changed (avoid unnecessary DB writes) // Only update if timezone changed (avoid unnecessary DB writes)
if prefs.Timezone == nil || *prefs.Timezone != timezone { if prefs.Timezone == nil || *prefs.Timezone != timezone {
prefs.Timezone = &timezone prefs.Timezone = &timezone
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil { if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone). log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
Msg("Failed to update user timezone in notification preferences") Msg("Failed to update user timezone in notification preferences")
} }
@@ -293,27 +293,27 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
// === Device Registration === // === Device Registration ===
// RegisterDevice registers a device for push notifications // RegisterDevice registers a device for push notifications
func (s *NotificationService) RegisterDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { func (s *NotificationService) RegisterDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
switch req.Platform { switch req.Platform {
case push.PlatformIOS: case push.PlatformIOS:
return s.registerAPNSDevice(userID, req) return s.registerAPNSDevice(ctx, userID, req)
case push.PlatformAndroid: case push.PlatformAndroid:
return s.registerGCMDevice(userID, req) return s.registerGCMDevice(ctx, userID, req)
default: default:
return nil, apperrors.BadRequest("error.invalid_platform") return nil, apperrors.BadRequest("error.invalid_platform")
} }
} }
func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
// Check if device exists // Check if device exists
existing, err := s.notificationRepo.FindAPNSDeviceByToken(req.RegistrationID) existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID)
if err == nil { if err == nil {
// Update existing device // Update existing device
existing.UserID = &userID existing.UserID = &userID
existing.Active = true existing.Active = true
existing.Name = req.Name existing.Name = req.Name
existing.DeviceID = req.DeviceID existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateAPNSDevice(existing); err != nil { if err := s.notificationRepo.WithContext(ctx).UpdateAPNSDevice(existing); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return NewAPNSDeviceResponse(existing), nil return NewAPNSDeviceResponse(existing), nil
@@ -327,22 +327,22 @@ func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDevic
RegistrationID: req.RegistrationID, RegistrationID: req.RegistrationID,
Active: true, Active: true,
} }
if err := s.notificationRepo.CreateAPNSDevice(device); err != nil { if err := s.notificationRepo.WithContext(ctx).CreateAPNSDevice(device); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return NewAPNSDeviceResponse(device), nil return NewAPNSDeviceResponse(device), nil
} }
func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
// Check if device exists // Check if device exists
existing, err := s.notificationRepo.FindGCMDeviceByToken(req.RegistrationID) existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID)
if err == nil { if err == nil {
// Update existing device // Update existing device
existing.UserID = &userID existing.UserID = &userID
existing.Active = true existing.Active = true
existing.Name = req.Name existing.Name = req.Name
existing.DeviceID = req.DeviceID existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateGCMDevice(existing); err != nil { if err := s.notificationRepo.WithContext(ctx).UpdateGCMDevice(existing); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return NewGCMDeviceResponse(existing), nil return NewGCMDeviceResponse(existing), nil
@@ -357,20 +357,20 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice
CloudMessageType: "FCM", CloudMessageType: "FCM",
Active: true, Active: true,
} }
if err := s.notificationRepo.CreateGCMDevice(device); err != nil { if err := s.notificationRepo.WithContext(ctx).CreateGCMDevice(device); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return NewGCMDeviceResponse(device), nil return NewGCMDeviceResponse(device), nil
} }
// ListDevices lists all devices for a user // ListDevices lists all devices for a user
func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) { func (s *NotificationService) ListDevices(ctx context.Context, userID uint) ([]DeviceResponse, error) {
iosDevices, err := s.notificationRepo.FindAPNSDevicesByUser(userID) iosDevices, err := s.notificationRepo.WithContext(ctx).FindAPNSDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
androidDevices, err := s.notificationRepo.FindGCMDevicesByUser(userID) androidDevices, err := s.notificationRepo.WithContext(ctx).FindGCMDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -387,10 +387,10 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error)
// DeleteDevice deactivates a device after verifying it belongs to the requesting user. // DeleteDevice deactivates a device after verifying it belongs to the requesting user.
// Without ownership verification, an attacker could deactivate push notifications for other users. // Without ownership verification, an attacker could deactivate push notifications for other users.
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error { func (s *NotificationService) DeleteDevice(ctx context.Context, deviceID uint, platform string, userID uint) error {
switch platform { switch platform {
case push.PlatformIOS: case push.PlatformIOS:
device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID) device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByID(deviceID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.device_not_found") return apperrors.NotFound("error.device_not_found")
@@ -401,11 +401,11 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
if device.UserID == nil || *device.UserID != userID { if device.UserID == nil || *device.UserID != userID {
return apperrors.Forbidden("error.device_access_denied") return apperrors.Forbidden("error.device_access_denied")
} }
if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil { if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(deviceID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
case push.PlatformAndroid: case push.PlatformAndroid:
device, err := s.notificationRepo.FindGCMDeviceByID(deviceID) device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByID(deviceID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.device_not_found") return apperrors.NotFound("error.device_not_found")
@@ -416,7 +416,7 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
if device.UserID == nil || *device.UserID != userID { if device.UserID == nil || *device.UserID != userID {
return apperrors.Forbidden("error.device_access_denied") return apperrors.Forbidden("error.device_access_denied")
} }
if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil { if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(deviceID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
default: default:
@@ -426,10 +426,10 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
} }
// UnregisterDevice deactivates a device by its registration token // UnregisterDevice deactivates a device by its registration token
func (s *NotificationService) UnregisterDevice(registrationID, platform string, userID uint) error { func (s *NotificationService) UnregisterDevice(ctx context.Context, registrationID, platform string, userID uint) error {
switch platform { switch platform {
case push.PlatformIOS: case push.PlatformIOS:
device, err := s.notificationRepo.FindAPNSDeviceByToken(registrationID) device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(registrationID)
if err != nil { if err != nil {
return apperrors.NotFound("error.device_not_found") return apperrors.NotFound("error.device_not_found")
} }
@@ -437,11 +437,11 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
if device.UserID == nil || *device.UserID != userID { if device.UserID == nil || *device.UserID != userID {
return apperrors.NotFound("error.device_not_found") return apperrors.NotFound("error.device_not_found")
} }
if err := s.notificationRepo.DeactivateAPNSDevice(device.ID); err != nil { if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(device.ID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
case push.PlatformAndroid: case push.PlatformAndroid:
device, err := s.notificationRepo.FindGCMDeviceByToken(registrationID) device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(registrationID)
if err != nil { if err != nil {
return apperrors.NotFound("error.device_not_found") return apperrors.NotFound("error.device_not_found")
} }
@@ -449,7 +449,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
if device.UserID == nil || *device.UserID != userID { if device.UserID == nil || *device.UserID != userID {
return apperrors.NotFound("error.device_not_found") return apperrors.NotFound("error.device_not_found")
} }
if err := s.notificationRepo.DeactivateGCMDevice(device.ID); err != nil { if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(device.ID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
default: default:
@@ -624,7 +624,7 @@ func (s *NotificationService) CreateAndSendTaskNotification(
task *models.Task, task *models.Task,
) error { ) error {
// Check user notification preferences // Check user notification preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -662,12 +662,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
TaskID: &task.ID, TaskID: &task.ID,
} }
if err := s.notificationRepo.Create(notification); err != nil { if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
// Get device tokens // Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID) iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -691,12 +691,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
if s.pushClient != nil { if s.pushClient != nil {
err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID) err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID)
if err != nil { if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error()) s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error())
return apperrors.Internal(err) return apperrors.Internal(err)
} }
} }
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil { if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
return nil return nil
+58 -58
View File
@@ -44,7 +44,7 @@ func TestNotificationService_GetNotifications(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
resp, err := service.GetNotifications(user.ID, 10, 0) resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 3) assert.Len(t, resp, 3)
} }
@@ -56,7 +56,7 @@ func TestNotificationService_GetNotifications_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := service.GetNotifications(user.ID, 10, 0) resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -80,7 +80,7 @@ func TestNotificationService_GetNotifications_Pagination(t *testing.T) {
} }
// Get first 2 // Get first 2
resp, err := service.GetNotifications(user.ID, 2, 0) resp, err := service.GetNotifications(context.Background(), user.ID, 2, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 2) assert.Len(t, resp, 2)
} }
@@ -107,7 +107,7 @@ func TestNotificationService_GetUnreadCount(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
count, err := service.GetUnreadCount(user.ID) count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(3), count) assert.Equal(t, int64(3), count)
} }
@@ -119,7 +119,7 @@ func TestNotificationService_GetUnreadCount_Zero(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
count, err := service.GetUnreadCount(user.ID) count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
@@ -142,11 +142,11 @@ func TestNotificationService_MarkAsRead(t *testing.T) {
err := db.Create(notif).Error err := db.Create(notif).Error
require.NoError(t, err) require.NoError(t, err)
err = service.MarkAsRead(notif.ID, user.ID) err = service.MarkAsRead(context.Background(), notif.ID, user.ID)
require.NoError(t, err) require.NoError(t, err)
// Verify unread count is 0 // Verify unread count is 0
count, err := service.GetUnreadCount(user.ID) count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
@@ -168,7 +168,7 @@ func TestNotificationService_MarkAsRead_WrongUser(t *testing.T) {
err := db.Create(notif).Error err := db.Create(notif).Error
require.NoError(t, err) require.NoError(t, err)
err = service.MarkAsRead(notif.ID, other.ID) err = service.MarkAsRead(context.Background(), notif.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
} }
@@ -179,7 +179,7 @@ func TestNotificationService_MarkAsRead_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.MarkAsRead(9999, user.ID) err := service.MarkAsRead(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
} }
@@ -204,10 +204,10 @@ func TestNotificationService_MarkAllAsRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
err := service.MarkAllAsRead(user.ID) err := service.MarkAllAsRead(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
count, err := service.GetUnreadCount(user.ID) count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
} }
@@ -229,7 +229,7 @@ func TestNotificationService_CreateAndSendNotification(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify notification was created // Verify notification was created
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
assert.Equal(t, "Due Soon", notifs[0].Title) assert.Equal(t, "Due Soon", notifs[0].Title)
@@ -254,7 +254,7 @@ func TestNotificationService_CreateAndSendNotification_DisabledPreference(t *tes
require.NoError(t, err) require.NoError(t, err)
// Verify no notification was created (silently skipped) // Verify no notification was created (silently skipped)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, notifs) assert.Empty(t, notifs)
} }
@@ -268,7 +268,7 @@ func TestNotificationService_GetPreferences(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := service.GetPreferences(user.ID) resp, err := service.GetPreferences(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, resp) assert.NotNil(t, resp)
// Defaults should all be true // Defaults should all be true
@@ -289,7 +289,7 @@ func TestNotificationService_UpdatePreferences(t *testing.T) {
TaskDueSoon: &falseVal, TaskDueSoon: &falseVal,
} }
resp, err := service.UpdatePreferences(user.ID, req) resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.TaskDueSoon) assert.False(t, resp.TaskDueSoon)
assert.True(t, resp.TaskOverdue) // unchanged assert.True(t, resp.TaskOverdue) // unchanged
@@ -307,7 +307,7 @@ func TestNotificationService_UpdatePreferences_InvalidHour(t *testing.T) {
TaskDueSoonHour: &invalidHour, TaskDueSoonHour: &invalidHour,
} }
_, err := service.UpdatePreferences(user.ID, req) _, err := service.UpdatePreferences(context.Background(), user.ID, req)
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest) testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
} }
@@ -323,7 +323,7 @@ func TestNotificationService_UpdatePreferences_ValidHour(t *testing.T) {
TaskDueSoonHour: &hour, TaskDueSoonHour: &hour,
} }
resp, err := service.UpdatePreferences(user.ID, req) resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 9, *resp.TaskDueSoonHour) assert.Equal(t, 9, *resp.TaskDueSoonHour)
} }
@@ -344,7 +344,7 @@ func TestNotificationService_RegisterDevice_iOS(t *testing.T) {
Platform: push.PlatformIOS, Platform: push.PlatformIOS,
} }
resp, err := service.RegisterDevice(user.ID, req) resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "iPhone 15", resp.Name) assert.Equal(t, "iPhone 15", resp.Name)
assert.Equal(t, push.PlatformIOS, resp.Platform) assert.Equal(t, push.PlatformIOS, resp.Platform)
@@ -365,7 +365,7 @@ func TestNotificationService_RegisterDevice_Android(t *testing.T) {
Platform: push.PlatformAndroid, Platform: push.PlatformAndroid,
} }
resp, err := service.RegisterDevice(user.ID, req) resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Pixel 8", resp.Name) assert.Equal(t, "Pixel 8", resp.Name)
assert.Equal(t, push.PlatformAndroid, resp.Platform) assert.Equal(t, push.PlatformAndroid, resp.Platform)
@@ -386,7 +386,7 @@ func TestNotificationService_RegisterDevice_InvalidPlatform(t *testing.T) {
Platform: "windows", Platform: "windows",
} }
_, err := service.RegisterDevice(user.ID, req) _, err := service.RegisterDevice(context.Background(), user.ID, req)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
} }
@@ -404,12 +404,12 @@ func TestNotificationService_RegisterDevice_UpdateExisting(t *testing.T) {
RegistrationID: "token-xyz", RegistrationID: "token-xyz",
Platform: push.PlatformIOS, Platform: push.PlatformIOS,
} }
_, err := service.RegisterDevice(user.ID, req) _, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
// Re-register with same token (should update, not duplicate) // Re-register with same token (should update, not duplicate)
req.Name = "iPhone 15 Pro" req.Name = "iPhone 15 Pro"
resp, err := service.RegisterDevice(user.ID, req) resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "iPhone 15 Pro", resp.Name) assert.Equal(t, "iPhone 15 Pro", resp.Name)
} }
@@ -445,7 +445,7 @@ func TestNotificationService_ListDevices(t *testing.T) {
err = db.Create(androidDevice).Error err = db.Create(androidDevice).Error
require.NoError(t, err) require.NoError(t, err)
resp, err := service.ListDevices(user.ID) resp, err := service.ListDevices(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, resp, 2) assert.Len(t, resp, 2)
} }
@@ -457,7 +457,7 @@ func TestNotificationService_ListDevices_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := service.ListDevices(user.ID) resp, err := service.ListDevices(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, resp) assert.Empty(t, resp)
} }
@@ -471,7 +471,7 @@ func TestDeleteDevice_InvalidPlatform_Returns400(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDevice(1, "windows", user.ID) err := service.DeleteDevice(context.Background(), 1, "windows", user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
} }
@@ -494,7 +494,7 @@ func TestNotificationService_UnregisterDevice_iOS(t *testing.T) {
err := db.Create(device).Error err := db.Create(device).Error
require.NoError(t, err) require.NoError(t, err)
err = service.UnregisterDevice("reg-token-ios", push.PlatformIOS, user.ID) err = service.UnregisterDevice(context.Background(), "reg-token-ios", push.PlatformIOS, user.ID)
require.NoError(t, err) require.NoError(t, err)
// Verify device is deactivated // Verify device is deactivated
@@ -522,7 +522,7 @@ func TestNotificationService_UnregisterDevice_Android(t *testing.T) {
err := db.Create(device).Error err := db.Create(device).Error
require.NoError(t, err) require.NoError(t, err)
err = service.UnregisterDevice("reg-token-android", push.PlatformAndroid, user.ID) err = service.UnregisterDevice(context.Background(), "reg-token-android", push.PlatformAndroid, user.ID)
require.NoError(t, err) require.NoError(t, err)
var found models.GCMDevice var found models.GCMDevice
@@ -538,7 +538,7 @@ func TestNotificationService_UnregisterDevice_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.UnregisterDevice("nonexistent-token", push.PlatformIOS, user.ID) err := service.UnregisterDevice(context.Background(), "nonexistent-token", push.PlatformIOS, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
} }
@@ -560,7 +560,7 @@ func TestNotificationService_UnregisterDevice_WrongUser(t *testing.T) {
err := db.Create(device).Error err := db.Create(device).Error
require.NoError(t, err) require.NoError(t, err)
err = service.UnregisterDevice("owner-token", push.PlatformIOS, attacker.ID) err = service.UnregisterDevice(context.Background(), "owner-token", push.PlatformIOS, attacker.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
} }
@@ -571,7 +571,7 @@ func TestNotificationService_UnregisterDevice_InvalidPlatform(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.UnregisterDevice("some-token", "windows", user.ID) err := service.UnregisterDevice(context.Background(), "some-token", "windows", user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
} }
@@ -585,7 +585,7 @@ func TestNotificationService_UpdateUserTimezone(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Should not panic, just silently update // Should not panic, just silently update
service.UpdateUserTimezone(user.ID, "America/Los_Angeles") service.UpdateUserTimezone(context.Background(), user.ID, "America/Los_Angeles")
// Verify timezone was stored // Verify timezone was stored
prefs, err := notifRepo.GetOrCreatePreferences(user.ID) prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
@@ -602,7 +602,7 @@ func TestNotificationService_UpdateUserTimezone_Invalid(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Invalid timezone should be silently ignored // Invalid timezone should be silently ignored
service.UpdateUserTimezone(user.ID, "Invalid/Timezone") service.UpdateUserTimezone(context.Background(), user.ID, "Invalid/Timezone")
prefs, err := notifRepo.GetOrCreatePreferences(user.ID) prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
require.NoError(t, err) require.NoError(t, err)
@@ -617,10 +617,10 @@ func TestNotificationService_UpdateUserTimezone_NoChangeSkipsWrite(t *testing.T)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Set timezone // Set timezone
service.UpdateUserTimezone(user.ID, "America/New_York") service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York")
// Set same timezone again — should be a no-op // Set same timezone again — should be a no-op
service.UpdateUserTimezone(user.ID, "America/New_York") service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York")
prefs, err := notifRepo.GetOrCreatePreferences(user.ID) prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
require.NoError(t, err) require.NoError(t, err)
@@ -648,7 +648,7 @@ func TestDeleteDevice_WrongUser_Returns403(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Attacker tries to deactivate the owner's device // Attacker tries to deactivate the owner's device
err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID) err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, attacker.ID)
require.Error(t, err, "should not allow deleting another user's device") require.Error(t, err, "should not allow deleting another user's device")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden) testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
@@ -678,7 +678,7 @@ func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Owner deactivates their own device // Owner deactivates their own device
err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID) err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, owner.ID)
require.NoError(t, err, "owner should be able to deactivate their own device") require.NoError(t, err, "owner should be able to deactivate their own device")
// Verify the device is now inactive // Verify the device is now inactive
@@ -709,7 +709,7 @@ func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Attacker tries to deactivate the owner's Android device // Attacker tries to deactivate the owner's Android device
err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID) err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, attacker.ID)
require.Error(t, err, "should not allow deleting another user's Android device") require.Error(t, err, "should not allow deleting another user's Android device")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden) testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
@@ -727,7 +727,7 @@ func TestDeleteDevice_NonExistent_Returns404(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
err := service.DeleteDevice(99999, push.PlatformIOS, user.ID) err := service.DeleteDevice(context.Background(), 99999, push.PlatformIOS, user.ID)
require.Error(t, err, "should return error for non-existent device") require.Error(t, err, "should return error for non-existent device")
testutil.AssertAppErrorCode(t, err, http.StatusNotFound) testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
} }
@@ -744,7 +744,7 @@ func TestNotificationService_CreateAndSend_TaskOverdue(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Task is overdue", nil) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Task is overdue", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
assert.Equal(t, "Overdue", notifs[0].Title) assert.Equal(t, "Overdue", notifs[0].Title)
@@ -760,7 +760,7 @@ func TestNotificationService_CreateAndSend_TaskCompleted(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
} }
@@ -775,7 +775,7 @@ func TestNotificationService_CreateAndSend_TaskAssigned(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned to you", nil) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned to you", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
} }
@@ -790,7 +790,7 @@ func TestNotificationService_CreateAndSend_ResidenceShared(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Someone shared a home", nil) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Someone shared a home", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
} }
@@ -805,7 +805,7 @@ func TestNotificationService_CreateAndSend_WarrantyExpiring(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Expiring", "Warranty expiring soon", nil) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Expiring", "Warranty expiring soon", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
} }
@@ -828,7 +828,7 @@ func TestNotificationService_DisabledPrefs_TaskOverdue(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Overdue task", nil) err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Overdue task", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, notifs) assert.Empty(t, notifs)
} }
@@ -849,7 +849,7 @@ func TestNotificationService_DisabledPrefs_TaskCompleted(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil) err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, notifs) assert.Empty(t, notifs)
} }
@@ -870,7 +870,7 @@ func TestNotificationService_DisabledPrefs_TaskAssigned(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned", nil) err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, notifs) assert.Empty(t, notifs)
} }
@@ -891,7 +891,7 @@ func TestNotificationService_DisabledPrefs_ResidenceShared(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Home shared", nil) err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Home shared", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, notifs) assert.Empty(t, notifs)
} }
@@ -912,7 +912,7 @@ func TestNotificationService_DisabledPrefs_WarrantyExpiring(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Warranty", "Expiring", nil) err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Warranty", "Expiring", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, notifs) assert.Empty(t, notifs)
} }
@@ -929,7 +929,7 @@ func TestNotificationService_CreateAndSend_UnknownTypeDefaultsEnabled(t *testing
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationType("unknown_type"), "Unknown", "Unknown notification", nil) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationType("unknown_type"), "Unknown", "Unknown notification", nil)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
} }
@@ -953,7 +953,7 @@ func TestNotificationService_CreateAndSend_WithMixedDataTypes(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskDueSoon, "Due Soon", "Fix faucet", data) err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskDueSoon, "Due Soon", "Fix faucet", data)
require.NoError(t, err) require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0) notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, notifs, 1) assert.Len(t, notifs, 1)
assert.NotNil(t, notifs[0].Data) assert.NotNil(t, notifs[0].Data)
@@ -985,7 +985,7 @@ func TestNotificationService_UpdatePreferences_MultipleFields(t *testing.T) {
TaskOverdueHour: &hour14, TaskOverdueHour: &hour14,
} }
resp, err := service.UpdatePreferences(user.ID, req) resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.TaskDueSoon) assert.False(t, resp.TaskDueSoon)
assert.False(t, resp.TaskOverdue) assert.False(t, resp.TaskOverdue)
@@ -1013,7 +1013,7 @@ func TestNotificationService_UpdatePreferences_NegativeHour(t *testing.T) {
TaskOverdueHour: &negHour, TaskOverdueHour: &negHour,
} }
_, err := service.UpdatePreferences(user.ID, req) _, err := service.UpdatePreferences(context.Background(), user.ID, req)
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest) testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
} }
@@ -1032,12 +1032,12 @@ func TestNotificationService_RegisterDevice_UpdateExistingAndroid(t *testing.T)
RegistrationID: "token-android-1", RegistrationID: "token-android-1",
Platform: push.PlatformAndroid, Platform: push.PlatformAndroid,
} }
_, err := service.RegisterDevice(user.ID, req) _, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
// Re-register with same token but new name // Re-register with same token but new name
req.Name = "Pixel 8 Pro" req.Name = "Pixel 8 Pro"
resp, err := service.RegisterDevice(user.ID, req) resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Pixel 8 Pro", resp.Name) assert.Equal(t, "Pixel 8 Pro", resp.Name)
assert.Equal(t, push.PlatformAndroid, resp.Platform) assert.Equal(t, push.PlatformAndroid, resp.Platform)
@@ -1052,7 +1052,7 @@ func TestDeleteDevice_AndroidNotFound_Returns404(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDevice(99999, push.PlatformAndroid, user.ID) err := service.DeleteDevice(context.Background(), 99999, push.PlatformAndroid, user.ID)
testutil.AssertAppErrorCode(t, err, http.StatusNotFound) testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
} }
@@ -1076,7 +1076,7 @@ func TestDeleteDevice_CorrectUser_Android_Succeeds(t *testing.T) {
err := db.Create(device).Error err := db.Create(device).Error
require.NoError(t, err) require.NoError(t, err)
err = service.DeleteDevice(device.ID, push.PlatformAndroid, owner.ID) err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, owner.ID)
require.NoError(t, err) require.NoError(t, err)
var found models.GCMDevice var found models.GCMDevice
@@ -1106,7 +1106,7 @@ func TestNotificationService_UnregisterDevice_WrongUser_Android(t *testing.T) {
err := db.Create(device).Error err := db.Create(device).Error
require.NoError(t, err) require.NoError(t, err)
err = service.UnregisterDevice("owner-android-token", push.PlatformAndroid, attacker.ID) err = service.UnregisterDevice(context.Background(), "owner-android-token", push.PlatformAndroid, attacker.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
} }
@@ -1119,7 +1119,7 @@ func TestNotificationService_UnregisterDevice_AndroidNotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.UnregisterDevice("nonexistent-android", push.PlatformAndroid, user.ID) err := service.UnregisterDevice(context.Background(), "nonexistent-android", push.PlatformAndroid, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
} }
+1 -1
View File
@@ -186,7 +186,7 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary {
func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) { func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) {
// Check subscription tier limits (if subscription service is wired up) // Check subscription tier limits (if subscription service is wired up)
if s.subscriptionService != nil { if s.subscriptionService != nil {
if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil { if err := s.subscriptionService.CheckLimit(ctx, ownerID, "properties"); err != nil {
return nil, err return nil, err
} }
} }
+42 -42
View File
@@ -98,8 +98,8 @@ func NewSubscriptionService(
} }
// GetSubscription gets the subscription for a user // GetSubscription gets the subscription for a user
func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionResponse, error) { func (s *SubscriptionService) GetSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID) sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -107,13 +107,13 @@ func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionRespons
} }
// GetSubscriptionStatus gets detailed subscription status including limits // GetSubscriptionStatus gets detailed subscription status including limits
func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionStatusResponse, error) { func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID uint) (*SubscriptionStatusResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID) sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
settings, err := s.subscriptionRepo.GetSettings() settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings()
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -122,18 +122,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled { if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
now := time.Now().UTC() now := time.Now().UTC()
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour) trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil { if err := s.subscriptionRepo.WithContext(ctx).SetTrialDates(userID, now, trialEnd); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Re-fetch after starting trial so response reflects the new state // Re-fetch after starting trial so response reflects the new state
sub, err = s.subscriptionRepo.FindByUserID(userID) sub, err = s.subscriptionRepo.WithContext(ctx).FindByUserID(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
} }
// Get all tier limits and build a map // Get all tier limits and build a map
allLimits, err := s.subscriptionRepo.GetAllTierLimits() allLimits, err := s.subscriptionRepo.WithContext(ctx).GetAllTierLimits()
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -154,7 +154,7 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
} }
// Get current usage // Get current usage
usage, err := s.getUserUsage(userID) usage, err := s.getUserUsage(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -204,31 +204,31 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
// getUserUsage calculates current usage for a user. // getUserUsage calculates current usage for a user.
// P-10: Uses CountByOwner for properties count instead of loading all owned residences. // P-10: Uses CountByOwner for properties count instead of loading all owned residences.
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)). // Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) { func (s *SubscriptionService) getUserUsage(ctx context.Context, userID uint) (*UsageResponse, error) {
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records // P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
propertiesCount, err := s.residenceRepo.CountByOwner(userID) propertiesCount, err := s.residenceRepo.WithContext(ctx).CountByOwner(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Still need residence IDs for batch counting tasks/contractors/documents // Still need residence IDs for batch counting tasks/contractors/documents
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID) residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByOwner(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Count tasks, contractors, and documents across all residences with single queries each // Count tasks, contractors, and documents across all residences with single queries each
tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs) tasksCount, err := s.taskRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs) contractorsCount, err := s.contractorRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs) documentsCount, err := s.documentRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -242,8 +242,8 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
} }
// CheckLimit checks if a user has exceeded a specific limit // CheckLimit checks if a user has exceeded a specific limit
func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error { func (s *SubscriptionService) CheckLimit(ctx context.Context, userID uint, limitType string) error {
settings, err := s.subscriptionRepo.GetSettings() settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings()
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -253,7 +253,7 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
return nil return nil
} }
sub, err := s.subscriptionRepo.GetOrCreate(userID) sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
@@ -268,12 +268,12 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
return nil return nil
} }
limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier) limits, err := s.subscriptionRepo.WithContext(ctx).GetTierLimits(sub.Tier)
if err != nil { if err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
usage, err := s.getUserUsage(userID) usage, err := s.getUserUsage(ctx, userID)
if err != nil { if err != nil {
return err return err
} }
@@ -301,8 +301,8 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
} }
// GetUpgradeTrigger gets an upgrade trigger by key // GetUpgradeTrigger gets an upgrade trigger by key
func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResponse, error) { func (s *SubscriptionService) GetUpgradeTrigger(ctx context.Context, key string) (*UpgradeTriggerResponse, error) {
trigger, err := s.subscriptionRepo.GetUpgradeTrigger(key) trigger, err := s.subscriptionRepo.WithContext(ctx).GetUpgradeTrigger(key)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.upgrade_trigger_not_found") return nil, apperrors.NotFound("error.upgrade_trigger_not_found")
@@ -314,8 +314,8 @@ func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResp
// GetAllUpgradeTriggers gets all active upgrade triggers as a map keyed by trigger_key // GetAllUpgradeTriggers gets all active upgrade triggers as a map keyed by trigger_key
// KMM client expects Map<String, UpgradeTriggerData> // KMM client expects Map<String, UpgradeTriggerData>
func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTriggerDataResponse, error) { func (s *SubscriptionService) GetAllUpgradeTriggers(ctx context.Context) (map[string]*UpgradeTriggerDataResponse, error) {
triggers, err := s.subscriptionRepo.GetAllUpgradeTriggers() triggers, err := s.subscriptionRepo.WithContext(ctx).GetAllUpgradeTriggers()
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -328,8 +328,8 @@ func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTrigge
} }
// GetFeatureBenefits gets all feature benefits // GetFeatureBenefits gets all feature benefits
func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, error) { func (s *SubscriptionService) GetFeatureBenefits(ctx context.Context) ([]FeatureBenefitResponse, error) {
benefits, err := s.subscriptionRepo.GetFeatureBenefits() benefits, err := s.subscriptionRepo.WithContext(ctx).GetFeatureBenefits()
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -342,13 +342,13 @@ func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, er
} }
// GetActivePromotions gets active promotions for a user // GetActivePromotions gets active promotions for a user
func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionResponse, error) { func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID uint) ([]PromotionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID) sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
promotions, err := s.subscriptionRepo.GetActivePromotions(sub.Tier) promotions, err := s.subscriptionRepo.WithContext(ctx).GetActivePromotions(sub.Tier)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -362,13 +362,13 @@ func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionRespo
// ProcessApplePurchase processes an Apple IAP purchase // ProcessApplePurchase processes an Apple IAP purchase
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID) // Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) { func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
// Store receipt/transaction data // Store receipt/transaction data
dataToStore := receiptData dataToStore := receiptData
if dataToStore == "" { if dataToStore == "" {
dataToStore = transactionID dataToStore = transactionID
} }
if err := s.subscriptionRepo.UpdateReceiptData(userID, dataToStore); err != nil { if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -406,18 +406,18 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated") log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
// Upgrade to Pro with the validated expiration // Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil { if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return s.GetSubscription(userID) return s.GetSubscription(ctx, userID)
} }
// ProcessGooglePurchase processes a Google Play purchase // ProcessGooglePurchase processes a Google Play purchase
// productID is optional but helps validate the specific subscription // productID is optional but helps validate the specific subscription
func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) { func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
// Store purchase token first // Store purchase token first
if err := s.subscriptionRepo.UpdatePurchaseToken(userID, purchaseToken); err != nil { if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
@@ -463,25 +463,25 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
} }
// Upgrade to Pro with the validated expiration // Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil { if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return s.GetSubscription(userID) return s.GetSubscription(ctx, userID)
} }
// CancelSubscription cancels a subscription (downgrades to free at end of period) // CancelSubscription cancels a subscription (downgrades to free at end of period)
func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResponse, error) { func (s *SubscriptionService) CancelSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) {
if err := s.subscriptionRepo.SetAutoRenew(userID, false); err != nil { if err := s.subscriptionRepo.WithContext(ctx).SetAutoRenew(userID, false); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return s.GetSubscription(userID) return s.GetSubscription(ctx, userID)
} }
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription // IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error). // from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) { func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(ctx context.Context, userID uint, requestedPlatform string) (bool, string, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID) sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil { if err != nil {
return false, "", apperrors.Internal(err) return false, "", apperrors.Internal(err)
} }
@@ -1,6 +1,7 @@
package services package services
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -70,7 +71,7 @@ func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) {
googleClient: nil, googleClient: nil,
} }
_, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "") _, err := svc.ProcessApplePurchase(context.Background(), user.ID, "fake-receipt", "")
assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil") assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil")
// Verify user was NOT upgraded to Pro // Verify user was NOT upgraded to Pro
@@ -109,7 +110,7 @@ func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
} }
// Neither receipt data nor transaction ID - should still not grant Pro // Neither receipt data nor transaction ID - should still not grant Pro
_, err := svc.ProcessApplePurchase(user.ID, "", "") _, err := svc.ProcessApplePurchase(context.Background(), user.ID, "", "")
assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data") assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data")
// Verify no upgrade happened // Verify no upgrade happened
@@ -140,7 +141,7 @@ func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) {
googleClient: nil, // Not configured googleClient: nil, // Not configured
} }
_, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.honeyDue.pro.monthly") _, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "fake-token", "com.tt.honeyDue.pro.monthly")
assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil") assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil")
// Verify user was NOT upgraded to Pro // Verify user was NOT upgraded to Pro
@@ -172,7 +173,7 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
} }
// With empty token // With empty token
_, err := svc.ProcessGooglePurchase(user.ID, "", "") _, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "", "")
assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil") assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil")
// Verify no upgrade happened // Verify no upgrade happened
@@ -202,7 +203,7 @@ func TestSubscriptionService_GetSubscription(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := svc.GetSubscription(user.ID) resp, err := svc.GetSubscription(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "free", resp.Tier) assert.Equal(t, "free", resp.Tier)
assert.False(t, resp.IsPro) assert.False(t, resp.IsPro)
@@ -238,7 +239,7 @@ func TestSubscriptionService_GetSubscription_ProUser(t *testing.T) {
err := db.Create(sub).Error err := db.Create(sub).Error
require.NoError(t, err) require.NoError(t, err)
resp, err := svc.GetSubscription(user.ID) resp, err := svc.GetSubscription(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "pro", resp.Tier) assert.Equal(t, "pro", resp.Tier)
assert.True(t, resp.IsPro) assert.True(t, resp.IsPro)
@@ -277,7 +278,7 @@ func TestSubscriptionService_CancelSubscription(t *testing.T) {
err := db.Create(sub).Error err := db.Create(sub).Error
require.NoError(t, err) require.NoError(t, err)
resp, err := svc.CancelSubscription(user.ID) resp, err := svc.CancelSubscription(context.Background(), user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, resp.AutoRenew) assert.False(t, resp.AutoRenew)
} }
@@ -365,7 +366,7 @@ func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
err := db.Create(sub).Error err := db.Create(sub).Error
require.NoError(t, err) require.NoError(t, err)
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform) conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(context.Background(), user.ID, tt.requestedPlatform)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.wantConflict, conflict) assert.Equal(t, tt.wantConflict, conflict)
assert.Equal(t, tt.wantPlatform, existingPlatform) assert.Equal(t, tt.wantPlatform, existingPlatform)
+3 -3
View File
@@ -898,7 +898,7 @@ func (s *TaskService) sendTaskCompletedNotification(ctx context.Context, task *m
// Send email notification (to everyone INCLUDING the person who completed it) // Send email notification (to everyone INCLUDING the person who completed it)
// Check user's email notification preferences first // Check user's email notification preferences first
if s.emailService != nil && user.Email != "" && s.notificationService != nil { if s.emailService != nil && user.Email != "" && s.notificationService != nil {
prefs, prefsErr := s.notificationService.GetPreferences(user.ID) prefs, prefsErr := s.notificationService.GetPreferences(ctx, user.ID)
// LE-06: Log fail-open behavior when preferences cannot be loaded // LE-06: Log fail-open behavior when preferences cannot be loaded
if prefsErr != nil { if prefsErr != nil {
log.Warn(). log.Warn().
@@ -1264,8 +1264,8 @@ func (s *TaskService) GetFrequencies(ctx context.Context) ([]responses.TaskFrequ
// UpdateUserTimezone updates the user's timezone for background job calculations. // UpdateUserTimezone updates the user's timezone for background job calculations.
// This is called from handlers when the X-Timezone header is present. // This is called from handlers when the X-Timezone header is present.
// Delegates to NotificationService since timezone is stored in notification preferences. // Delegates to NotificationService since timezone is stored in notification preferences.
func (s *TaskService) UpdateUserTimezone(userID uint, timezone string) { func (s *TaskService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) {
if s.notificationService != nil { if s.notificationService != nil {
s.notificationService.UpdateUserTimezone(userID, timezone) s.notificationService.UpdateUserTimezone(ctx, userID, timezone)
} }
} }