diff --git a/internal/handlers/auth_handler.go b/internal/handlers/auth_handler.go index 7ae2bcd..9649274 100644 --- a/internal/handlers/auth_handler.go +++ b/internal/handlers/auth_handler.go @@ -65,7 +65,7 @@ func (h *AuthHandler) Login(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - response, err := h.authService.Login(&req) + response, err := h.authService.Login(c.Request().Context(), &req) if err != nil { log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed") if h.auditService != nil { @@ -94,7 +94,7 @@ func (h *AuthHandler) Register(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - response, confirmationCode, err := h.authService.Register(&req) + response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req) if err != nil { log.Debug().Err(err).Msg("Registration failed") return err @@ -141,7 +141,7 @@ func (h *AuthHandler) Logout(c echo.Context) error { } // Invalidate token in database - if err := h.authService.Logout(token); err != nil { + if err := h.authService.Logout(c.Request().Context(), token); err != nil { log.Warn().Err(err).Msg("Failed to delete token from database") } @@ -162,7 +162,7 @@ func (h *AuthHandler) CurrentUser(c echo.Context) error { return err } - response, err := h.authService.GetCurrentUser(user.ID) + response, err := h.authService.GetCurrentUser(c.Request().Context(), user.ID) if err != nil { log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user") return err @@ -186,7 +186,7 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - response, err := h.authService.UpdateProfile(user.ID, &req) + response, err := h.authService.UpdateProfile(c.Request().Context(), user.ID, &req) if err != nil { log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile") return err @@ -210,7 +210,7 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - err = h.authService.VerifyEmail(user.ID, req.Code) + err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code) if err != nil { log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed") return err @@ -243,7 +243,7 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error { return err } - code, err := h.authService.ResendVerificationCode(user.ID) + code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID) if err != nil { log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification") return err @@ -276,7 +276,7 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - code, user, err := h.authService.ForgotPassword(req.Email) + code, user, err := h.authService.ForgotPassword(c.Request().Context(), req.Email) if err != nil { var appErr *apperrors.AppError if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests { @@ -324,7 +324,7 @@ func (h *AuthHandler) VerifyResetCode(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - resetToken, err := h.authService.VerifyResetCode(req.Email, req.Code) + resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code) if err != nil { log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed") return err @@ -346,7 +346,7 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error { return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) } - err := h.authService.ResetPassword(req.ResetToken, req.NewPassword) + err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword) if err != nil { log.Debug().Err(err).Msg("Password reset failed") return err @@ -469,7 +469,7 @@ func (h *AuthHandler) RefreshToken(c echo.Context) error { return apperrors.Unauthorized("error.not_authenticated") } - response, err := h.authService.RefreshToken(token, user.ID) + response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID) if err != nil { log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed") return err @@ -497,7 +497,7 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error { return apperrors.BadRequest("error.invalid_request") } - fileURLs, err := h.authService.DeleteAccount(user.ID, req.Password, req.Confirmation) + fileURLs, err := h.authService.DeleteAccount(c.Request().Context(), user.ID, req.Password, req.Confirmation) if err != nil { log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed") return err diff --git a/internal/handlers/contractor_handler.go b/internal/handlers/contractor_handler.go index d8def5c..a6ce095 100644 --- a/internal/handlers/contractor_handler.go +++ b/internal/handlers/contractor_handler.go @@ -30,7 +30,7 @@ func (h *ContractorHandler) ListContractors(c echo.Context) error { if err != nil { return err } - response, err := h.contractorService.ListContractors(user.ID) + response, err := h.contractorService.ListContractors(c.Request().Context(), user.ID) if err != nil { return apperrors.Internal(err) } @@ -48,7 +48,7 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error { return apperrors.BadRequest("error.invalid_contractor_id") } - response, err := h.contractorService.GetContractor(uint(contractorID), user.ID) + response, err := h.contractorService.GetContractor(c.Request().Context(), uint(contractorID), user.ID) if err != nil { return err } @@ -69,7 +69,7 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error { return err } - response, err := h.contractorService.CreateContractor(&req, user.ID) + response, err := h.contractorService.CreateContractor(c.Request().Context(), &req, user.ID) if err != nil { return err } @@ -95,7 +95,7 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error { return err } - response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req) + response, err := h.contractorService.UpdateContractor(c.Request().Context(), uint(contractorID), user.ID, &req) if err != nil { return err } @@ -113,7 +113,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error { return apperrors.BadRequest("error.invalid_contractor_id") } - err = h.contractorService.DeleteContractor(uint(contractorID), user.ID) + err = h.contractorService.DeleteContractor(c.Request().Context(), uint(contractorID), user.ID) if err != nil { return err } @@ -131,7 +131,7 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error { return apperrors.BadRequest("error.invalid_contractor_id") } - response, err := h.contractorService.ToggleFavorite(uint(contractorID), user.ID) + response, err := h.contractorService.ToggleFavorite(c.Request().Context(), uint(contractorID), user.ID) if err != nil { return err } @@ -149,7 +149,7 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error { return apperrors.BadRequest("error.invalid_contractor_id") } - response, err := h.contractorService.GetContractorTasks(uint(contractorID), user.ID) + response, err := h.contractorService.GetContractorTasks(c.Request().Context(), uint(contractorID), user.ID) if err != nil { return err } @@ -167,7 +167,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error { return apperrors.BadRequest("error.invalid_residence_id") } - response, err := h.contractorService.ListContractorsByResidence(uint(residenceID), user.ID) + response, err := h.contractorService.ListContractorsByResidence(c.Request().Context(), uint(residenceID), user.ID) if err != nil { return err } @@ -176,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error { // GetSpecialties handles GET /api/contractors/specialties/ func (h *ContractorHandler) GetSpecialties(c echo.Context) error { - specialties, err := h.contractorService.GetSpecialties() + specialties, err := h.contractorService.GetSpecialties(c.Request().Context()) if err != nil { return apperrors.Internal(err) } diff --git a/internal/handlers/document_handler.go b/internal/handlers/document_handler.go index 97de94c..7ef91d3 100644 --- a/internal/handlers/document_handler.go +++ b/internal/handlers/document_handler.go @@ -70,7 +70,7 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error { } } - response, err := h.documentService.ListDocuments(user.ID, filter) + response, err := h.documentService.ListDocuments(c.Request().Context(), user.ID, filter) if err != nil { return err } @@ -88,7 +88,7 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error { return apperrors.BadRequest("error.invalid_document_id") } - response, err := h.documentService.GetDocument(uint(documentID), user.ID) + response, err := h.documentService.GetDocument(c.Request().Context(), uint(documentID), user.ID) if err != nil { return err } @@ -101,7 +101,7 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error { if err != nil { return err } - response, err := h.documentService.ListWarranties(user.ID) + response, err := h.documentService.ListWarranties(c.Request().Context(), user.ID) if err != nil { return apperrors.Internal(err) } @@ -222,7 +222,7 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error { return err } - response, err := h.documentService.CreateDocument(&req, user.ID) + response, err := h.documentService.CreateDocument(c.Request().Context(), &req, user.ID) if err != nil { return err } @@ -248,7 +248,7 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error { return err } - response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req) + response, err := h.documentService.UpdateDocument(c.Request().Context(), uint(documentID), user.ID, &req) if err != nil { return err } @@ -266,7 +266,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error { return apperrors.BadRequest("error.invalid_document_id") } - err = h.documentService.DeleteDocument(uint(documentID), user.ID) + err = h.documentService.DeleteDocument(c.Request().Context(), uint(documentID), user.ID) if err != nil { return err } @@ -284,7 +284,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error { return apperrors.BadRequest("error.invalid_document_id") } - response, err := h.documentService.ActivateDocument(uint(documentID), user.ID) + response, err := h.documentService.ActivateDocument(c.Request().Context(), uint(documentID), user.ID) if err != nil { return err } @@ -302,7 +302,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error { return apperrors.BadRequest("error.invalid_document_id") } - response, err := h.documentService.DeactivateDocument(uint(documentID), user.ID) + response, err := h.documentService.DeactivateDocument(c.Request().Context(), uint(documentID), user.ID) if err != nil { return err } @@ -349,7 +349,7 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error { caption := c.FormValue("caption") - response, err := h.documentService.UploadDocumentImage(uint(documentID), user.ID, result.URL, caption) + response, err := h.documentService.UploadDocumentImage(c.Request().Context(), uint(documentID), user.ID, result.URL, caption) if err != nil { return err } @@ -372,7 +372,7 @@ func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error { return apperrors.BadRequest("error.invalid_image_id") } - response, err := h.documentService.DeleteDocumentImage(uint(documentID), uint(imageID), user.ID) + response, err := h.documentService.DeleteDocumentImage(c.Request().Context(), uint(documentID), uint(imageID), user.ID) if err != nil { return err } diff --git a/internal/handlers/notification_handler.go b/internal/handlers/notification_handler.go index 04d5dad..f3f2c82 100644 --- a/internal/handlers/notification_handler.go +++ b/internal/handlers/notification_handler.go @@ -46,7 +46,7 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error { } } - notifications, err := h.notificationService.GetNotifications(user.ID, limit, offset) + notifications, err := h.notificationService.GetNotifications(c.Request().Context(), user.ID, limit, offset) if err != nil { return err } @@ -64,7 +64,7 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error { return err } - count, err := h.notificationService.GetUnreadCount(user.ID) + count, err := h.notificationService.GetUnreadCount(c.Request().Context(), user.ID) if err != nil { return err } @@ -84,7 +84,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error { return apperrors.BadRequest("error.invalid_notification_id") } - err = h.notificationService.MarkAsRead(uint(notificationID), user.ID) + err = h.notificationService.MarkAsRead(c.Request().Context(), uint(notificationID), user.ID) if err != nil { return err } @@ -99,7 +99,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error { return err } - err = h.notificationService.MarkAllAsRead(user.ID) + err = h.notificationService.MarkAllAsRead(c.Request().Context(), user.ID) if err != nil { return err } @@ -114,7 +114,7 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error { return err } - prefs, err := h.notificationService.GetPreferences(user.ID) + prefs, err := h.notificationService.GetPreferences(c.Request().Context(), user.ID) if err != nil { return err } @@ -137,7 +137,7 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error { return err } - prefs, err := h.notificationService.UpdatePreferences(user.ID, &req) + prefs, err := h.notificationService.UpdatePreferences(c.Request().Context(), user.ID, &req) if err != nil { return err } @@ -160,7 +160,7 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error { return err } - device, err := h.notificationService.RegisterDevice(user.ID, &req) + device, err := h.notificationService.RegisterDevice(c.Request().Context(), user.ID, &req) if err != nil { return err } @@ -175,7 +175,7 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error { return err } - devices, err := h.notificationService.ListDevices(user.ID) + devices, err := h.notificationService.ListDevices(c.Request().Context(), user.ID) if err != nil { return err } @@ -208,7 +208,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error { return apperrors.BadRequest("error.invalid_platform") } - err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID) + err = h.notificationService.UnregisterDevice(c.Request().Context(), req.RegistrationID, req.Platform, user.ID) if err != nil { return err } @@ -236,7 +236,7 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error { return apperrors.BadRequest("error.invalid_platform") } - err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID) + err = h.notificationService.DeleteDevice(c.Request().Context(), uint(deviceID), platform, user.ID) if err != nil { return err } diff --git a/internal/handlers/static_data_handler.go b/internal/handlers/static_data_handler.go index 46ddbe3..ef902c5 100644 --- a/internal/handlers/static_data_handler.go +++ b/internal/handlers/static_data_handler.go @@ -106,7 +106,7 @@ func (h *StaticDataHandler) GetStaticData(c echo.Context) error { return err } - contractorSpecialties, err := h.contractorService.GetSpecialties() + contractorSpecialties, err := h.contractorService.GetSpecialties(c.Request().Context()) if err != nil { return err } diff --git a/internal/handlers/subscription_handler.go b/internal/handlers/subscription_handler.go index 80699b2..faa5a48 100644 --- a/internal/handlers/subscription_handler.go +++ b/internal/handlers/subscription_handler.go @@ -32,7 +32,7 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error { return err } - subscription, err := h.subscriptionService.GetSubscription(user.ID) + subscription, err := h.subscriptionService.GetSubscription(c.Request().Context(), user.ID) if err != nil { return err } @@ -47,7 +47,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error { return err } - status, err := h.subscriptionService.GetSubscriptionStatus(user.ID) + status, err := h.subscriptionService.GetSubscriptionStatus(c.Request().Context(), user.ID) if err != nil { return err } @@ -59,7 +59,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error { func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error { key := c.Param("key") - trigger, err := h.subscriptionService.GetUpgradeTrigger(key) + trigger, err := h.subscriptionService.GetUpgradeTrigger(c.Request().Context(), key) if err != nil { return err } @@ -69,7 +69,7 @@ func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error { // GetAllUpgradeTriggers handles GET /api/subscription/upgrade-triggers/ func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error { - triggers, err := h.subscriptionService.GetAllUpgradeTriggers() + triggers, err := h.subscriptionService.GetAllUpgradeTriggers(c.Request().Context()) if err != nil { return err } @@ -79,7 +79,7 @@ func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error { // GetFeatureBenefits handles GET /api/subscription/features/ func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error { - benefits, err := h.subscriptionService.GetFeatureBenefits() + benefits, err := h.subscriptionService.GetFeatureBenefits(c.Request().Context()) if err != nil { return err } @@ -94,7 +94,7 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error { return err } - promotions, err := h.subscriptionService.GetActivePromotions(user.ID) + promotions, err := h.subscriptionService.GetActivePromotions(c.Request().Context(), user.ID) if err != nil { return err } @@ -125,12 +125,12 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error { if req.TransactionID == "" && req.ReceiptData == "" { return apperrors.BadRequest("error.receipt_data_required") } - subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID) + subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID) case "android": if req.PurchaseToken == "" { return apperrors.BadRequest("error.purchase_token_required") } - subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID) + subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID) default: return apperrors.BadRequest("error.invalid_platform") } @@ -152,7 +152,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error { return err } - subscription, err := h.subscriptionService.CancelSubscription(user.ID) + subscription, err := h.subscriptionService.CancelSubscription(c.Request().Context(), user.ID) if err != nil { return err } @@ -187,12 +187,12 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error { if req.ReceiptData == "" && req.TransactionID == "" { return apperrors.BadRequest("error.receipt_data_required") } - subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID) + subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID) case "android": if req.PurchaseToken == "" { return apperrors.BadRequest("error.purchase_token_required") } - subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID) + subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID) default: return apperrors.BadRequest("error.invalid_platform") } @@ -220,7 +220,7 @@ func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error { } // Check if already Pro from another platform - alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe") + alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(c.Request().Context(), user.ID, "stripe") if err != nil { return err } diff --git a/internal/handlers/task_handler.go b/internal/handlers/task_handler.go index b061404..5d5d9e9 100644 --- a/internal/handlers/task_handler.go +++ b/internal/handlers/task_handler.go @@ -42,7 +42,7 @@ func (h *TaskHandler) ListTasks(c echo.Context) error { if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" { cachedTZ, _ := c.Get("user_timezone").(string) if cachedTZ != tzHeader { - h.taskService.UpdateUserTimezone(user.ID, tzHeader) + h.taskService.UpdateUserTimezone(c.Request().Context(), user.ID, tzHeader) c.Set("user_timezone", tzHeader) } } diff --git a/internal/services/auth_refresh_test.go b/internal/services/auth_refresh_test.go index 7966732..94e7f44 100644 --- a/internal/services/auth_refresh_test.go +++ b/internal/services/auth_refresh_test.go @@ -1,6 +1,7 @@ package services import ( + "context" "testing" "time" @@ -74,7 +75,7 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) { svc := newTestAuthService(db) - resp, err := svc.RefreshToken(token.Key, user.ID) + resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID) require.NoError(t, err) assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token") assert.Contains(t, resp.Message, "still valid") @@ -87,7 +88,7 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) { svc := newTestAuthService(db) - resp, err := svc.RefreshToken(token.Key, user.ID) + resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID) require.NoError(t, err) assert.NotEqual(t, token.Key, resp.Token, "should return a new token") assert.Contains(t, resp.Message, "refreshed") @@ -114,7 +115,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) { svc := newTestAuthService(db) - resp, err := svc.RefreshToken(token.Key, user.ID) + resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID) require.Error(t, err) assert.Nil(t, resp) assert.Contains(t, err.Error(), "error.token_expired") @@ -129,7 +130,7 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) { svc := newTestAuthService(db) - resp, err := svc.RefreshToken(token.Key, user.ID) + resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID) require.NoError(t, err) assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed") } @@ -140,7 +141,7 @@ func TestRefreshToken_InvalidToken_Returns401(t *testing.T) { svc := newTestAuthService(db) - resp, err := svc.RefreshToken("nonexistent-token-key", user.ID) + resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID) require.Error(t, err) assert.Nil(t, resp) assert.Contains(t, err.Error(), "error.invalid_token") @@ -154,7 +155,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) { svc := newTestAuthService(db) // Try to refresh with a different user ID - resp, err := svc.RefreshToken(token.Key, user.ID+999) + resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999) require.Error(t, err) assert.Nil(t, resp) assert.Contains(t, err.Error(), "error.invalid_token") @@ -167,7 +168,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) { svc := newTestAuthService(db) - resp, err := svc.RefreshToken(token.Key, user.ID) + resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID) require.NoError(t, err) assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed") } diff --git a/internal/services/auth_service.go b/internal/services/auth_service.go index 75bd21c..2910076 100644 --- a/internal/services/auth_service.go +++ b/internal/services/auth_service.go @@ -57,14 +57,14 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N } // Login authenticates a user and returns a token -func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginResponse, error) { +func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) { // Find user by username or email identifier := req.Username if identifier == "" { identifier = req.Email } - user, err := s.userRepo.FindByUsernameOrEmail(identifier) + user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier) if err != nil { if errors.Is(err, repositories.ErrUserNotFound) { return nil, apperrors.Unauthorized("error.invalid_credentials") @@ -83,13 +83,13 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons } // Get or create auth token - token, err := s.userRepo.GetOrCreateToken(user.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID) if err != nil { return nil, apperrors.Internal(err) } // Update last login - if err := s.userRepo.UpdateLastLogin(user.ID); err != nil { + if err := s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID); err != nil { // Log error but don't fail the login log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login") } @@ -103,9 +103,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons // Register creates a new user account. // F-10: User creation, profile creation, notification preferences, and confirmation code // are wrapped in a transaction for atomicity. -func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) { +func (s *AuthService) Register(ctx context.Context, req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) { // Check if username exists - exists, err := s.userRepo.ExistsByUsername(req.Username) + exists, err := s.userRepo.WithContext(ctx).ExistsByUsername(req.Username) if err != nil { return nil, "", apperrors.Internal(err) } @@ -114,7 +114,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist } // Check if email exists - exists, err = s.userRepo.ExistsByEmail(req.Email) + exists, err = s.userRepo.WithContext(ctx).ExistsByEmail(req.Email) if err != nil { return nil, "", apperrors.Internal(err) } @@ -146,7 +146,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) // Wrap user creation + profile + notification preferences + confirmation code in a transaction - txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error { + txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error { // Save user if err := txRepo.Create(user); err != nil { return err @@ -159,7 +159,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist // Create notification preferences with all options enabled if s.notificationRepo != nil { - if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { + if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil { log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration") } } @@ -176,7 +176,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist } // Create auth token (outside transaction since token generation is idempotent) - token, err := s.userRepo.GetOrCreateToken(user.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID) if err != nil { return nil, "", apperrors.Internal(err) } @@ -192,7 +192,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist // - If token is expired (> expiryDays old), returns error (must re-login). // - If token is in the renewal window (> refreshDays old), generates a new token. // - If token is still fresh (< refreshDays old), returns the existing token (no-op). -func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) { +func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) { expiryDays := s.cfg.Security.TokenExpiryDays if expiryDays <= 0 { expiryDays = 90 @@ -203,7 +203,7 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref } // Look up the token - authToken, err := s.userRepo.FindTokenByKey(tokenKey) + authToken, err := s.userRepo.WithContext(ctx).FindTokenByKey(tokenKey) if err != nil { return nil, apperrors.Unauthorized("error.invalid_token") } @@ -232,12 +232,12 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref // Token is in the renewal window — generate a new one // Delete the old token - if err := s.userRepo.DeleteToken(tokenKey); err != nil { + if err := s.userRepo.WithContext(ctx).DeleteToken(tokenKey); err != nil { log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh") } // Create a new token - newToken, err := s.userRepo.CreateToken(userID) + newToken, err := s.userRepo.WithContext(ctx).CreateToken(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -249,18 +249,18 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref } // Logout invalidates a user's token -func (s *AuthService) Logout(token string) error { - return s.userRepo.DeleteToken(token) +func (s *AuthService) Logout(ctx context.Context, token string) error { + return s.userRepo.WithContext(ctx).DeleteToken(token) } // GetCurrentUser returns the current authenticated user with profile -func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserResponse, error) { - user, err := s.userRepo.FindByIDWithProfile(userID) +func (s *AuthService) GetCurrentUser(ctx context.Context, userID uint) (*responses.CurrentUserResponse, error) { + user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(userID) if err != nil { return nil, err } - authProvider, err := s.userRepo.FindAuthProvider(userID) + authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID) if err != nil { // Log but don't fail - default to "email" log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") @@ -275,9 +275,9 @@ func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserRespons // For email auth users, password verification is required. // For social auth users, confirmation string "DELETE" is required. // Returns a list of file URLs that need to be deleted from disk. -func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) ([]string, error) { +func (s *AuthService) DeleteAccount(ctx context.Context, userID uint, password, confirmation *string) ([]string, error) { // Fetch user - user, err := s.userRepo.FindByID(userID) + user, err := s.userRepo.WithContext(ctx).FindByID(userID) if err != nil { if errors.Is(err, repositories.ErrUserNotFound) { return nil, apperrors.NotFound("error.user_not_found") @@ -286,7 +286,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) } // Determine auth provider - authProvider, err := s.userRepo.FindAuthProvider(userID) + authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -308,7 +308,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) // Start transaction and cascade delete var fileURLs []string - txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error { + txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error { urls, err := txRepo.DeleteUserCascade(userID) if err != nil { return err @@ -324,15 +324,15 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) } // UpdateProfile updates a user's profile -func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) { - user, err := s.userRepo.FindByID(userID) +func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) { + user, err := s.userRepo.WithContext(ctx).FindByID(userID) if err != nil { return nil, err } // Check if new email is taken (if email is being changed) if req.Email != nil && *req.Email != user.Email { - exists, err := s.userRepo.ExistsByEmail(*req.Email) + exists, err := s.userRepo.WithContext(ctx).ExistsByEmail(*req.Email) if err != nil { return nil, apperrors.Internal(err) } @@ -349,17 +349,17 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ user.LastName = *req.LastName } - if err := s.userRepo.Update(user); err != nil { + if err := s.userRepo.WithContext(ctx).Update(user); err != nil { return nil, apperrors.Internal(err) } // Reload with profile - user, err = s.userRepo.FindByIDWithProfile(userID) + user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(userID) if err != nil { return nil, err } - authProvider, err := s.userRepo.FindAuthProvider(userID) + authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID) if err != nil { log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") authProvider = "email" @@ -370,9 +370,9 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ } // VerifyEmail verifies a user's email with a confirmation code -func (s *AuthService) VerifyEmail(userID uint, code string) error { +func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string) error { // Get user profile - profile, err := s.userRepo.GetOrCreateProfile(userID) + profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID) if err != nil { return apperrors.Internal(err) } @@ -384,14 +384,14 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error { // Check for test code when DEBUG_FIXED_CODES is enabled if s.cfg.Server.DebugFixedCodes && code == "123456" { - if err := s.userRepo.SetProfileVerified(userID, true); err != nil { + if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil { return apperrors.Internal(err) } return nil } // Find and validate confirmation code - confirmCode, err := s.userRepo.FindConfirmationCode(userID, code) + confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code) if err != nil { if errors.Is(err, repositories.ErrCodeNotFound) { return apperrors.BadRequest("error.invalid_verification_code") @@ -403,12 +403,12 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error { } // Mark code as used - if err := s.userRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil { + if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil { return apperrors.Internal(err) } // Set profile as verified - if err := s.userRepo.SetProfileVerified(userID, true); err != nil { + if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil { return apperrors.Internal(err) } @@ -416,9 +416,9 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error { } // ResendVerificationCode creates and returns a new verification code -func (s *AuthService) ResendVerificationCode(userID uint) (string, error) { +func (s *AuthService) ResendVerificationCode(ctx context.Context, userID uint) (string, error) { // Get user profile - profile, err := s.userRepo.GetOrCreateProfile(userID) + profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID) if err != nil { return "", apperrors.Internal(err) } @@ -437,7 +437,7 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) { } expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) - if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil { + if _, err := s.userRepo.WithContext(ctx).CreateConfirmationCode(userID, code, expiresAt); err != nil { return "", apperrors.Internal(err) } @@ -445,9 +445,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) { } // ForgotPassword initiates the password reset process -func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) { +func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string, *models.User, error) { // Find user by email - user, err := s.userRepo.FindByEmail(email) + user, err := s.userRepo.WithContext(ctx).FindByEmail(email) if err != nil { if errors.Is(err, repositories.ErrUserNotFound) { // Don't reveal that the email doesn't exist @@ -457,7 +457,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) } // Check rate limit - count, err := s.userRepo.CountRecentPasswordResetRequests(user.ID) + count, err := s.userRepo.WithContext(ctx).CountRecentPasswordResetRequests(user.ID) if err != nil { return "", nil, apperrors.Internal(err) } @@ -481,7 +481,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) return "", nil, apperrors.Internal(err) } - if _, err := s.userRepo.CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil { + if _, err := s.userRepo.WithContext(ctx).CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil { return "", nil, apperrors.Internal(err) } @@ -489,9 +489,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) } // VerifyResetCode verifies a password reset code and returns a reset token -func (s *AuthService) VerifyResetCode(email, code string) (string, error) { +func (s *AuthService) VerifyResetCode(ctx context.Context, email, code string) (string, error) { // Find the reset code - resetCode, user, err := s.userRepo.FindPasswordResetCodeByEmail(email) + resetCode, user, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByEmail(email) if err != nil { if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) { return "", apperrors.BadRequest("error.invalid_verification_code") @@ -507,7 +507,7 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) { // Verify the code if !resetCode.CheckCode(code) { // Increment attempts - s.userRepo.IncrementResetCodeAttempts(resetCode.ID) + s.userRepo.WithContext(ctx).IncrementResetCodeAttempts(resetCode.ID) return "", apperrors.BadRequest("error.invalid_verification_code") } @@ -528,9 +528,9 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) { } // ResetPassword resets the user's password using a reset token -func (s *AuthService) ResetPassword(resetToken, newPassword string) error { +func (s *AuthService) ResetPassword(ctx context.Context, resetToken, newPassword string) error { // Find the reset code by token - resetCode, err := s.userRepo.FindPasswordResetCodeByToken(resetToken) + resetCode, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByToken(resetToken) if err != nil { if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) { return apperrors.BadRequest("error.invalid_reset_token") @@ -539,7 +539,7 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error { } // Get the user - user, err := s.userRepo.FindByID(resetCode.UserID) + user, err := s.userRepo.WithContext(ctx).FindByID(resetCode.UserID) if err != nil { return apperrors.Internal(err) } @@ -549,18 +549,18 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error { return apperrors.Internal(err) } - if err := s.userRepo.Update(user); err != nil { + if err := s.userRepo.WithContext(ctx).Update(user); err != nil { return apperrors.Internal(err) } // Mark reset code as used - if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil { + if err := s.userRepo.WithContext(ctx).MarkPasswordResetCodeUsed(resetCode.ID); err != nil { // Log error but don't fail log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used") } // Invalidate all existing tokens for this user (security measure) - if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil { + if err := s.userRepo.WithContext(ctx).DeleteTokenByUserID(user.ID); err != nil { // Log error but don't fail log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset") } @@ -583,10 +583,10 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi } // 2. Check if this Apple ID is already linked to an account - existingAuth, err := s.userRepo.FindByAppleID(appleID) + existingAuth, err := s.userRepo.WithContext(ctx).FindByAppleID(appleID) if err == nil && existingAuth != nil { // User already linked with this Apple ID - log them in - user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID) + user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID) if err != nil { return nil, apperrors.Internal(err) } @@ -596,13 +596,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi } // Get or create token - token, err := s.userRepo.GetOrCreateToken(user.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID) if err != nil { return nil, apperrors.Internal(err) } // Update last login - _ = s.userRepo.UpdateLastLogin(user.ID) + _ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID) return &responses.AppleSignInResponse{ Token: token.Key, @@ -614,7 +614,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi // 3. Check if email matches an existing user (for account linking) email := getEmailFromRequest(req.Email, claims.Email) if email != "" { - existingUser, err := s.userRepo.FindByEmail(email) + existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email) if err == nil && existingUser != nil { // S-06: Log auto-linking of social account to existing user log.Warn(). @@ -630,24 +630,24 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi Email: email, IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(), } - if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil { + if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil { return nil, apperrors.Internal(err) } // Mark as verified since Apple verified the email - _ = s.userRepo.SetProfileVerified(existingUser.ID, true) + _ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true) // Get or create token - token, err := s.userRepo.GetOrCreateToken(existingUser.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID) if err != nil { return nil, apperrors.Internal(err) } // Update last login - _ = s.userRepo.UpdateLastLogin(existingUser.ID) + _ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID) // B-08: Check error from FindByIDWithProfile - existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID) + existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID) if err != nil { return nil, apperrors.Internal(err) } @@ -675,19 +675,19 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi randomPassword := generateResetToken() _ = user.SetPassword(randomPassword) - if err := s.userRepo.Create(user); err != nil { + if err := s.userRepo.WithContext(ctx).Create(user); err != nil { return nil, apperrors.Internal(err) } // Create profile (already verified since Apple verified) - profile, _ := s.userRepo.GetOrCreateProfile(user.ID) + profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID) if profile != nil { - _ = s.userRepo.SetProfileVerified(user.ID, true) + _ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true) } // Create notification preferences with all options enabled if s.notificationRepo != nil { - if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { + if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil { log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user") } } @@ -699,18 +699,18 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi Email: getEmailOrDefault(email), IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(), } - if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil { + if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil { return nil, apperrors.Internal(err) } // Create token - token, err := s.userRepo.GetOrCreateToken(user.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID) if err != nil { return nil, apperrors.Internal(err) } // B-08: Check error from FindByIDWithProfile - user, err = s.userRepo.FindByIDWithProfile(user.ID) + user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID) if err != nil { return nil, apperrors.Internal(err) } @@ -736,10 +736,10 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe } // 2. Check if this Google ID is already linked to an account - existingAuth, err := s.userRepo.FindByGoogleID(googleID) + existingAuth, err := s.userRepo.WithContext(ctx).FindByGoogleID(googleID) if err == nil && existingAuth != nil { // User already linked with this Google ID - log them in - user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID) + user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID) if err != nil { return nil, apperrors.Internal(err) } @@ -749,13 +749,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe } // Get or create token - token, err := s.userRepo.GetOrCreateToken(user.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID) if err != nil { return nil, apperrors.Internal(err) } // Update last login - _ = s.userRepo.UpdateLastLogin(user.ID) + _ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID) return &responses.GoogleSignInResponse{ Token: token.Key, @@ -767,7 +767,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe // 3. Check if email matches an existing user (for account linking) email := tokenInfo.Email if email != "" { - existingUser, err := s.userRepo.FindByEmail(email) + existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email) if err == nil && existingUser != nil { // S-06: Log auto-linking of social account to existing user log.Warn(). @@ -784,26 +784,26 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe Name: tokenInfo.Name, Picture: tokenInfo.Picture, } - if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil { + if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil { return nil, apperrors.Internal(err) } // Mark as verified since Google verified the email if tokenInfo.IsEmailVerified() { - _ = s.userRepo.SetProfileVerified(existingUser.ID, true) + _ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true) } // Get or create token - token, err := s.userRepo.GetOrCreateToken(existingUser.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID) if err != nil { return nil, apperrors.Internal(err) } // Update last login - _ = s.userRepo.UpdateLastLogin(existingUser.ID) + _ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID) // B-08: Check error from FindByIDWithProfile - existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID) + existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID) if err != nil { return nil, apperrors.Internal(err) } @@ -831,19 +831,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe randomPassword := generateResetToken() _ = user.SetPassword(randomPassword) - if err := s.userRepo.Create(user); err != nil { + if err := s.userRepo.WithContext(ctx).Create(user); err != nil { return nil, apperrors.Internal(err) } // Create profile (already verified if Google verified email) - profile, _ := s.userRepo.GetOrCreateProfile(user.ID) + profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID) if profile != nil && tokenInfo.IsEmailVerified() { - _ = s.userRepo.SetProfileVerified(user.ID, true) + _ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true) } // Create notification preferences with all options enabled if s.notificationRepo != nil { - if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { + if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil { log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user") } } @@ -856,18 +856,18 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe Name: tokenInfo.Name, Picture: tokenInfo.Picture, } - if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil { + if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil { return nil, apperrors.Internal(err) } // Create token - token, err := s.userRepo.GetOrCreateToken(user.ID) + token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID) if err != nil { return nil, apperrors.Internal(err) } // B-08: Check error from FindByIDWithProfile - user, err = s.userRepo.FindByIDWithProfile(user.ID) + user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID) if err != nil { return nil, apperrors.Internal(err) } diff --git a/internal/services/auth_service_test.go b/internal/services/auth_service_test.go index c80c245..1ab3638 100644 --- a/internal/services/auth_service_test.go +++ b/internal/services/auth_service_test.go @@ -1,6 +1,7 @@ package services import ( + "context" "net/http" "testing" "time" @@ -53,7 +54,7 @@ func TestAuthService_Login(t *testing.T) { Password: "Password123", } - resp, err := service.Login(req) + resp, err := service.Login(context.Background(), req) require.NoError(t, err) assert.NotEmpty(t, resp.Token) assert.Equal(t, "testuser", resp.User.Username) @@ -74,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) { Password: "Password123", } - resp, err := service.Login(req) + resp, err := service.Login(context.Background(), req) require.NoError(t, err) assert.NotEmpty(t, resp.Token) } @@ -94,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) { Password: "WrongPassword1", } - _, err := service.Login(req) + _, err := service.Login(context.Background(), req) testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") } @@ -111,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) { Password: "Password123", } - _, err := service.Login(req) + _, err := service.Login(context.Background(), req) testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") } @@ -133,7 +134,7 @@ func TestAuthService_Login_InactiveUser(t *testing.T) { Password: "Password123", } - _, err := service.Login(req) + _, err := service.Login(context.Background(), req) testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive") } @@ -148,7 +149,7 @@ func TestAuthService_Register(t *testing.T) { Password: "Password123", } - resp, code, err := service.Register(req) + resp, code, err := service.Register(context.Background(), req) require.NoError(t, err) assert.NotEmpty(t, resp.Token) assert.Equal(t, "newuser", resp.User.Username) @@ -172,7 +173,7 @@ func TestAuthService_Register_DuplicateUsername(t *testing.T) { Password: "Password123", } - _, _, err := service.Register(req) + _, _, err := service.Register(context.Background(), req) testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken") } @@ -193,7 +194,7 @@ func TestAuthService_Register_DuplicateEmail(t *testing.T) { Password: "Password123", } - _, _, err := service.Register(req) + _, _, err := service.Register(context.Background(), req) testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken") } @@ -211,7 +212,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) { // Create profile userRepo.GetOrCreateProfile(user.ID) - resp, err := service.GetCurrentUser(user.ID) + resp, err := service.GetCurrentUser(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, "testuser", resp.Username) assert.Equal(t, "test@test.com", resp.Email) @@ -238,7 +239,7 @@ func TestAuthService_UpdateProfile(t *testing.T) { LastName: &newLast, } - resp, err := service.UpdateProfile(user.ID, req) + resp, err := service.UpdateProfile(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "John", resp.FirstName) assert.Equal(t, "Doe", resp.LastName) @@ -261,7 +262,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) { Email: &takenEmail, } - _, err := service.UpdateProfile(user2.ID, req) + _, err := service.UpdateProfile(context.Background(), user2.ID, req) testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken") } @@ -282,7 +283,7 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) { } // Same email should not trigger duplicate error - resp, err := service.UpdateProfile(user.ID, req) + resp, err := service.UpdateProfile(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "test@test.com", resp.Email) } @@ -298,7 +299,7 @@ func TestAuthService_VerifyEmail(t *testing.T) { Email: "new@test.com", Password: "Password123", } - _, _, err := service.Register(req) + _, _, err := service.Register(context.Background(), req) require.NoError(t, err) // Get the user ID @@ -306,11 +307,11 @@ func TestAuthService_VerifyEmail(t *testing.T) { require.NoError(t, err) // Verify with the debug code - err = service.VerifyEmail(user.ID, "123456") + err = service.VerifyEmail(context.Background(), user.ID, "123456") require.NoError(t, err) // Verify again — should get already verified error - err = service.VerifyEmail(user.ID, "123456") + err = service.VerifyEmail(context.Background(), user.ID, "123456") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified") } @@ -323,7 +324,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) { Email: "new@test.com", Password: "Password123", } - _, _, err := service.Register(req) + _, _, err := service.Register(context.Background(), req) require.NoError(t, err) user, err := service.userRepo.FindByEmail("new@test.com") @@ -331,7 +332,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) { // Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup, // but a wrong code should use the normal path - err = service.VerifyEmail(user.ID, "000000") + err = service.VerifyEmail(context.Background(), user.ID, "000000") assert.Error(t, err) } @@ -346,13 +347,13 @@ func TestAuthService_ResendVerificationCode(t *testing.T) { Email: "new@test.com", Password: "Password123", } - _, _, err := service.Register(req) + _, _, err := service.Register(context.Background(), req) require.NoError(t, err) user, err := service.userRepo.FindByEmail("new@test.com") require.NoError(t, err) - code, err := service.ResendVerificationCode(user.ID) + code, err := service.ResendVerificationCode(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, "123456", code) // DebugFixedCodes } @@ -366,16 +367,16 @@ func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) { Email: "new@test.com", Password: "Password123", } - _, _, err := service.Register(req) + _, _, err := service.Register(context.Background(), req) require.NoError(t, err) user, err := service.userRepo.FindByEmail("new@test.com") require.NoError(t, err) - err = service.VerifyEmail(user.ID, "123456") + err = service.VerifyEmail(context.Background(), user.ID, "123456") require.NoError(t, err) - _, err = service.ResendVerificationCode(user.ID) + _, err = service.ResendVerificationCode(context.Background(), user.ID) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified") } @@ -390,10 +391,10 @@ func TestAuthService_ForgotPassword(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) - code, user, err := service.ForgotPassword("test@test.com") + code, user, err := service.ForgotPassword(context.Background(), "test@test.com") require.NoError(t, err) assert.Equal(t, "123456", code) // DebugFixedCodes assert.NotNil(t, user) @@ -404,7 +405,7 @@ func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) { service, _ := setupAuthService(t) // Should not reveal that email doesn't exist - code, user, err := service.ForgotPassword("nonexistent@test.com") + code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com") require.NoError(t, err) assert.Empty(t, code) assert.Nil(t, user) @@ -421,20 +422,20 @@ func TestAuthService_ResetPassword(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) // Forgot password - _, _, err = service.ForgotPassword("test@test.com") + _, _, err = service.ForgotPassword(context.Background(), "test@test.com") require.NoError(t, err) // Verify reset code to get the token - resetToken, err := service.VerifyResetCode("test@test.com", "123456") + resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456") require.NoError(t, err) assert.NotEmpty(t, resetToken) // Reset password - err = service.ResetPassword(resetToken, "NewPassword123") + err = service.ResetPassword(context.Background(), resetToken, "NewPassword123") require.NoError(t, err) // Login with new password @@ -442,7 +443,7 @@ func TestAuthService_ResetPassword(t *testing.T) { Username: "testuser", Password: "NewPassword123", } - loginResp, err := service.Login(loginReq) + loginResp, err := service.Login(context.Background(), loginReq) require.NoError(t, err) assert.NotEmpty(t, loginResp.Token) } @@ -450,7 +451,7 @@ func TestAuthService_ResetPassword(t *testing.T) { func TestAuthService_ResetPassword_InvalidToken(t *testing.T) { service, _ := setupAuthService(t) - err := service.ResetPassword("invalid-token", "NewPassword123") + err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123") testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token") } @@ -471,15 +472,15 @@ func TestAuthService_Logout(t *testing.T) { Username: "testuser", Password: "Password123", } - loginResp, err := service.Login(loginReq) + loginResp, err := service.Login(context.Background(), loginReq) require.NoError(t, err) // Logout - err = service.Logout(loginResp.Token) + err = service.Logout(context.Background(), loginResp.Token) require.NoError(t, err) // Token should be deleted — refreshing should fail - _, err = service.RefreshToken(loginResp.Token, user.ID) + _, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID) assert.Error(t, err) } @@ -494,14 +495,14 @@ func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) user, err := service.userRepo.FindByEmail("test@test.com") require.NoError(t, err) password := "Password123" - _, err = service.DeleteAccount(user.ID, &password, nil) + _, err = service.DeleteAccount(context.Background(), user.ID, &password, nil) require.NoError(t, err) } @@ -513,14 +514,14 @@ func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) user, err := service.userRepo.FindByEmail("test@test.com") require.NoError(t, err) wrongPassword := "WrongPassword1" - _, err = service.DeleteAccount(user.ID, &wrongPassword, nil) + _, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil) testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") } @@ -532,13 +533,13 @@ func TestAuthService_DeleteAccount_NoPassword(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) user, err := service.userRepo.FindByEmail("test@test.com") require.NoError(t, err) - _, err = service.DeleteAccount(user.ID, nil, nil) + _, err = service.DeleteAccount(context.Background(), user.ID, nil, nil) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required") } @@ -546,7 +547,7 @@ func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) { service, _ := setupAuthService(t) password := "Password123" - _, err := service.DeleteAccount(99999, &password, nil) + _, err := service.DeleteAccount(context.Background(), 99999, &password, nil) testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found") } @@ -658,7 +659,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) { Password: "", } - _, err := service.Login(req) + _, err := service.Login(context.Background(), req) testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") } @@ -672,17 +673,17 @@ func TestAuthService_ForgotPassword_RateLimit(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) // Make max allowed reset requests (3 based on setup) for i := 0; i < 3; i++ { - _, _, err := service.ForgotPassword("test@test.com") + _, _, err := service.ForgotPassword(context.Background(), "test@test.com") require.NoError(t, err) } // The 4th should be rate limited - _, _, err = service.ForgotPassword("test@test.com") + _, _, err = service.ForgotPassword(context.Background(), "test@test.com") assert.Error(t, err) } @@ -696,14 +697,14 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) - _, _, err = service.ForgotPassword("test@test.com") + _, _, err = service.ForgotPassword(context.Background(), "test@test.com") require.NoError(t, err) // Wrong code but with debug mode, "123456" works, "000000" should fail - _, err = service.VerifyResetCode("test@test.com", "000000") + _, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000") assert.Error(t, err) } @@ -712,7 +713,7 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) { func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) { service, _ := setupAuthService(t) - _, err := service.VerifyResetCode("nonexistent@test.com", "123456") + _, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456") assert.Error(t, err) } @@ -734,7 +735,7 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) { Email: &newEmail, } - resp, err := service.UpdateProfile(user.ID, req) + resp, err := service.UpdateProfile(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "newemail@test.com", resp.Email) } @@ -749,14 +750,14 @@ func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) { Email: "test@test.com", Password: "Password123", } - _, _, err := service.Register(registerReq) + _, _, err := service.Register(context.Background(), registerReq) require.NoError(t, err) user, err := service.userRepo.FindByEmail("test@test.com") require.NoError(t, err) emptyPw := "" - _, err = service.DeleteAccount(user.ID, &emptyPw, nil) + _, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required") } @@ -789,7 +790,7 @@ func TestAuthService_Register_CreatesProfile(t *testing.T) { LastName: "Doe", } - resp, _, err := service.Register(req) + resp, _, err := service.Register(context.Background(), req) require.NoError(t, err) assert.Equal(t, "profileuser", resp.User.Username) diff --git a/internal/services/contractor_service.go b/internal/services/contractor_service.go index 565e40e..0dee23e 100644 --- a/internal/services/contractor_service.go +++ b/internal/services/contractor_service.go @@ -1,6 +1,7 @@ package services import ( + "context" "errors" "gorm.io/gorm" @@ -33,8 +34,8 @@ func NewContractorService(contractorRepo *repositories.ContractorRepository, res } // GetContractor gets a contractor by ID with access check -func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses.ContractorResponse, error) { - contractor, err := s.contractorRepo.FindByID(contractorID) +func (s *ContractorService) GetContractor(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) { + contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.contractor_not_found") @@ -43,7 +44,7 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses } // Check access - if !s.hasContractorAccess(contractor, userID) { + if !s.hasContractorAccess(ctx, contractor, userID) { return nil, apperrors.Forbidden("error.contractor_access_denied") } @@ -55,14 +56,14 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses // Access rules: // - If contractor has no residence: only the creator has access // - If contractor has a residence: all users with access to that residence -func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, userID uint) bool { +func (s *ContractorService) hasContractorAccess(ctx context.Context, contractor *models.Contractor, userID uint) bool { if contractor.ResidenceID == nil { // Personal contractor - only creator has access return contractor.CreatedByID == userID } // Residence contractor - check residence access - hasAccess, err := s.residenceRepo.HasAccess(*contractor.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*contractor.ResidenceID, userID) if err != nil { return false } @@ -70,15 +71,15 @@ func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, u } // ListContractors lists all contractors accessible to a user -func (s *ContractorService) ListContractors(userID uint) ([]responses.ContractorResponse, error) { +func (s *ContractorService) ListContractors(ctx context.Context, userID uint) ([]responses.ContractorResponse, error) { // Get residence IDs (lightweight - no preloads) - residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID) + residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID) if err != nil { return nil, apperrors.Internal(err) } // FindByUser now handles both personal and residence contractors - contractors, err := s.contractorRepo.FindByUser(userID, residenceIDs) + contractors, err := s.contractorRepo.WithContext(ctx).FindByUser(userID, residenceIDs) if err != nil { return nil, apperrors.Internal(err) } @@ -87,10 +88,10 @@ func (s *ContractorService) ListContractors(userID uint) ([]responses.Contractor } // CreateContractor creates a new contractor -func (s *ContractorService) CreateContractor(req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) { +func (s *ContractorService) CreateContractor(ctx context.Context, req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) { // If residence is provided, check access if req.ResidenceID != nil { - hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -122,19 +123,19 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque IsActive: true, } - if err := s.contractorRepo.Create(contractor); err != nil { + if err := s.contractorRepo.WithContext(ctx).Create(contractor); err != nil { return nil, apperrors.Internal(err) } // Set specialties if provided if len(req.SpecialtyIDs) > 0 { - if err := s.contractorRepo.SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil { + if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil { return nil, apperrors.Internal(err) } } // Reload with relations - contractor, reloadErr := s.contractorRepo.FindByID(contractor.ID) + contractor, reloadErr := s.contractorRepo.WithContext(ctx).FindByID(contractor.ID) if reloadErr != nil { return nil, apperrors.Internal(reloadErr) } @@ -144,8 +145,8 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque } // UpdateContractor updates a contractor -func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) { - contractor, err := s.contractorRepo.FindByID(contractorID) +func (s *ContractorService) UpdateContractor(ctx context.Context, contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) { + contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.contractor_not_found") @@ -154,7 +155,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req } // Check access - if !s.hasContractorAccess(contractor, userID) { + if !s.hasContractorAccess(ctx, contractor, userID) { return nil, apperrors.Forbidden("error.contractor_access_denied") } @@ -198,7 +199,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req // If residence_id is provided, verify the user has access to the NEW residence. // This prevents an attacker from reassigning a contractor to someone else's residence. if req.ResidenceID != nil { - hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -211,19 +212,19 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req // removed the residence association - contractor becomes personal contractor.ResidenceID = req.ResidenceID - if err := s.contractorRepo.Update(contractor); err != nil { + if err := s.contractorRepo.WithContext(ctx).Update(contractor); err != nil { return nil, apperrors.Internal(err) } // Update specialties if provided if req.SpecialtyIDs != nil { - if err := s.contractorRepo.SetSpecialties(contractorID, req.SpecialtyIDs); err != nil { + if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractorID, req.SpecialtyIDs); err != nil { return nil, apperrors.Internal(err) } } // Reload - contractor, err = s.contractorRepo.FindByID(contractorID) + contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { return nil, apperrors.Internal(err) } @@ -233,8 +234,8 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req } // DeleteContractor soft-deletes a contractor -func (s *ContractorService) DeleteContractor(contractorID, userID uint) error { - contractor, err := s.contractorRepo.FindByID(contractorID) +func (s *ContractorService) DeleteContractor(ctx context.Context, contractorID, userID uint) error { + contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return apperrors.NotFound("error.contractor_not_found") @@ -243,11 +244,11 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error { } // Check access - if !s.hasContractorAccess(contractor, userID) { + if !s.hasContractorAccess(ctx, contractor, userID) { return apperrors.Forbidden("error.contractor_access_denied") } - if err := s.contractorRepo.Delete(contractorID); err != nil { + if err := s.contractorRepo.WithContext(ctx).Delete(contractorID); err != nil { return apperrors.Internal(err) } @@ -255,8 +256,8 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error { } // ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor -func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*responses.ContractorResponse, error) { - contractor, err := s.contractorRepo.FindByID(contractorID) +func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) { + contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.contractor_not_found") @@ -265,17 +266,17 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response } // Check access - if !s.hasContractorAccess(contractor, userID) { + if !s.hasContractorAccess(ctx, contractor, userID) { return nil, apperrors.Forbidden("error.contractor_access_denied") } - _, err = s.contractorRepo.ToggleFavorite(contractorID) + _, err = s.contractorRepo.WithContext(ctx).ToggleFavorite(contractorID) if err != nil { return nil, apperrors.Internal(err) } // Re-fetch the contractor to get the updated state with all relations - contractor, err = s.contractorRepo.FindByID(contractorID) + contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { return nil, apperrors.Internal(err) } @@ -285,8 +286,8 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response } // GetContractorTasks gets all tasks for a contractor -func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]responses.TaskResponse, error) { - contractor, err := s.contractorRepo.FindByID(contractorID) +func (s *ContractorService) GetContractorTasks(ctx context.Context, contractorID, userID uint) ([]responses.TaskResponse, error) { + contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.contractor_not_found") @@ -295,11 +296,11 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res } // Check access - if !s.hasContractorAccess(contractor, userID) { + if !s.hasContractorAccess(ctx, contractor, userID) { return nil, apperrors.Forbidden("error.contractor_access_denied") } - tasks, err := s.contractorRepo.GetTasksForContractor(contractorID) + tasks, err := s.contractorRepo.WithContext(ctx).GetTasksForContractor(contractorID) if err != nil { return nil, apperrors.Internal(err) } @@ -308,9 +309,9 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res } // ListContractorsByResidence lists all contractors for a specific residence -func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint) ([]responses.ContractorResponse, error) { +func (s *ContractorService) ListContractorsByResidence(ctx context.Context, residenceID, userID uint) ([]responses.ContractorResponse, error) { // Check user has access to the residence - hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(residenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -318,7 +319,7 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint) return nil, apperrors.Forbidden("error.residence_access_denied") } - contractors, err := s.contractorRepo.FindByResidence(residenceID) + contractors, err := s.contractorRepo.WithContext(ctx).FindByResidence(residenceID) if err != nil { return nil, apperrors.Internal(err) } @@ -327,8 +328,8 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint) } // GetSpecialties returns all contractor specialties -func (s *ContractorService) GetSpecialties() ([]responses.ContractorSpecialtyResponse, error) { - specialties, err := s.contractorRepo.GetAllSpecialties() +func (s *ContractorService) GetSpecialties(ctx context.Context) ([]responses.ContractorSpecialtyResponse, error) { + specialties, err := s.contractorRepo.WithContext(ctx).GetAllSpecialties() if err != nil { return nil, apperrors.Internal(err) } diff --git a/internal/services/contractor_service_test.go b/internal/services/contractor_service_test.go index 5dd67af..e51f884 100644 --- a/internal/services/contractor_service_test.go +++ b/internal/services/contractor_service_test.go @@ -1,6 +1,7 @@ package services import ( + "context" "net/http" "testing" @@ -41,7 +42,7 @@ func TestContractorService_CreateContractor(t *testing.T) { Email: "bob@plumbing.com", } - resp, err := service.CreateContractor(req, user.ID) + resp, err := service.CreateContractor(context.Background(), req, user.ID) require.NoError(t, err) assert.NotNil(t, resp) assert.Equal(t, "Bob's Plumbing", resp.Name) @@ -63,7 +64,7 @@ func TestContractorService_CreateContractor_Personal(t *testing.T) { Name: "Personal Handyman", } - resp, err := service.CreateContractor(req, user.ID) + resp, err := service.CreateContractor(context.Background(), req, user.ID) require.NoError(t, err) assert.Equal(t, "Personal Handyman", resp.Name) } @@ -84,7 +85,7 @@ func TestContractorService_CreateContractor_AccessDenied(t *testing.T) { Name: "Unauthorized Contractor", } - _, err := service.CreateContractor(req, other.ID) + _, err := service.CreateContractor(context.Background(), req, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") } @@ -105,7 +106,7 @@ func TestContractorService_CreateContractor_WithFavorite(t *testing.T) { IsFavorite: &isFav, } - resp, err := service.CreateContractor(req, user.ID) + resp, err := service.CreateContractor(context.Background(), req, user.ID) require.NoError(t, err) assert.True(t, resp.IsFavorite) } @@ -123,7 +124,7 @@ func TestContractorService_GetContractor(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") - resp, err := service.GetContractor(contractor.ID, user.ID) + resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID) require.NoError(t, err) assert.Equal(t, contractor.ID, resp.ID) assert.Equal(t, "Test Contractor", resp.Name) @@ -138,7 +139,7 @@ func TestContractorService_GetContractor_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.GetContractor(9999, user.ID) + _, err := service.GetContractor(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") } @@ -154,7 +155,7 @@ func TestContractorService_GetContractor_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") - _, err := service.GetContractor(contractor.ID, other.ID) + _, err := service.GetContractor(context.Background(), contractor.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") } @@ -171,7 +172,7 @@ func TestContractorService_GetContractor_SharedUserHasAccess(t *testing.T) { residenceRepo.AddUser(residence.ID, shared.ID) contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor") - resp, err := service.GetContractor(contractor.ID, shared.ID) + resp, err := service.GetContractor(context.Background(), contractor.ID, shared.ID) require.NoError(t, err) assert.Equal(t, "Shared Contractor", resp.Name) } @@ -190,7 +191,7 @@ func TestContractorService_ListContractors(t *testing.T) { testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1") testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2") - resp, err := service.ListContractors(user.ID) + resp, err := service.ListContractors(context.Background(), user.ID) require.NoError(t, err) assert.Len(t, resp, 2) } @@ -208,11 +209,11 @@ func TestContractorService_DeleteContractor(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete") - err := service.DeleteContractor(contractor.ID, user.ID) + err := service.DeleteContractor(context.Background(), contractor.ID, user.ID) require.NoError(t, err) // Should not be found after deletion - _, err = service.GetContractor(contractor.ID, user.ID) + _, err = service.GetContractor(context.Background(), contractor.ID, user.ID) assert.Error(t, err) } @@ -225,7 +226,7 @@ func TestContractorService_DeleteContractor_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.DeleteContractor(9999, user.ID) + err := service.DeleteContractor(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") } @@ -241,7 +242,7 @@ func TestContractorService_DeleteContractor_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") - err := service.DeleteContractor(contractor.ID, other.ID) + err := service.DeleteContractor(context.Background(), contractor.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") } @@ -259,17 +260,17 @@ func TestContractorService_ToggleFavorite(t *testing.T) { contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") // Initially not favorite - resp, err := service.GetContractor(contractor.ID, user.ID) + resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID) require.NoError(t, err) assert.False(t, resp.IsFavorite) // Toggle to favorite - resp, err = service.ToggleFavorite(contractor.ID, user.ID) + resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID) require.NoError(t, err) assert.True(t, resp.IsFavorite) // Toggle back - resp, err = service.ToggleFavorite(contractor.ID, user.ID) + resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID) require.NoError(t, err) assert.False(t, resp.IsFavorite) } @@ -283,7 +284,7 @@ func TestContractorService_ToggleFavorite_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.ToggleFavorite(9999, user.ID) + _, err := service.ToggleFavorite(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") } @@ -299,7 +300,7 @@ func TestContractorService_ToggleFavorite_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") - _, err := service.ToggleFavorite(contractor.ID, other.ID) + _, err := service.ToggleFavorite(context.Background(), contractor.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") } @@ -317,7 +318,7 @@ func TestContractorService_ListContractorsByResidence(t *testing.T) { testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A") testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B") - resp, err := service.ListContractorsByResidence(residence.ID, user.ID) + resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID) require.NoError(t, err) assert.Len(t, resp, 2) } @@ -333,7 +334,7 @@ func TestContractorService_ListContractorsByResidence_AccessDenied(t *testing.T) other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123") residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") - _, err := service.ListContractorsByResidence(residence.ID, other.ID) + _, err := service.ListContractorsByResidence(context.Background(), residence.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") } @@ -348,7 +349,7 @@ func TestContractorService_GetContractorTasks_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.GetContractorTasks(9999, user.ID) + _, err := service.GetContractorTasks(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") } @@ -364,7 +365,7 @@ func TestContractorService_GetContractorTasks_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor") - _, err := service.GetContractorTasks(contractor.ID, other.ID) + _, err := service.GetContractorTasks(context.Background(), contractor.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") } @@ -379,7 +380,7 @@ func TestContractorService_GetContractorTasks_Empty(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") - resp, err := service.GetContractorTasks(contractor.ID, user.ID) + resp, err := service.GetContractorTasks(context.Background(), contractor.ID, user.ID) require.NoError(t, err) assert.Empty(t, resp) } @@ -393,7 +394,7 @@ func TestContractorService_GetSpecialties(t *testing.T) { residenceRepo := repositories.NewResidenceRepository(db) service := NewContractorService(contractorRepo, residenceRepo) - resp, err := service.GetSpecialties() + resp, err := service.GetSpecialties(context.Background()) require.NoError(t, err) // SeedLookupData creates 4 specialties assert.Len(t, resp, 4) @@ -413,7 +414,7 @@ func TestContractorService_UpdateContractor_NotFound(t *testing.T) { newName := "Won't Work" req := &requests.UpdateContractorRequest{Name: &newName} - _, err := service.UpdateContractor(9999, user.ID, req) + _, err := service.UpdateContractor(context.Background(), 9999, user.ID, req) testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found") } @@ -432,7 +433,7 @@ func TestContractorService_UpdateContractor_AccessDenied(t *testing.T) { newName := "Hacked" req := &requests.UpdateContractorRequest{Name: &newName} - _, err := service.UpdateContractor(contractor.ID, other.ID, req) + _, err := service.UpdateContractor(context.Background(), contractor.ID, other.ID, req) testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") } @@ -461,7 +462,7 @@ func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) { ResidenceID: &newResidenceID, } - _, err := service.UpdateContractor(contractor.ID, attacker.ID, req) + _, err := service.UpdateContractor(context.Background(), contractor.ID, attacker.ID, req) require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to") testutil.AssertAppErrorCode(t, err, http.StatusForbidden) } @@ -486,7 +487,7 @@ func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) { ResidenceID: &newResidenceID, } - resp, err := service.UpdateContractor(contractor.ID, owner.ID, req) + resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req) require.NoError(t, err, "should allow reassigning contractor to a residence the user owns") require.NotNil(t, resp) require.Equal(t, "Updated Contractor", resp.Name) @@ -508,7 +509,7 @@ func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) { ResidenceID: nil, } - resp, err := service.UpdateContractor(contractor.ID, owner.ID, req) + resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req) require.NoError(t, err, "should allow removing residence association") require.NotNil(t, resp) } @@ -555,7 +556,7 @@ func TestContractorService_UpdateContractor_PartialUpdate(t *testing.T) { ResidenceID: &residence.ID, } - resp, err := service.UpdateContractor(contractor.ID, user.ID, req) + resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req) require.NoError(t, err) assert.Equal(t, "Updated Plumber", resp.Name) assert.Equal(t, "555-9999", resp.Phone) @@ -588,7 +589,7 @@ func TestContractorService_UpdateContractor_WithSpecialties(t *testing.T) { ResidenceID: &residence.ID, } - resp, err := service.UpdateContractor(contractor.ID, user.ID, req) + resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req) require.NoError(t, err) assert.NotNil(t, resp) } @@ -615,7 +616,7 @@ func TestContractorService_CreateContractor_WithSpecialties(t *testing.T) { SpecialtyIDs: []uint{specialties[0].ID}, } - resp, err := service.CreateContractor(req, user.ID) + resp, err := service.CreateContractor(context.Background(), req, user.ID) require.NoError(t, err) assert.Equal(t, "Specialized Plumber", resp.Name) } @@ -631,7 +632,7 @@ func TestContractorService_ListContractors_Empty(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") // No residence, no contractors - resp, err := service.ListContractors(user.ID) + resp, err := service.ListContractors(context.Background(), user.ID) require.NoError(t, err) assert.Empty(t, resp) } @@ -648,7 +649,7 @@ func TestContractorService_ListContractorsByResidence_Empty(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House") - resp, err := service.ListContractorsByResidence(residence.ID, user.ID) + resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID) require.NoError(t, err) assert.Empty(t, resp) } @@ -669,14 +670,14 @@ func TestContractorService_PersonalContractor_OnlyCreatorAccess(t *testing.T) { req := &requests.CreateContractorRequest{ Name: "Personal Plumber", } - resp, err := service.CreateContractor(req, creator.ID) + resp, err := service.CreateContractor(context.Background(), req, creator.ID) require.NoError(t, err) // Creator can access - _, err = service.GetContractor(resp.ID, creator.ID) + _, err = service.GetContractor(context.Background(), resp.ID, creator.ID) require.NoError(t, err) // Other user cannot - _, err = service.GetContractor(resp.ID, other.ID) + _, err = service.GetContractor(context.Background(), resp.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied") } diff --git a/internal/services/document_service.go b/internal/services/document_service.go index 40d90bf..174047e 100644 --- a/internal/services/document_service.go +++ b/internal/services/document_service.go @@ -1,6 +1,7 @@ package services import ( + "context" "errors" "gorm.io/gorm" @@ -34,8 +35,8 @@ func NewDocumentService(documentRepo *repositories.DocumentRepository, residence } // GetDocument gets a document by ID with access check -func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.DocumentResponse, error) { - document, err := s.documentRepo.FindByID(documentID) +func (s *DocumentService) GetDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) { + document, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.document_not_found") @@ -44,7 +45,7 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum } // Check access via residence - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -57,9 +58,9 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum } // ListDocuments lists all documents accessible to a user, with optional filters. -func (s *DocumentService) ListDocuments(userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) { +func (s *DocumentService) ListDocuments(ctx context.Context, userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) { // Get residence IDs (lightweight - no preloads) - residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID) + residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -83,7 +84,7 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume residenceIDs = []uint{*filter.ResidenceID} } - documents, err := s.documentRepo.FindByUserFiltered(residenceIDs, filter) + documents, err := s.documentRepo.WithContext(ctx).FindByUserFiltered(residenceIDs, filter) if err != nil { return nil, apperrors.Internal(err) } @@ -92,9 +93,9 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume } // ListWarranties lists all warranty documents -func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentResponse, error) { +func (s *DocumentService) ListWarranties(ctx context.Context, userID uint) ([]responses.DocumentResponse, error) { // Get residence IDs (lightweight - no preloads) - residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID) + residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -103,7 +104,7 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo return []responses.DocumentResponse{}, nil } - documents, err := s.documentRepo.FindWarranties(residenceIDs) + documents, err := s.documentRepo.WithContext(ctx).FindWarranties(residenceIDs) if err != nil { return nil, apperrors.Internal(err) } @@ -112,9 +113,9 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo } // CreateDocument creates a new document -func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) { +func (s *DocumentService) CreateDocument(ctx context.Context, req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) { // Check residence access - hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(req.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -147,7 +148,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us IsActive: true, } - if err := s.documentRepo.Create(document); err != nil { + if err := s.documentRepo.WithContext(ctx).Create(document); err != nil { return nil, apperrors.Internal(err) } @@ -158,7 +159,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us DocumentID: document.ID, ImageURL: imageURL, } - if err := s.documentRepo.CreateDocumentImage(img); err != nil { + if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil { // Log but don't fail the whole operation continue } @@ -166,7 +167,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us } // Reload with relations - document, err = s.documentRepo.FindByID(document.ID) + document, err = s.documentRepo.WithContext(ctx).FindByID(document.ID) if err != nil { return nil, apperrors.Internal(err) } @@ -176,8 +177,8 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us } // UpdateDocument updates a document -func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) { - document, err := s.documentRepo.FindByID(documentID) +func (s *DocumentService) UpdateDocument(ctx context.Context, documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) { + document, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.document_not_found") @@ -186,7 +187,7 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests. } // Check access - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -238,12 +239,12 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests. document.TaskID = req.TaskID } - if err := s.documentRepo.Update(document); err != nil { + if err := s.documentRepo.WithContext(ctx).Update(document); err != nil { return nil, apperrors.Internal(err) } // Reload - document, err = s.documentRepo.FindByID(documentID) + document, err = s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { return nil, apperrors.Internal(err) } @@ -253,8 +254,8 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests. } // DeleteDocument soft-deletes a document -func (s *DocumentService) DeleteDocument(documentID, userID uint) error { - document, err := s.documentRepo.FindByID(documentID) +func (s *DocumentService) DeleteDocument(ctx context.Context, documentID, userID uint) error { + document, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return apperrors.NotFound("error.document_not_found") @@ -263,7 +264,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error { } // Check access - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return apperrors.Internal(err) } @@ -271,7 +272,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error { return apperrors.Forbidden("error.document_access_denied") } - if err := s.documentRepo.Delete(documentID); err != nil { + if err := s.documentRepo.WithContext(ctx).Delete(documentID); err != nil { return apperrors.Internal(err) } @@ -279,15 +280,15 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error { } // ActivateDocument activates a document -func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) { +func (s *DocumentService) ActivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) { // First check if document exists (even if inactive) var document models.Document - if err := s.documentRepo.FindByIDIncludingInactive(documentID, &document); err != nil { + if err := s.documentRepo.WithContext(ctx).FindByIDIncludingInactive(documentID, &document); err != nil { return nil, apperrors.NotFound("error.document_not_found") } // Check access - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -295,12 +296,12 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses. return nil, apperrors.Forbidden("error.document_access_denied") } - if err := s.documentRepo.Activate(documentID); err != nil { + if err := s.documentRepo.WithContext(ctx).Activate(documentID); err != nil { return nil, apperrors.Internal(err) } // Reload - doc, err := s.documentRepo.FindByID(documentID) + doc, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { return nil, apperrors.Internal(err) } @@ -310,8 +311,8 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses. } // DeactivateDocument deactivates a document -func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) { - document, err := s.documentRepo.FindByID(documentID) +func (s *DocumentService) DeactivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) { + document, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.document_not_found") @@ -320,7 +321,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response } // Check access - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -328,7 +329,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response return nil, apperrors.Forbidden("error.document_access_denied") } - if err := s.documentRepo.Deactivate(documentID); err != nil { + if err := s.documentRepo.WithContext(ctx).Deactivate(documentID); err != nil { return nil, apperrors.Internal(err) } @@ -338,8 +339,8 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response } // UploadDocumentImage adds an image to an existing document -func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) { - document, err := s.documentRepo.FindByID(documentID) +func (s *DocumentService) UploadDocumentImage(ctx context.Context, documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) { + document, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.document_not_found") @@ -348,7 +349,7 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, } // Check access via residence - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -361,12 +362,12 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, ImageURL: imageURL, Caption: caption, } - if err := s.documentRepo.CreateDocumentImage(img); err != nil { + if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil { return nil, apperrors.Internal(err) } // Reload with relations - document, err = s.documentRepo.FindByID(documentID) + document, err = s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { return nil, apperrors.Internal(err) } @@ -376,9 +377,9 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, } // DeleteDocumentImage removes an image from a document -func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) (*responses.DocumentResponse, error) { +func (s *DocumentService) DeleteDocumentImage(ctx context.Context, documentID, imageID, userID uint) (*responses.DocumentResponse, error) { // Find the image first - image, err := s.documentRepo.FindImageByID(imageID) + image, err := s.documentRepo.WithContext(ctx).FindImageByID(imageID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.document_image_not_found") @@ -392,7 +393,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) } // Find parent document to check access - document, err := s.documentRepo.FindByID(documentID) + document, err := s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.document_not_found") @@ -401,7 +402,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) } // Check access via residence - hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID) + hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID) if err != nil { return nil, apperrors.Internal(err) } @@ -409,12 +410,12 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) return nil, apperrors.Forbidden("error.document_access_denied") } - if err := s.documentRepo.DeleteDocumentImage(imageID); err != nil { + if err := s.documentRepo.WithContext(ctx).DeleteDocumentImage(imageID); err != nil { return nil, apperrors.Internal(err) } // Reload with relations - document, err = s.documentRepo.FindByID(documentID) + document, err = s.documentRepo.WithContext(ctx).FindByID(documentID) if err != nil { return nil, apperrors.Internal(err) } diff --git a/internal/services/document_service_test.go b/internal/services/document_service_test.go index 7b6e80b..d02455d 100644 --- a/internal/services/document_service_test.go +++ b/internal/services/document_service_test.go @@ -1,6 +1,7 @@ package services import ( + "context" "net/http" "testing" @@ -42,7 +43,7 @@ func TestDocumentService_CreateDocument(t *testing.T) { FileName: "manual.pdf", } - resp, err := service.CreateDocument(req, user.ID) + resp, err := service.CreateDocument(context.Background(), req, user.ID) require.NoError(t, err) assert.NotNil(t, resp) assert.Equal(t, "Furnace Manual", resp.Title) @@ -64,7 +65,7 @@ func TestDocumentService_CreateDocument_DefaultType(t *testing.T) { // DocumentType not set — should default to "general" } - resp, err := service.CreateDocument(req, user.ID) + resp, err := service.CreateDocument(context.Background(), req, user.ID) require.NoError(t, err) assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType) } @@ -84,7 +85,7 @@ func TestDocumentService_CreateDocument_WithImages(t *testing.T) { ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"}, } - resp, err := service.CreateDocument(req, user.ID) + resp, err := service.CreateDocument(context.Background(), req, user.ID) require.NoError(t, err) assert.NotNil(t, resp) assert.Equal(t, "Receipt with photos", resp.Title) @@ -105,7 +106,7 @@ func TestDocumentService_CreateDocument_AccessDenied(t *testing.T) { Title: "Unauthorized Doc", } - _, err := service.CreateDocument(req, other.ID) + _, err := service.CreateDocument(context.Background(), req, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") } @@ -121,7 +122,7 @@ func TestDocumentService_GetDocument(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc") - resp, err := service.GetDocument(doc.ID, user.ID) + resp, err := service.GetDocument(context.Background(), doc.ID, user.ID) require.NoError(t, err) assert.Equal(t, doc.ID, resp.ID) assert.Equal(t, "Test Doc", resp.Title) @@ -135,7 +136,7 @@ func TestDocumentService_GetDocument_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.GetDocument(9999, user.ID) + _, err := service.GetDocument(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") } @@ -150,7 +151,7 @@ func TestDocumentService_GetDocument_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") - _, err := service.GetDocument(doc.ID, other.ID) + _, err := service.GetDocument(context.Background(), doc.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -173,7 +174,7 @@ func TestDocumentService_UpdateDocument(t *testing.T) { Description: &newDesc, } - resp, err := service.UpdateDocument(doc.ID, user.ID, req) + resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req) require.NoError(t, err) assert.Equal(t, "Updated Title", resp.Title) assert.Equal(t, "Updated description", resp.Description) @@ -190,7 +191,7 @@ func TestDocumentService_UpdateDocument_NotFound(t *testing.T) { newTitle := "Won't Work" req := &requests.UpdateDocumentRequest{Title: &newTitle} - _, err := service.UpdateDocument(9999, user.ID, req) + _, err := service.UpdateDocument(context.Background(), 9999, user.ID, req) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") } @@ -208,7 +209,7 @@ func TestDocumentService_UpdateDocument_AccessDenied(t *testing.T) { newTitle := "Hacked" req := &requests.UpdateDocumentRequest{Title: &newTitle} - _, err := service.UpdateDocument(doc.ID, other.ID, req) + _, err := service.UpdateDocument(context.Background(), doc.ID, other.ID, req) testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -225,7 +226,7 @@ func TestDocumentService_UpdateDocument_ChangeType(t *testing.T) { newType := models.DocumentTypeWarranty req := &requests.UpdateDocumentRequest{DocumentType: &newType} - resp, err := service.UpdateDocument(doc.ID, user.ID, req) + resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req) require.NoError(t, err) assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType) } @@ -242,11 +243,11 @@ func TestDocumentService_DeleteDocument(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete") - err := service.DeleteDocument(doc.ID, user.ID) + err := service.DeleteDocument(context.Background(), doc.ID, user.ID) require.NoError(t, err) // Should not be found after deletion - _, err = service.GetDocument(doc.ID, user.ID) + _, err = service.GetDocument(context.Background(), doc.ID, user.ID) assert.Error(t, err) } @@ -258,7 +259,7 @@ func TestDocumentService_DeleteDocument_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.DeleteDocument(9999, user.ID) + err := service.DeleteDocument(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") } @@ -273,7 +274,7 @@ func TestDocumentService_DeleteDocument_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") - err := service.DeleteDocument(doc.ID, other.ID) + err := service.DeleteDocument(context.Background(), doc.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -290,7 +291,7 @@ func TestDocumentService_ListDocuments(t *testing.T) { testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1") testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2") - resp, err := service.ListDocuments(user.ID, nil) + resp, err := service.ListDocuments(context.Background(), user.ID, nil) require.NoError(t, err) assert.Len(t, resp, 2) } @@ -303,7 +304,7 @@ func TestDocumentService_ListDocuments_NoResidences(t *testing.T) { user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123") - resp, err := service.ListDocuments(user.ID, nil) + resp, err := service.ListDocuments(context.Background(), user.ID, nil) require.NoError(t, err) assert.Empty(t, resp) } @@ -321,7 +322,7 @@ func TestDocumentService_ListDocuments_FilterByResidence(t *testing.T) { testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B") filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID} - resp, err := service.ListDocuments(user.ID, filter) + resp, err := service.ListDocuments(context.Background(), user.ID, filter) require.NoError(t, err) assert.Len(t, resp, 1) assert.Equal(t, "Doc A", resp[0].Title) @@ -340,7 +341,7 @@ func TestDocumentService_ListDocuments_FilterByResidence_AccessDenied(t *testing testutil.CreateTestResidence(t, db, other.ID, "Other House") filter := &repositories.DocumentFilter{ResidenceID: &residence.ID} - _, err := service.ListDocuments(other.ID, filter) + _, err := service.ListDocuments(context.Background(), other.ID, filter) testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied") } @@ -369,7 +370,7 @@ func TestDocumentService_ListWarranties(t *testing.T) { // Create a general doc testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc") - resp, err := service.ListWarranties(user.ID) + resp, err := service.ListWarranties(context.Background(), user.ID) require.NoError(t, err) assert.Len(t, resp, 1) assert.Equal(t, "HVAC Warranty", resp[0].Title) @@ -383,7 +384,7 @@ func TestDocumentService_ListWarranties_NoResidences(t *testing.T) { user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123") - resp, err := service.ListWarranties(user.ID) + resp, err := service.ListWarranties(context.Background(), user.ID) require.NoError(t, err) assert.Empty(t, resp) } @@ -400,7 +401,7 @@ func TestDocumentService_DeactivateDocument(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate") - resp, err := service.DeactivateDocument(doc.ID, user.ID) + resp, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID) require.NoError(t, err) assert.False(t, resp.IsActive) } @@ -413,7 +414,7 @@ func TestDocumentService_DeactivateDocument_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.DeactivateDocument(9999, user.ID) + _, err := service.DeactivateDocument(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") } @@ -428,7 +429,7 @@ func TestDocumentService_DeactivateDocument_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") - _, err := service.DeactivateDocument(doc.ID, other.ID) + _, err := service.DeactivateDocument(context.Background(), doc.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -444,7 +445,7 @@ func TestDocumentService_UploadDocumentImage(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc") - resp, err := service.UploadDocumentImage(doc.ID, user.ID, "https://example.com/photo.jpg", "Front view") + resp, err := service.UploadDocumentImage(context.Background(), doc.ID, user.ID, "https://example.com/photo.jpg", "Front view") require.NoError(t, err) assert.NotNil(t, resp) } @@ -457,7 +458,7 @@ func TestDocumentService_UploadDocumentImage_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.UploadDocumentImage(9999, user.ID, "https://example.com/photo.jpg", "") + _, err := service.UploadDocumentImage(context.Background(), 9999, user.ID, "https://example.com/photo.jpg", "") testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") } @@ -472,7 +473,7 @@ func TestDocumentService_UploadDocumentImage_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") - _, err := service.UploadDocumentImage(doc.ID, other.ID, "https://example.com/photo.jpg", "") + _, err := service.UploadDocumentImage(context.Background(), doc.ID, other.ID, "https://example.com/photo.jpg", "") testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -496,7 +497,7 @@ func TestDocumentService_DeleteDocumentImage(t *testing.T) { err := db.Create(img).Error require.NoError(t, err) - resp, err := service.DeleteDocumentImage(doc.ID, img.ID, user.ID) + resp, err := service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, user.ID) require.NoError(t, err) assert.NotNil(t, resp) } @@ -511,7 +512,7 @@ func TestDocumentService_DeleteDocumentImage_NotFound(t *testing.T) { residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc") - _, err := service.DeleteDocumentImage(doc.ID, 9999, user.ID) + _, err := service.DeleteDocumentImage(context.Background(), doc.ID, 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found") } @@ -535,7 +536,7 @@ func TestDocumentService_DeleteDocumentImage_WrongDocument(t *testing.T) { require.NoError(t, err) // Try to delete the image specifying doc2 - _, err = service.DeleteDocumentImage(doc2.ID, img.ID, user.ID) + _, err = service.DeleteDocumentImage(context.Background(), doc2.ID, img.ID, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found") } @@ -557,7 +558,7 @@ func TestDocumentService_DeleteDocumentImage_AccessDenied(t *testing.T) { err := db.Create(img).Error require.NoError(t, err) - _, err = service.DeleteDocumentImage(doc.ID, img.ID, other.ID) + _, err = service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -575,7 +576,7 @@ func TestDocumentService_GetDocument_SharedUserHasAccess(t *testing.T) { residenceRepo.AddUser(residence.ID, shared.ID) doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc") - resp, err := service.GetDocument(doc.ID, shared.ID) + resp, err := service.GetDocument(context.Background(), doc.ID, shared.ID) require.NoError(t, err) assert.Equal(t, "Shared Doc", resp.Title) } @@ -593,11 +594,11 @@ func TestDocumentService_ActivateDocument(t *testing.T) { doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate") // Deactivate first - _, err := service.DeactivateDocument(doc.ID, user.ID) + _, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID) require.NoError(t, err) // Now activate - resp, err := service.ActivateDocument(doc.ID, user.ID) + resp, err := service.ActivateDocument(context.Background(), doc.ID, user.ID) require.NoError(t, err) assert.True(t, resp.IsActive) } @@ -610,7 +611,7 @@ func TestDocumentService_ActivateDocument_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - _, err := service.ActivateDocument(9999, user.ID) + _, err := service.ActivateDocument(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found") } @@ -625,7 +626,7 @@ func TestDocumentService_ActivateDocument_AccessDenied(t *testing.T) { residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House") doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc") - _, err := service.ActivateDocument(doc.ID, other.ID) + _, err := service.ActivateDocument(context.Background(), doc.ID, other.ID) testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied") } @@ -646,7 +647,7 @@ func TestDocumentService_CreateDocument_WithEmptyImageURL(t *testing.T) { ImageURLs: []string{"", "https://example.com/img.jpg", ""}, } - resp, err := service.CreateDocument(req, user.ID) + resp, err := service.CreateDocument(context.Background(), req, user.ID) require.NoError(t, err) assert.NotNil(t, resp) } @@ -687,7 +688,7 @@ func TestDocumentService_UpdateDocument_AllFields(t *testing.T) { ModelNumber: &newModel, } - resp, err := service.UpdateDocument(doc.ID, user.ID, req) + resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req) require.NoError(t, err) assert.Equal(t, "Updated", resp.Title) assert.Equal(t, "New description", resp.Description) @@ -720,7 +721,7 @@ func TestDocumentService_ListDocuments_FilterByType(t *testing.T) { testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc") filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)} - resp, err := service.ListDocuments(user.ID, filter) + resp, err := service.ListDocuments(context.Background(), user.ID, filter) require.NoError(t, err) assert.Len(t, resp, 1) assert.Equal(t, "Warranty Doc", resp[0].Title) @@ -742,7 +743,7 @@ func TestDocumentService_SharedUser_CanUpdate(t *testing.T) { newTitle := "Updated by shared user" req := &requests.UpdateDocumentRequest{Title: &newTitle} - resp, err := service.UpdateDocument(doc.ID, shared.ID, req) + resp, err := service.UpdateDocument(context.Background(), doc.ID, shared.ID, req) require.NoError(t, err) assert.Equal(t, "Updated by shared user", resp.Title) } @@ -759,6 +760,6 @@ func TestDocumentService_SharedUser_CanDelete(t *testing.T) { residenceRepo.AddUser(residence.ID, shared.ID) doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc") - err := service.DeleteDocument(doc.ID, shared.ID) + err := service.DeleteDocument(context.Background(), doc.ID, shared.ID) require.NoError(t, err) } diff --git a/internal/services/notification_service.go b/internal/services/notification_service.go index 1ce4858..1bb7e1f 100644 --- a/internal/services/notification_service.go +++ b/internal/services/notification_service.go @@ -44,8 +44,8 @@ func NewNotificationService(notificationRepo *repositories.NotificationRepositor // === Notifications === // GetNotifications gets notifications for a user -func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ([]NotificationResponse, error) { - notifications, err := s.notificationRepo.FindByUser(userID, limit, offset) +func (s *NotificationService) GetNotifications(ctx context.Context, userID uint, limit, offset int) ([]NotificationResponse, error) { + notifications, err := s.notificationRepo.WithContext(ctx).FindByUser(userID, limit, offset) if err != nil { return nil, apperrors.Internal(err) } @@ -58,8 +58,8 @@ func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ( } // GetUnreadCount gets the count of unread notifications -func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) { - count, err := s.notificationRepo.CountUnread(userID) +func (s *NotificationService) GetUnreadCount(ctx context.Context, userID uint) (int64, error) { + count, err := s.notificationRepo.WithContext(ctx).CountUnread(userID) if err != nil { return 0, apperrors.Internal(err) } @@ -67,8 +67,8 @@ func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) { } // MarkAsRead marks a notification as read -func (s *NotificationService) MarkAsRead(notificationID, userID uint) error { - notification, err := s.notificationRepo.FindByID(notificationID) +func (s *NotificationService) MarkAsRead(ctx context.Context, notificationID, userID uint) error { + notification, err := s.notificationRepo.WithContext(ctx).FindByID(notificationID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return apperrors.NotFound("error.notification_not_found") @@ -80,15 +80,15 @@ func (s *NotificationService) MarkAsRead(notificationID, userID uint) error { return apperrors.NotFound("error.notification_not_found") } - if err := s.notificationRepo.MarkAsRead(notificationID); err != nil { + if err := s.notificationRepo.WithContext(ctx).MarkAsRead(notificationID); err != nil { return apperrors.Internal(err) } return nil } // MarkAllAsRead marks all notifications as read -func (s *NotificationService) MarkAllAsRead(userID uint) error { - if err := s.notificationRepo.MarkAllAsRead(userID); err != nil { +func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID uint) error { + if err := s.notificationRepo.WithContext(ctx).MarkAllAsRead(userID); err != nil { return apperrors.Internal(err) } return nil @@ -97,7 +97,7 @@ func (s *NotificationService) MarkAllAsRead(userID uint) error { // CreateAndSendNotification creates a notification and sends it via push func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error { // Check user preferences - prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) + prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID) if err != nil { return apperrors.Internal(err) } @@ -117,12 +117,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use Data: string(dataJSON), } - if err := s.notificationRepo.Create(notification); err != nil { + if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil { return apperrors.Internal(err) } // Get device tokens - iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID) + iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID) if err != nil { return apperrors.Internal(err) } @@ -144,12 +144,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use if s.pushClient != nil { err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData) if err != nil { - s.notificationRepo.SetError(notification.ID, err.Error()) + s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error()) return apperrors.Internal(err) } } - if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil { + if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil { return apperrors.Internal(err) } return nil @@ -178,8 +178,8 @@ func (s *NotificationService) isNotificationEnabled(prefs *models.NotificationPr // === Notification Preferences === // GetPreferences gets notification preferences for a user -func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferencesResponse, error) { - prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) +func (s *NotificationService) GetPreferences(ctx context.Context, userID uint) (*NotificationPreferencesResponse, error) { + prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -196,7 +196,7 @@ func validateHourField(val *int, fieldName string) error { } // UpdatePreferences updates notification preferences -func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) { +func (s *NotificationService) UpdatePreferences(ctx context.Context, userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) { // B-12: Validate hour fields are in range 0-23 hourFields := []struct { value *int @@ -213,7 +213,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen } } - prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) + prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -258,7 +258,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen prefs.DailyDigestHour = req.DailyDigestHour } - if err := s.notificationRepo.UpdatePreferences(prefs); err != nil { + if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil { return nil, apperrors.Internal(err) } @@ -268,14 +268,14 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen // UpdateUserTimezone updates the user's timezone for background job calculations. // This is called automatically when the user makes API calls (e.g., fetching tasks). // The timezone should be an IANA timezone name (e.g., "America/Los_Angeles"). -func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) { +func (s *NotificationService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) { // Validate timezone is a valid IANA name if _, err := time.LoadLocation(timezone); err != nil { return // Invalid timezone, skip silently } // Get or create preferences and update timezone - prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) + prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID) if err != nil { return // Skip silently on error } @@ -283,7 +283,7 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) { // Only update if timezone changed (avoid unnecessary DB writes) if prefs.Timezone == nil || *prefs.Timezone != timezone { prefs.Timezone = &timezone - if err := s.notificationRepo.UpdatePreferences(prefs); err != nil { + if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil { log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone). Msg("Failed to update user timezone in notification preferences") } @@ -293,27 +293,27 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) { // === Device Registration === // RegisterDevice registers a device for push notifications -func (s *NotificationService) RegisterDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { +func (s *NotificationService) RegisterDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { switch req.Platform { case push.PlatformIOS: - return s.registerAPNSDevice(userID, req) + return s.registerAPNSDevice(ctx, userID, req) case push.PlatformAndroid: - return s.registerGCMDevice(userID, req) + return s.registerGCMDevice(ctx, userID, req) default: return nil, apperrors.BadRequest("error.invalid_platform") } } -func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { +func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { // Check if device exists - existing, err := s.notificationRepo.FindAPNSDeviceByToken(req.RegistrationID) + existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID) if err == nil { // Update existing device existing.UserID = &userID existing.Active = true existing.Name = req.Name existing.DeviceID = req.DeviceID - if err := s.notificationRepo.UpdateAPNSDevice(existing); err != nil { + if err := s.notificationRepo.WithContext(ctx).UpdateAPNSDevice(existing); err != nil { return nil, apperrors.Internal(err) } return NewAPNSDeviceResponse(existing), nil @@ -327,22 +327,22 @@ func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDevic RegistrationID: req.RegistrationID, Active: true, } - if err := s.notificationRepo.CreateAPNSDevice(device); err != nil { + if err := s.notificationRepo.WithContext(ctx).CreateAPNSDevice(device); err != nil { return nil, apperrors.Internal(err) } return NewAPNSDeviceResponse(device), nil } -func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { +func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) { // Check if device exists - existing, err := s.notificationRepo.FindGCMDeviceByToken(req.RegistrationID) + existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID) if err == nil { // Update existing device existing.UserID = &userID existing.Active = true existing.Name = req.Name existing.DeviceID = req.DeviceID - if err := s.notificationRepo.UpdateGCMDevice(existing); err != nil { + if err := s.notificationRepo.WithContext(ctx).UpdateGCMDevice(existing); err != nil { return nil, apperrors.Internal(err) } return NewGCMDeviceResponse(existing), nil @@ -357,20 +357,20 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice CloudMessageType: "FCM", Active: true, } - if err := s.notificationRepo.CreateGCMDevice(device); err != nil { + if err := s.notificationRepo.WithContext(ctx).CreateGCMDevice(device); err != nil { return nil, apperrors.Internal(err) } return NewGCMDeviceResponse(device), nil } // ListDevices lists all devices for a user -func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) { - iosDevices, err := s.notificationRepo.FindAPNSDevicesByUser(userID) +func (s *NotificationService) ListDevices(ctx context.Context, userID uint) ([]DeviceResponse, error) { + iosDevices, err := s.notificationRepo.WithContext(ctx).FindAPNSDevicesByUser(userID) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.Internal(err) } - androidDevices, err := s.notificationRepo.FindGCMDevicesByUser(userID) + androidDevices, err := s.notificationRepo.WithContext(ctx).FindGCMDevicesByUser(userID) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.Internal(err) } @@ -387,10 +387,10 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) // DeleteDevice deactivates a device after verifying it belongs to the requesting user. // Without ownership verification, an attacker could deactivate push notifications for other users. -func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error { +func (s *NotificationService) DeleteDevice(ctx context.Context, deviceID uint, platform string, userID uint) error { switch platform { case push.PlatformIOS: - device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID) + device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByID(deviceID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return apperrors.NotFound("error.device_not_found") @@ -401,11 +401,11 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI if device.UserID == nil || *device.UserID != userID { return apperrors.Forbidden("error.device_access_denied") } - if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil { + if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(deviceID); err != nil { return apperrors.Internal(err) } case push.PlatformAndroid: - device, err := s.notificationRepo.FindGCMDeviceByID(deviceID) + device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByID(deviceID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return apperrors.NotFound("error.device_not_found") @@ -416,7 +416,7 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI if device.UserID == nil || *device.UserID != userID { return apperrors.Forbidden("error.device_access_denied") } - if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil { + if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(deviceID); err != nil { return apperrors.Internal(err) } default: @@ -426,10 +426,10 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI } // UnregisterDevice deactivates a device by its registration token -func (s *NotificationService) UnregisterDevice(registrationID, platform string, userID uint) error { +func (s *NotificationService) UnregisterDevice(ctx context.Context, registrationID, platform string, userID uint) error { switch platform { case push.PlatformIOS: - device, err := s.notificationRepo.FindAPNSDeviceByToken(registrationID) + device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(registrationID) if err != nil { return apperrors.NotFound("error.device_not_found") } @@ -437,11 +437,11 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string, if device.UserID == nil || *device.UserID != userID { return apperrors.NotFound("error.device_not_found") } - if err := s.notificationRepo.DeactivateAPNSDevice(device.ID); err != nil { + if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(device.ID); err != nil { return apperrors.Internal(err) } case push.PlatformAndroid: - device, err := s.notificationRepo.FindGCMDeviceByToken(registrationID) + device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(registrationID) if err != nil { return apperrors.NotFound("error.device_not_found") } @@ -449,7 +449,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string, if device.UserID == nil || *device.UserID != userID { return apperrors.NotFound("error.device_not_found") } - if err := s.notificationRepo.DeactivateGCMDevice(device.ID); err != nil { + if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(device.ID); err != nil { return apperrors.Internal(err) } default: @@ -624,7 +624,7 @@ func (s *NotificationService) CreateAndSendTaskNotification( task *models.Task, ) error { // Check user notification preferences - prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) + prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID) if err != nil { return apperrors.Internal(err) } @@ -662,12 +662,12 @@ func (s *NotificationService) CreateAndSendTaskNotification( TaskID: &task.ID, } - if err := s.notificationRepo.Create(notification); err != nil { + if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil { return apperrors.Internal(err) } // Get device tokens - iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID) + iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID) if err != nil { return apperrors.Internal(err) } @@ -691,12 +691,12 @@ func (s *NotificationService) CreateAndSendTaskNotification( if s.pushClient != nil { err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID) if err != nil { - s.notificationRepo.SetError(notification.ID, err.Error()) + s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error()) return apperrors.Internal(err) } } - if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil { + if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil { return apperrors.Internal(err) } return nil diff --git a/internal/services/notification_service_test.go b/internal/services/notification_service_test.go index c74994d..35d27ec 100644 --- a/internal/services/notification_service_test.go +++ b/internal/services/notification_service_test.go @@ -44,7 +44,7 @@ func TestNotificationService_GetNotifications(t *testing.T) { require.NoError(t, err) } - resp, err := service.GetNotifications(user.ID, 10, 0) + resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, resp, 3) } @@ -56,7 +56,7 @@ func TestNotificationService_GetNotifications_Empty(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - resp, err := service.GetNotifications(user.ID, 10, 0) + resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, resp) } @@ -80,7 +80,7 @@ func TestNotificationService_GetNotifications_Pagination(t *testing.T) { } // Get first 2 - resp, err := service.GetNotifications(user.ID, 2, 0) + resp, err := service.GetNotifications(context.Background(), user.ID, 2, 0) require.NoError(t, err) assert.Len(t, resp, 2) } @@ -107,7 +107,7 @@ func TestNotificationService_GetUnreadCount(t *testing.T) { require.NoError(t, err) } - count, err := service.GetUnreadCount(user.ID) + count, err := service.GetUnreadCount(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, int64(3), count) } @@ -119,7 +119,7 @@ func TestNotificationService_GetUnreadCount_Zero(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - count, err := service.GetUnreadCount(user.ID) + count, err := service.GetUnreadCount(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, int64(0), count) } @@ -142,11 +142,11 @@ func TestNotificationService_MarkAsRead(t *testing.T) { err := db.Create(notif).Error require.NoError(t, err) - err = service.MarkAsRead(notif.ID, user.ID) + err = service.MarkAsRead(context.Background(), notif.ID, user.ID) require.NoError(t, err) // Verify unread count is 0 - count, err := service.GetUnreadCount(user.ID) + count, err := service.GetUnreadCount(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, int64(0), count) } @@ -168,7 +168,7 @@ func TestNotificationService_MarkAsRead_WrongUser(t *testing.T) { err := db.Create(notif).Error require.NoError(t, err) - err = service.MarkAsRead(notif.ID, other.ID) + err = service.MarkAsRead(context.Background(), notif.ID, other.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found") } @@ -179,7 +179,7 @@ func TestNotificationService_MarkAsRead_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.MarkAsRead(9999, user.ID) + err := service.MarkAsRead(context.Background(), 9999, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found") } @@ -204,10 +204,10 @@ func TestNotificationService_MarkAllAsRead(t *testing.T) { require.NoError(t, err) } - err := service.MarkAllAsRead(user.ID) + err := service.MarkAllAsRead(context.Background(), user.ID) require.NoError(t, err) - count, err := service.GetUnreadCount(user.ID) + count, err := service.GetUnreadCount(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, int64(0), count) } @@ -229,7 +229,7 @@ func TestNotificationService_CreateAndSendNotification(t *testing.T) { require.NoError(t, err) // Verify notification was created - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) assert.Equal(t, "Due Soon", notifs[0].Title) @@ -254,7 +254,7 @@ func TestNotificationService_CreateAndSendNotification_DisabledPreference(t *tes require.NoError(t, err) // Verify no notification was created (silently skipped) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, notifs) } @@ -268,7 +268,7 @@ func TestNotificationService_GetPreferences(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - resp, err := service.GetPreferences(user.ID) + resp, err := service.GetPreferences(context.Background(), user.ID) require.NoError(t, err) assert.NotNil(t, resp) // Defaults should all be true @@ -289,7 +289,7 @@ func TestNotificationService_UpdatePreferences(t *testing.T) { TaskDueSoon: &falseVal, } - resp, err := service.UpdatePreferences(user.ID, req) + resp, err := service.UpdatePreferences(context.Background(), user.ID, req) require.NoError(t, err) assert.False(t, resp.TaskDueSoon) assert.True(t, resp.TaskOverdue) // unchanged @@ -307,7 +307,7 @@ func TestNotificationService_UpdatePreferences_InvalidHour(t *testing.T) { TaskDueSoonHour: &invalidHour, } - _, err := service.UpdatePreferences(user.ID, req) + _, err := service.UpdatePreferences(context.Background(), user.ID, req) testutil.AssertAppErrorCode(t, err, http.StatusBadRequest) } @@ -323,7 +323,7 @@ func TestNotificationService_UpdatePreferences_ValidHour(t *testing.T) { TaskDueSoonHour: &hour, } - resp, err := service.UpdatePreferences(user.ID, req) + resp, err := service.UpdatePreferences(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, 9, *resp.TaskDueSoonHour) } @@ -344,7 +344,7 @@ func TestNotificationService_RegisterDevice_iOS(t *testing.T) { Platform: push.PlatformIOS, } - resp, err := service.RegisterDevice(user.ID, req) + resp, err := service.RegisterDevice(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "iPhone 15", resp.Name) assert.Equal(t, push.PlatformIOS, resp.Platform) @@ -365,7 +365,7 @@ func TestNotificationService_RegisterDevice_Android(t *testing.T) { Platform: push.PlatformAndroid, } - resp, err := service.RegisterDevice(user.ID, req) + resp, err := service.RegisterDevice(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "Pixel 8", resp.Name) assert.Equal(t, push.PlatformAndroid, resp.Platform) @@ -386,7 +386,7 @@ func TestNotificationService_RegisterDevice_InvalidPlatform(t *testing.T) { Platform: "windows", } - _, err := service.RegisterDevice(user.ID, req) + _, err := service.RegisterDevice(context.Background(), user.ID, req) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform") } @@ -404,12 +404,12 @@ func TestNotificationService_RegisterDevice_UpdateExisting(t *testing.T) { RegistrationID: "token-xyz", Platform: push.PlatformIOS, } - _, err := service.RegisterDevice(user.ID, req) + _, err := service.RegisterDevice(context.Background(), user.ID, req) require.NoError(t, err) // Re-register with same token (should update, not duplicate) req.Name = "iPhone 15 Pro" - resp, err := service.RegisterDevice(user.ID, req) + resp, err := service.RegisterDevice(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "iPhone 15 Pro", resp.Name) } @@ -445,7 +445,7 @@ func TestNotificationService_ListDevices(t *testing.T) { err = db.Create(androidDevice).Error require.NoError(t, err) - resp, err := service.ListDevices(user.ID) + resp, err := service.ListDevices(context.Background(), user.ID) require.NoError(t, err) assert.Len(t, resp, 2) } @@ -457,7 +457,7 @@ func TestNotificationService_ListDevices_Empty(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - resp, err := service.ListDevices(user.ID) + resp, err := service.ListDevices(context.Background(), user.ID) require.NoError(t, err) assert.Empty(t, resp) } @@ -471,7 +471,7 @@ func TestDeleteDevice_InvalidPlatform_Returns400(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.DeleteDevice(1, "windows", user.ID) + err := service.DeleteDevice(context.Background(), 1, "windows", user.ID) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform") } @@ -494,7 +494,7 @@ func TestNotificationService_UnregisterDevice_iOS(t *testing.T) { err := db.Create(device).Error require.NoError(t, err) - err = service.UnregisterDevice("reg-token-ios", push.PlatformIOS, user.ID) + err = service.UnregisterDevice(context.Background(), "reg-token-ios", push.PlatformIOS, user.ID) require.NoError(t, err) // Verify device is deactivated @@ -522,7 +522,7 @@ func TestNotificationService_UnregisterDevice_Android(t *testing.T) { err := db.Create(device).Error require.NoError(t, err) - err = service.UnregisterDevice("reg-token-android", push.PlatformAndroid, user.ID) + err = service.UnregisterDevice(context.Background(), "reg-token-android", push.PlatformAndroid, user.ID) require.NoError(t, err) var found models.GCMDevice @@ -538,7 +538,7 @@ func TestNotificationService_UnregisterDevice_NotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.UnregisterDevice("nonexistent-token", push.PlatformIOS, user.ID) + err := service.UnregisterDevice(context.Background(), "nonexistent-token", push.PlatformIOS, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") } @@ -560,7 +560,7 @@ func TestNotificationService_UnregisterDevice_WrongUser(t *testing.T) { err := db.Create(device).Error require.NoError(t, err) - err = service.UnregisterDevice("owner-token", push.PlatformIOS, attacker.ID) + err = service.UnregisterDevice(context.Background(), "owner-token", push.PlatformIOS, attacker.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") } @@ -571,7 +571,7 @@ func TestNotificationService_UnregisterDevice_InvalidPlatform(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.UnregisterDevice("some-token", "windows", user.ID) + err := service.UnregisterDevice(context.Background(), "some-token", "windows", user.ID) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform") } @@ -585,7 +585,7 @@ func TestNotificationService_UpdateUserTimezone(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") // Should not panic, just silently update - service.UpdateUserTimezone(user.ID, "America/Los_Angeles") + service.UpdateUserTimezone(context.Background(), user.ID, "America/Los_Angeles") // Verify timezone was stored prefs, err := notifRepo.GetOrCreatePreferences(user.ID) @@ -602,7 +602,7 @@ func TestNotificationService_UpdateUserTimezone_Invalid(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") // Invalid timezone should be silently ignored - service.UpdateUserTimezone(user.ID, "Invalid/Timezone") + service.UpdateUserTimezone(context.Background(), user.ID, "Invalid/Timezone") prefs, err := notifRepo.GetOrCreatePreferences(user.ID) require.NoError(t, err) @@ -617,10 +617,10 @@ func TestNotificationService_UpdateUserTimezone_NoChangeSkipsWrite(t *testing.T) user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") // Set timezone - service.UpdateUserTimezone(user.ID, "America/New_York") + service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York") // Set same timezone again — should be a no-op - service.UpdateUserTimezone(user.ID, "America/New_York") + service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York") prefs, err := notifRepo.GetOrCreatePreferences(user.ID) require.NoError(t, err) @@ -648,7 +648,7 @@ func TestDeleteDevice_WrongUser_Returns403(t *testing.T) { require.NoError(t, err) // Attacker tries to deactivate the owner's device - err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID) + err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, attacker.ID) require.Error(t, err, "should not allow deleting another user's device") testutil.AssertAppErrorCode(t, err, http.StatusForbidden) @@ -678,7 +678,7 @@ func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) { require.NoError(t, err) // Owner deactivates their own device - err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID) + err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, owner.ID) require.NoError(t, err, "owner should be able to deactivate their own device") // Verify the device is now inactive @@ -709,7 +709,7 @@ func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) { require.NoError(t, err) // Attacker tries to deactivate the owner's Android device - err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID) + err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, attacker.ID) require.Error(t, err, "should not allow deleting another user's Android device") testutil.AssertAppErrorCode(t, err, http.StatusForbidden) @@ -727,7 +727,7 @@ func TestDeleteDevice_NonExistent_Returns404(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") - err := service.DeleteDevice(99999, push.PlatformIOS, user.ID) + err := service.DeleteDevice(context.Background(), 99999, push.PlatformIOS, user.ID) require.Error(t, err, "should return error for non-existent device") testutil.AssertAppErrorCode(t, err, http.StatusNotFound) } @@ -744,7 +744,7 @@ func TestNotificationService_CreateAndSend_TaskOverdue(t *testing.T) { err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Task is overdue", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) assert.Equal(t, "Overdue", notifs[0].Title) @@ -760,7 +760,7 @@ func TestNotificationService_CreateAndSend_TaskCompleted(t *testing.T) { err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) } @@ -775,7 +775,7 @@ func TestNotificationService_CreateAndSend_TaskAssigned(t *testing.T) { err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned to you", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) } @@ -790,7 +790,7 @@ func TestNotificationService_CreateAndSend_ResidenceShared(t *testing.T) { err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Someone shared a home", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) } @@ -805,7 +805,7 @@ func TestNotificationService_CreateAndSend_WarrantyExpiring(t *testing.T) { err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Expiring", "Warranty expiring soon", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) } @@ -828,7 +828,7 @@ func TestNotificationService_DisabledPrefs_TaskOverdue(t *testing.T) { err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Overdue task", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, notifs) } @@ -849,7 +849,7 @@ func TestNotificationService_DisabledPrefs_TaskCompleted(t *testing.T) { err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, notifs) } @@ -870,7 +870,7 @@ func TestNotificationService_DisabledPrefs_TaskAssigned(t *testing.T) { err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, notifs) } @@ -891,7 +891,7 @@ func TestNotificationService_DisabledPrefs_ResidenceShared(t *testing.T) { err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Home shared", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, notifs) } @@ -912,7 +912,7 @@ func TestNotificationService_DisabledPrefs_WarrantyExpiring(t *testing.T) { err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Warranty", "Expiring", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Empty(t, notifs) } @@ -929,7 +929,7 @@ func TestNotificationService_CreateAndSend_UnknownTypeDefaultsEnabled(t *testing err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationType("unknown_type"), "Unknown", "Unknown notification", nil) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) } @@ -953,7 +953,7 @@ func TestNotificationService_CreateAndSend_WithMixedDataTypes(t *testing.T) { err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskDueSoon, "Due Soon", "Fix faucet", data) require.NoError(t, err) - notifs, err := service.GetNotifications(user.ID, 10, 0) + notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0) require.NoError(t, err) assert.Len(t, notifs, 1) assert.NotNil(t, notifs[0].Data) @@ -985,7 +985,7 @@ func TestNotificationService_UpdatePreferences_MultipleFields(t *testing.T) { TaskOverdueHour: &hour14, } - resp, err := service.UpdatePreferences(user.ID, req) + resp, err := service.UpdatePreferences(context.Background(), user.ID, req) require.NoError(t, err) assert.False(t, resp.TaskDueSoon) assert.False(t, resp.TaskOverdue) @@ -1013,7 +1013,7 @@ func TestNotificationService_UpdatePreferences_NegativeHour(t *testing.T) { TaskOverdueHour: &negHour, } - _, err := service.UpdatePreferences(user.ID, req) + _, err := service.UpdatePreferences(context.Background(), user.ID, req) testutil.AssertAppErrorCode(t, err, http.StatusBadRequest) } @@ -1032,12 +1032,12 @@ func TestNotificationService_RegisterDevice_UpdateExistingAndroid(t *testing.T) RegistrationID: "token-android-1", Platform: push.PlatformAndroid, } - _, err := service.RegisterDevice(user.ID, req) + _, err := service.RegisterDevice(context.Background(), user.ID, req) require.NoError(t, err) // Re-register with same token but new name req.Name = "Pixel 8 Pro" - resp, err := service.RegisterDevice(user.ID, req) + resp, err := service.RegisterDevice(context.Background(), user.ID, req) require.NoError(t, err) assert.Equal(t, "Pixel 8 Pro", resp.Name) assert.Equal(t, push.PlatformAndroid, resp.Platform) @@ -1052,7 +1052,7 @@ func TestDeleteDevice_AndroidNotFound_Returns404(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.DeleteDevice(99999, push.PlatformAndroid, user.ID) + err := service.DeleteDevice(context.Background(), 99999, push.PlatformAndroid, user.ID) testutil.AssertAppErrorCode(t, err, http.StatusNotFound) } @@ -1076,7 +1076,7 @@ func TestDeleteDevice_CorrectUser_Android_Succeeds(t *testing.T) { err := db.Create(device).Error require.NoError(t, err) - err = service.DeleteDevice(device.ID, push.PlatformAndroid, owner.ID) + err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, owner.ID) require.NoError(t, err) var found models.GCMDevice @@ -1106,7 +1106,7 @@ func TestNotificationService_UnregisterDevice_WrongUser_Android(t *testing.T) { err := db.Create(device).Error require.NoError(t, err) - err = service.UnregisterDevice("owner-android-token", push.PlatformAndroid, attacker.ID) + err = service.UnregisterDevice(context.Background(), "owner-android-token", push.PlatformAndroid, attacker.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") } @@ -1119,7 +1119,7 @@ func TestNotificationService_UnregisterDevice_AndroidNotFound(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - err := service.UnregisterDevice("nonexistent-android", push.PlatformAndroid, user.ID) + err := service.UnregisterDevice(context.Background(), "nonexistent-android", push.PlatformAndroid, user.ID) testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found") } diff --git a/internal/services/residence_service.go b/internal/services/residence_service.go index 385ab4c..d507983 100644 --- a/internal/services/residence_service.go +++ b/internal/services/residence_service.go @@ -186,7 +186,7 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary { func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) { // Check subscription tier limits (if subscription service is wired up) if s.subscriptionService != nil { - if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil { + if err := s.subscriptionService.CheckLimit(ctx, ownerID, "properties"); err != nil { return nil, err } } diff --git a/internal/services/subscription_service.go b/internal/services/subscription_service.go index 52e8db8..84ff5fe 100644 --- a/internal/services/subscription_service.go +++ b/internal/services/subscription_service.go @@ -98,8 +98,8 @@ func NewSubscriptionService( } // GetSubscription gets the subscription for a user -func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionResponse, error) { - sub, err := s.subscriptionRepo.GetOrCreate(userID) +func (s *SubscriptionService) GetSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) { + sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID) if err != nil { return nil, apperrors.Internal(err) } @@ -107,13 +107,13 @@ func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionRespons } // GetSubscriptionStatus gets detailed subscription status including limits -func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionStatusResponse, error) { - sub, err := s.subscriptionRepo.GetOrCreate(userID) +func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID uint) (*SubscriptionStatusResponse, error) { + sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID) if err != nil { return nil, apperrors.Internal(err) } - settings, err := s.subscriptionRepo.GetSettings() + settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings() if err != nil { return nil, apperrors.Internal(err) } @@ -122,18 +122,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled { now := time.Now().UTC() trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour) - if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil { + if err := s.subscriptionRepo.WithContext(ctx).SetTrialDates(userID, now, trialEnd); err != nil { return nil, apperrors.Internal(err) } // Re-fetch after starting trial so response reflects the new state - sub, err = s.subscriptionRepo.FindByUserID(userID) + sub, err = s.subscriptionRepo.WithContext(ctx).FindByUserID(userID) if err != nil { return nil, apperrors.Internal(err) } } // Get all tier limits and build a map - allLimits, err := s.subscriptionRepo.GetAllTierLimits() + allLimits, err := s.subscriptionRepo.WithContext(ctx).GetAllTierLimits() if err != nil { return nil, apperrors.Internal(err) } @@ -154,7 +154,7 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS } // Get current usage - usage, err := s.getUserUsage(userID) + usage, err := s.getUserUsage(ctx, userID) if err != nil { return nil, err } @@ -204,31 +204,31 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS // getUserUsage calculates current usage for a user. // P-10: Uses CountByOwner for properties count instead of loading all owned residences. // Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)). -func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) { +func (s *SubscriptionService) getUserUsage(ctx context.Context, userID uint) (*UsageResponse, error) { // P-10: Use CountByOwner for an efficient COUNT query instead of loading all records - propertiesCount, err := s.residenceRepo.CountByOwner(userID) + propertiesCount, err := s.residenceRepo.WithContext(ctx).CountByOwner(userID) if err != nil { return nil, apperrors.Internal(err) } // Still need residence IDs for batch counting tasks/contractors/documents - residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID) + residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByOwner(userID) if err != nil { return nil, apperrors.Internal(err) } // Count tasks, contractors, and documents across all residences with single queries each - tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs) + tasksCount, err := s.taskRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs) if err != nil { return nil, apperrors.Internal(err) } - contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs) + contractorsCount, err := s.contractorRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs) if err != nil { return nil, apperrors.Internal(err) } - documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs) + documentsCount, err := s.documentRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs) if err != nil { return nil, apperrors.Internal(err) } @@ -242,8 +242,8 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) } // CheckLimit checks if a user has exceeded a specific limit -func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error { - settings, err := s.subscriptionRepo.GetSettings() +func (s *SubscriptionService) CheckLimit(ctx context.Context, userID uint, limitType string) error { + settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings() if err != nil { return apperrors.Internal(err) } @@ -253,7 +253,7 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error { return nil } - sub, err := s.subscriptionRepo.GetOrCreate(userID) + sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID) if err != nil { return apperrors.Internal(err) } @@ -268,12 +268,12 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error { return nil } - limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier) + limits, err := s.subscriptionRepo.WithContext(ctx).GetTierLimits(sub.Tier) if err != nil { return apperrors.Internal(err) } - usage, err := s.getUserUsage(userID) + usage, err := s.getUserUsage(ctx, userID) if err != nil { return err } @@ -301,8 +301,8 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error { } // GetUpgradeTrigger gets an upgrade trigger by key -func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResponse, error) { - trigger, err := s.subscriptionRepo.GetUpgradeTrigger(key) +func (s *SubscriptionService) GetUpgradeTrigger(ctx context.Context, key string) (*UpgradeTriggerResponse, error) { + trigger, err := s.subscriptionRepo.WithContext(ctx).GetUpgradeTrigger(key) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, apperrors.NotFound("error.upgrade_trigger_not_found") @@ -314,8 +314,8 @@ func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResp // GetAllUpgradeTriggers gets all active upgrade triggers as a map keyed by trigger_key // KMM client expects Map -func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTriggerDataResponse, error) { - triggers, err := s.subscriptionRepo.GetAllUpgradeTriggers() +func (s *SubscriptionService) GetAllUpgradeTriggers(ctx context.Context) (map[string]*UpgradeTriggerDataResponse, error) { + triggers, err := s.subscriptionRepo.WithContext(ctx).GetAllUpgradeTriggers() if err != nil { return nil, apperrors.Internal(err) } @@ -328,8 +328,8 @@ func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTrigge } // GetFeatureBenefits gets all feature benefits -func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, error) { - benefits, err := s.subscriptionRepo.GetFeatureBenefits() +func (s *SubscriptionService) GetFeatureBenefits(ctx context.Context) ([]FeatureBenefitResponse, error) { + benefits, err := s.subscriptionRepo.WithContext(ctx).GetFeatureBenefits() if err != nil { return nil, apperrors.Internal(err) } @@ -342,13 +342,13 @@ func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, er } // GetActivePromotions gets active promotions for a user -func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionResponse, error) { - sub, err := s.subscriptionRepo.GetOrCreate(userID) +func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID uint) ([]PromotionResponse, error) { + sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID) if err != nil { return nil, apperrors.Internal(err) } - promotions, err := s.subscriptionRepo.GetActivePromotions(sub.Tier) + promotions, err := s.subscriptionRepo.WithContext(ctx).GetActivePromotions(sub.Tier) if err != nil { return nil, apperrors.Internal(err) } @@ -362,13 +362,13 @@ func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionRespo // ProcessApplePurchase processes an Apple IAP purchase // Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID) -func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) { +func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) { // Store receipt/transaction data dataToStore := receiptData if dataToStore == "" { dataToStore = transactionID } - if err := s.subscriptionRepo.UpdateReceiptData(userID, dataToStore); err != nil { + if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil { return nil, apperrors.Internal(err) } @@ -406,18 +406,18 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated") // Upgrade to Pro with the validated expiration - if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil { + if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil { return nil, apperrors.Internal(err) } - return s.GetSubscription(userID) + return s.GetSubscription(ctx, userID) } // ProcessGooglePurchase processes a Google Play purchase // productID is optional but helps validate the specific subscription -func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) { +func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) { // Store purchase token first - if err := s.subscriptionRepo.UpdatePurchaseToken(userID, purchaseToken); err != nil { + if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil { return nil, apperrors.Internal(err) } @@ -463,25 +463,25 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s } // Upgrade to Pro with the validated expiration - if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil { + if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil { return nil, apperrors.Internal(err) } - return s.GetSubscription(userID) + return s.GetSubscription(ctx, userID) } // CancelSubscription cancels a subscription (downgrades to free at end of period) -func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResponse, error) { - if err := s.subscriptionRepo.SetAutoRenew(userID, false); err != nil { +func (s *SubscriptionService) CancelSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) { + if err := s.subscriptionRepo.WithContext(ctx).SetAutoRenew(userID, false); err != nil { return nil, apperrors.Internal(err) } - return s.GetSubscription(userID) + return s.GetSubscription(ctx, userID) } // IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription // from a different platform than the one being requested. Returns (conflict, existingPlatform, error). -func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) { - sub, err := s.subscriptionRepo.GetOrCreate(userID) +func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(ctx context.Context, userID uint, requestedPlatform string) (bool, string, error) { + sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID) if err != nil { return false, "", apperrors.Internal(err) } diff --git a/internal/services/subscription_service_test.go b/internal/services/subscription_service_test.go index b1c365c..c5b721d 100644 --- a/internal/services/subscription_service_test.go +++ b/internal/services/subscription_service_test.go @@ -1,6 +1,7 @@ package services import ( + "context" "testing" "time" @@ -70,7 +71,7 @@ func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) { googleClient: nil, } - _, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "") + _, err := svc.ProcessApplePurchase(context.Background(), user.ID, "fake-receipt", "") assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil") // Verify user was NOT upgraded to Pro @@ -109,7 +110,7 @@ func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) { } // Neither receipt data nor transaction ID - should still not grant Pro - _, err := svc.ProcessApplePurchase(user.ID, "", "") + _, err := svc.ProcessApplePurchase(context.Background(), user.ID, "", "") assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data") // Verify no upgrade happened @@ -140,7 +141,7 @@ func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) { googleClient: nil, // Not configured } - _, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.honeyDue.pro.monthly") + _, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "fake-token", "com.tt.honeyDue.pro.monthly") assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil") // Verify user was NOT upgraded to Pro @@ -172,7 +173,7 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) { } // With empty token - _, err := svc.ProcessGooglePurchase(user.ID, "", "") + _, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "", "") assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil") // Verify no upgrade happened @@ -202,7 +203,7 @@ func TestSubscriptionService_GetSubscription(t *testing.T) { user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123") - resp, err := svc.GetSubscription(user.ID) + resp, err := svc.GetSubscription(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, "free", resp.Tier) assert.False(t, resp.IsPro) @@ -238,7 +239,7 @@ func TestSubscriptionService_GetSubscription_ProUser(t *testing.T) { err := db.Create(sub).Error require.NoError(t, err) - resp, err := svc.GetSubscription(user.ID) + resp, err := svc.GetSubscription(context.Background(), user.ID) require.NoError(t, err) assert.Equal(t, "pro", resp.Tier) assert.True(t, resp.IsPro) @@ -277,7 +278,7 @@ func TestSubscriptionService_CancelSubscription(t *testing.T) { err := db.Create(sub).Error require.NoError(t, err) - resp, err := svc.CancelSubscription(user.ID) + resp, err := svc.CancelSubscription(context.Background(), user.ID) require.NoError(t, err) assert.False(t, resp.AutoRenew) } @@ -365,7 +366,7 @@ func TestIsAlreadyProFromOtherPlatform(t *testing.T) { err := db.Create(sub).Error require.NoError(t, err) - conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform) + conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(context.Background(), user.ID, tt.requestedPlatform) require.NoError(t, err) assert.Equal(t, tt.wantConflict, conflict) assert.Equal(t, tt.wantPlatform, existingPlatform) diff --git a/internal/services/task_service.go b/internal/services/task_service.go index 07a1a92..193ce5a 100644 --- a/internal/services/task_service.go +++ b/internal/services/task_service.go @@ -898,7 +898,7 @@ func (s *TaskService) sendTaskCompletedNotification(ctx context.Context, task *m // Send email notification (to everyone INCLUDING the person who completed it) // Check user's email notification preferences first if s.emailService != nil && user.Email != "" && s.notificationService != nil { - prefs, prefsErr := s.notificationService.GetPreferences(user.ID) + prefs, prefsErr := s.notificationService.GetPreferences(ctx, user.ID) // LE-06: Log fail-open behavior when preferences cannot be loaded if prefsErr != nil { log.Warn(). @@ -1264,8 +1264,8 @@ func (s *TaskService) GetFrequencies(ctx context.Context) ([]responses.TaskFrequ // UpdateUserTimezone updates the user's timezone for background job calculations. // This is called from handlers when the X-Timezone header is present. // Delegates to NotificationService since timezone is stored in notification preferences. -func (s *TaskService) UpdateUserTimezone(userID uint, timezone string) { +func (s *TaskService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) { if s.notificationService != nil { - s.notificationService.UpdateUserTimezone(userID, timezone) + s.notificationService.UpdateUserTimezone(ctx, userID, timezone) } }