Migrate Auth/Contractor/Document/Notification/Subscription services to ctx
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Build (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled

Every public method on these five services now takes ctx context.Context as
the first arg and routes its repo calls through .WithContext(ctx). With
TaskService and ResidenceService already migrated, this means every
in-process service that touches Postgres now produces a flame graph in
Jaeger where the SQL spans nest under the parent HTTP request span.

Endpoints now fully traced (HTTP → service → SQL):
- /api/auth/login, /register, /logout, /me, /verify-email, /resend-verification
- /api/auth/forgot-password, /verify-reset, /reset-password, /update-profile
- /api/contractors/* (CRUD + favorite + by-residence + tasks)
- /api/documents/* (CRUD + activate/deactivate + image upload/delete)
- /api/notifications/* (list, count, mark-read, prefs, devices)
- /api/subscription/* (status, purchase, cancel, triggers, promotions)
- All previously-migrated /api/tasks/* and /api/residences/* paths

Internal helpers also threaded:
- TaskService.sendTaskCompletedNotification → forwards ctx
- TaskService.UpdateUserTimezone → forwards ctx to NotificationService
- ResidenceService.CreateResidence → forwards ctx to SubscriptionService.CheckLimit
- NotificationService.registerAPNSDevice / registerGCMDevice → both take ctx

~75 method signatures, ~120 handler/test call sites updated. Tests pass
green; the only failure is the pre-existing flaky TaskHandler_QuickComplete
SQLite race that fails ~60% of runs on master.

Step 3 of the observability plan is now genuinely complete: every API
endpoint backed by a Go service emits a per-request flame graph with
HTTP → service → SQL spans, plus B2/APNs/FCM/asynq spans where applicable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-04-25 16:26:21 -05:00
parent 65a9aae4e5
commit e881d37de0
20 changed files with 529 additions and 522 deletions
+12 -12
View File
@@ -65,7 +65,7 @@ func (h *AuthHandler) Login(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
response, err := h.authService.Login(&req)
response, err := h.authService.Login(c.Request().Context(), &req)
if err != nil {
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
if h.auditService != nil {
@@ -94,7 +94,7 @@ func (h *AuthHandler) Register(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
response, confirmationCode, err := h.authService.Register(&req)
response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req)
if err != nil {
log.Debug().Err(err).Msg("Registration failed")
return err
@@ -141,7 +141,7 @@ func (h *AuthHandler) Logout(c echo.Context) error {
}
// Invalidate token in database
if err := h.authService.Logout(token); err != nil {
if err := h.authService.Logout(c.Request().Context(), token); err != nil {
log.Warn().Err(err).Msg("Failed to delete token from database")
}
@@ -162,7 +162,7 @@ func (h *AuthHandler) CurrentUser(c echo.Context) error {
return err
}
response, err := h.authService.GetCurrentUser(user.ID)
response, err := h.authService.GetCurrentUser(c.Request().Context(), user.ID)
if err != nil {
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user")
return err
@@ -186,7 +186,7 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
response, err := h.authService.UpdateProfile(user.ID, &req)
response, err := h.authService.UpdateProfile(c.Request().Context(), user.ID, &req)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile")
return err
@@ -210,7 +210,7 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
err = h.authService.VerifyEmail(user.ID, req.Code)
err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
return err
@@ -243,7 +243,7 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
return err
}
code, err := h.authService.ResendVerificationCode(user.ID)
code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
return err
@@ -276,7 +276,7 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
code, user, err := h.authService.ForgotPassword(req.Email)
code, user, err := h.authService.ForgotPassword(c.Request().Context(), req.Email)
if err != nil {
var appErr *apperrors.AppError
if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests {
@@ -324,7 +324,7 @@ func (h *AuthHandler) VerifyResetCode(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
resetToken, err := h.authService.VerifyResetCode(req.Email, req.Code)
resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code)
if err != nil {
log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed")
return err
@@ -346,7 +346,7 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
err := h.authService.ResetPassword(req.ResetToken, req.NewPassword)
err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword)
if err != nil {
log.Debug().Err(err).Msg("Password reset failed")
return err
@@ -469,7 +469,7 @@ func (h *AuthHandler) RefreshToken(c echo.Context) error {
return apperrors.Unauthorized("error.not_authenticated")
}
response, err := h.authService.RefreshToken(token, user.ID)
response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
return err
@@ -497,7 +497,7 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
return apperrors.BadRequest("error.invalid_request")
}
fileURLs, err := h.authService.DeleteAccount(user.ID, req.Password, req.Confirmation)
fileURLs, err := h.authService.DeleteAccount(c.Request().Context(), user.ID, req.Password, req.Confirmation)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed")
return err
+9 -9
View File
@@ -30,7 +30,7 @@ func (h *ContractorHandler) ListContractors(c echo.Context) error {
if err != nil {
return err
}
response, err := h.contractorService.ListContractors(user.ID)
response, err := h.contractorService.ListContractors(c.Request().Context(), user.ID)
if err != nil {
return apperrors.Internal(err)
}
@@ -48,7 +48,7 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id")
}
response, err := h.contractorService.GetContractor(uint(contractorID), user.ID)
response, err := h.contractorService.GetContractor(c.Request().Context(), uint(contractorID), user.ID)
if err != nil {
return err
}
@@ -69,7 +69,7 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
return err
}
response, err := h.contractorService.CreateContractor(&req, user.ID)
response, err := h.contractorService.CreateContractor(c.Request().Context(), &req, user.ID)
if err != nil {
return err
}
@@ -95,7 +95,7 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
return err
}
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
response, err := h.contractorService.UpdateContractor(c.Request().Context(), uint(contractorID), user.ID, &req)
if err != nil {
return err
}
@@ -113,7 +113,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id")
}
err = h.contractorService.DeleteContractor(uint(contractorID), user.ID)
err = h.contractorService.DeleteContractor(c.Request().Context(), uint(contractorID), user.ID)
if err != nil {
return err
}
@@ -131,7 +131,7 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id")
}
response, err := h.contractorService.ToggleFavorite(uint(contractorID), user.ID)
response, err := h.contractorService.ToggleFavorite(c.Request().Context(), uint(contractorID), user.ID)
if err != nil {
return err
}
@@ -149,7 +149,7 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
return apperrors.BadRequest("error.invalid_contractor_id")
}
response, err := h.contractorService.GetContractorTasks(uint(contractorID), user.ID)
response, err := h.contractorService.GetContractorTasks(c.Request().Context(), uint(contractorID), user.ID)
if err != nil {
return err
}
@@ -167,7 +167,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
return apperrors.BadRequest("error.invalid_residence_id")
}
response, err := h.contractorService.ListContractorsByResidence(uint(residenceID), user.ID)
response, err := h.contractorService.ListContractorsByResidence(c.Request().Context(), uint(residenceID), user.ID)
if err != nil {
return err
}
@@ -176,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
// GetSpecialties handles GET /api/contractors/specialties/
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
specialties, err := h.contractorService.GetSpecialties()
specialties, err := h.contractorService.GetSpecialties(c.Request().Context())
if err != nil {
return apperrors.Internal(err)
}
+10 -10
View File
@@ -70,7 +70,7 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
}
}
response, err := h.documentService.ListDocuments(user.ID, filter)
response, err := h.documentService.ListDocuments(c.Request().Context(), user.ID, filter)
if err != nil {
return err
}
@@ -88,7 +88,7 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id")
}
response, err := h.documentService.GetDocument(uint(documentID), user.ID)
response, err := h.documentService.GetDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil {
return err
}
@@ -101,7 +101,7 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
if err != nil {
return err
}
response, err := h.documentService.ListWarranties(user.ID)
response, err := h.documentService.ListWarranties(c.Request().Context(), user.ID)
if err != nil {
return apperrors.Internal(err)
}
@@ -222,7 +222,7 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
return err
}
response, err := h.documentService.CreateDocument(&req, user.ID)
response, err := h.documentService.CreateDocument(c.Request().Context(), &req, user.ID)
if err != nil {
return err
}
@@ -248,7 +248,7 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
return err
}
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
response, err := h.documentService.UpdateDocument(c.Request().Context(), uint(documentID), user.ID, &req)
if err != nil {
return err
}
@@ -266,7 +266,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id")
}
err = h.documentService.DeleteDocument(uint(documentID), user.ID)
err = h.documentService.DeleteDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil {
return err
}
@@ -284,7 +284,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id")
}
response, err := h.documentService.ActivateDocument(uint(documentID), user.ID)
response, err := h.documentService.ActivateDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil {
return err
}
@@ -302,7 +302,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
return apperrors.BadRequest("error.invalid_document_id")
}
response, err := h.documentService.DeactivateDocument(uint(documentID), user.ID)
response, err := h.documentService.DeactivateDocument(c.Request().Context(), uint(documentID), user.ID)
if err != nil {
return err
}
@@ -349,7 +349,7 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
caption := c.FormValue("caption")
response, err := h.documentService.UploadDocumentImage(uint(documentID), user.ID, result.URL, caption)
response, err := h.documentService.UploadDocumentImage(c.Request().Context(), uint(documentID), user.ID, result.URL, caption)
if err != nil {
return err
}
@@ -372,7 +372,7 @@ func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
return apperrors.BadRequest("error.invalid_image_id")
}
response, err := h.documentService.DeleteDocumentImage(uint(documentID), uint(imageID), user.ID)
response, err := h.documentService.DeleteDocumentImage(c.Request().Context(), uint(documentID), uint(imageID), user.ID)
if err != nil {
return err
}
+10 -10
View File
@@ -46,7 +46,7 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
}
}
notifications, err := h.notificationService.GetNotifications(user.ID, limit, offset)
notifications, err := h.notificationService.GetNotifications(c.Request().Context(), user.ID, limit, offset)
if err != nil {
return err
}
@@ -64,7 +64,7 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
return err
}
count, err := h.notificationService.GetUnreadCount(user.ID)
count, err := h.notificationService.GetUnreadCount(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -84,7 +84,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
return apperrors.BadRequest("error.invalid_notification_id")
}
err = h.notificationService.MarkAsRead(uint(notificationID), user.ID)
err = h.notificationService.MarkAsRead(c.Request().Context(), uint(notificationID), user.ID)
if err != nil {
return err
}
@@ -99,7 +99,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
return err
}
err = h.notificationService.MarkAllAsRead(user.ID)
err = h.notificationService.MarkAllAsRead(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -114,7 +114,7 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
return err
}
prefs, err := h.notificationService.GetPreferences(user.ID)
prefs, err := h.notificationService.GetPreferences(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -137,7 +137,7 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
return err
}
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
prefs, err := h.notificationService.UpdatePreferences(c.Request().Context(), user.ID, &req)
if err != nil {
return err
}
@@ -160,7 +160,7 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
return err
}
device, err := h.notificationService.RegisterDevice(user.ID, &req)
device, err := h.notificationService.RegisterDevice(c.Request().Context(), user.ID, &req)
if err != nil {
return err
}
@@ -175,7 +175,7 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
return err
}
devices, err := h.notificationService.ListDevices(user.ID)
devices, err := h.notificationService.ListDevices(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -208,7 +208,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
return apperrors.BadRequest("error.invalid_platform")
}
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
err = h.notificationService.UnregisterDevice(c.Request().Context(), req.RegistrationID, req.Platform, user.ID)
if err != nil {
return err
}
@@ -236,7 +236,7 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
return apperrors.BadRequest("error.invalid_platform")
}
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID)
err = h.notificationService.DeleteDevice(c.Request().Context(), uint(deviceID), platform, user.ID)
if err != nil {
return err
}
+1 -1
View File
@@ -106,7 +106,7 @@ func (h *StaticDataHandler) GetStaticData(c echo.Context) error {
return err
}
contractorSpecialties, err := h.contractorService.GetSpecialties()
contractorSpecialties, err := h.contractorService.GetSpecialties(c.Request().Context())
if err != nil {
return err
}
+12 -12
View File
@@ -32,7 +32,7 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
return err
}
subscription, err := h.subscriptionService.GetSubscription(user.ID)
subscription, err := h.subscriptionService.GetSubscription(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -47,7 +47,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
return err
}
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
status, err := h.subscriptionService.GetSubscriptionStatus(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -59,7 +59,7 @@ func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
key := c.Param("key")
trigger, err := h.subscriptionService.GetUpgradeTrigger(key)
trigger, err := h.subscriptionService.GetUpgradeTrigger(c.Request().Context(), key)
if err != nil {
return err
}
@@ -69,7 +69,7 @@ func (h *SubscriptionHandler) GetUpgradeTrigger(c echo.Context) error {
// GetAllUpgradeTriggers handles GET /api/subscription/upgrade-triggers/
func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
triggers, err := h.subscriptionService.GetAllUpgradeTriggers()
triggers, err := h.subscriptionService.GetAllUpgradeTriggers(c.Request().Context())
if err != nil {
return err
}
@@ -79,7 +79,7 @@ func (h *SubscriptionHandler) GetAllUpgradeTriggers(c echo.Context) error {
// GetFeatureBenefits handles GET /api/subscription/features/
func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
benefits, err := h.subscriptionService.GetFeatureBenefits()
benefits, err := h.subscriptionService.GetFeatureBenefits(c.Request().Context())
if err != nil {
return err
}
@@ -94,7 +94,7 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
return err
}
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
promotions, err := h.subscriptionService.GetActivePromotions(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -125,12 +125,12 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
if req.TransactionID == "" && req.ReceiptData == "" {
return apperrors.BadRequest("error.receipt_data_required")
}
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID)
case "android":
if req.PurchaseToken == "" {
return apperrors.BadRequest("error.purchase_token_required")
}
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID)
default:
return apperrors.BadRequest("error.invalid_platform")
}
@@ -152,7 +152,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
return err
}
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
subscription, err := h.subscriptionService.CancelSubscription(c.Request().Context(), user.ID)
if err != nil {
return err
}
@@ -187,12 +187,12 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
if req.ReceiptData == "" && req.TransactionID == "" {
return apperrors.BadRequest("error.receipt_data_required")
}
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
subscription, err = h.subscriptionService.ProcessApplePurchase(c.Request().Context(), user.ID, req.ReceiptData, req.TransactionID)
case "android":
if req.PurchaseToken == "" {
return apperrors.BadRequest("error.purchase_token_required")
}
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
subscription, err = h.subscriptionService.ProcessGooglePurchase(c.Request().Context(), user.ID, req.PurchaseToken, req.ProductID)
default:
return apperrors.BadRequest("error.invalid_platform")
}
@@ -220,7 +220,7 @@ func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
}
// Check if already Pro from another platform
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe")
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(c.Request().Context(), user.ID, "stripe")
if err != nil {
return err
}
+1 -1
View File
@@ -42,7 +42,7 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
cachedTZ, _ := c.Get("user_timezone").(string)
if cachedTZ != tzHeader {
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
h.taskService.UpdateUserTimezone(c.Request().Context(), user.ID, tzHeader)
c.Set("user_timezone", tzHeader)
}
}
+8 -7
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"testing"
"time"
@@ -74,7 +75,7 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
assert.Contains(t, resp.Message, "still valid")
@@ -87,7 +88,7 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
assert.Contains(t, resp.Message, "refreshed")
@@ -114,7 +115,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.token_expired")
@@ -129,7 +130,7 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
}
@@ -140,7 +141,7 @@ func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken("nonexistent-token-key", user.ID)
resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
@@ -154,7 +155,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
svc := newTestAuthService(db)
// Try to refresh with a different user ID
resp, err := svc.RefreshToken(token.Key, user.ID+999)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
@@ -167,7 +168,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
}
+85 -85
View File
@@ -57,14 +57,14 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N
}
// Login authenticates a user and returns a token
func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginResponse, error) {
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) {
// Find user by username or email
identifier := req.Username
if identifier == "" {
identifier = req.Email
}
user, err := s.userRepo.FindByUsernameOrEmail(identifier)
user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
return nil, apperrors.Unauthorized("error.invalid_credentials")
@@ -83,13 +83,13 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
}
// Get or create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Update last login
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
if err := s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID); err != nil {
// Log error but don't fail the login
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
}
@@ -103,9 +103,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
// Register creates a new user account.
// F-10: User creation, profile creation, notification preferences, and confirmation code
// are wrapped in a transaction for atomicity.
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
func (s *AuthService) Register(ctx context.Context, req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
// Check if username exists
exists, err := s.userRepo.ExistsByUsername(req.Username)
exists, err := s.userRepo.WithContext(ctx).ExistsByUsername(req.Username)
if err != nil {
return nil, "", apperrors.Internal(err)
}
@@ -114,7 +114,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
}
// Check if email exists
exists, err = s.userRepo.ExistsByEmail(req.Email)
exists, err = s.userRepo.WithContext(ctx).ExistsByEmail(req.Email)
if err != nil {
return nil, "", apperrors.Internal(err)
}
@@ -146,7 +146,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
// Save user
if err := txRepo.Create(user); err != nil {
return err
@@ -159,7 +159,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
}
}
@@ -176,7 +176,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
}
// Create auth token (outside transaction since token generation is idempotent)
token, err := s.userRepo.GetOrCreateToken(user.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil {
return nil, "", apperrors.Internal(err)
}
@@ -192,7 +192,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// - If token is expired (> expiryDays old), returns error (must re-login).
// - If token is in the renewal window (> refreshDays old), generates a new token.
// - If token is still fresh (< refreshDays old), returns the existing token (no-op).
func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
expiryDays := s.cfg.Security.TokenExpiryDays
if expiryDays <= 0 {
expiryDays = 90
@@ -203,7 +203,7 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
}
// Look up the token
authToken, err := s.userRepo.FindTokenByKey(tokenKey)
authToken, err := s.userRepo.WithContext(ctx).FindTokenByKey(tokenKey)
if err != nil {
return nil, apperrors.Unauthorized("error.invalid_token")
}
@@ -232,12 +232,12 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
// Token is in the renewal window — generate a new one
// Delete the old token
if err := s.userRepo.DeleteToken(tokenKey); err != nil {
if err := s.userRepo.WithContext(ctx).DeleteToken(tokenKey); err != nil {
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
}
// Create a new token
newToken, err := s.userRepo.CreateToken(userID)
newToken, err := s.userRepo.WithContext(ctx).CreateToken(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -249,18 +249,18 @@ func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.Ref
}
// Logout invalidates a user's token
func (s *AuthService) Logout(token string) error {
return s.userRepo.DeleteToken(token)
func (s *AuthService) Logout(ctx context.Context, token string) error {
return s.userRepo.WithContext(ctx).DeleteToken(token)
}
// GetCurrentUser returns the current authenticated user with profile
func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.FindByIDWithProfile(userID)
func (s *AuthService) GetCurrentUser(ctx context.Context, userID uint) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(userID)
if err != nil {
return nil, err
}
authProvider, err := s.userRepo.FindAuthProvider(userID)
authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
if err != nil {
// Log but don't fail - default to "email"
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
@@ -275,9 +275,9 @@ func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserRespons
// For email auth users, password verification is required.
// For social auth users, confirmation string "DELETE" is required.
// Returns a list of file URLs that need to be deleted from disk.
func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) ([]string, error) {
func (s *AuthService) DeleteAccount(ctx context.Context, userID uint, password, confirmation *string) ([]string, error) {
// Fetch user
user, err := s.userRepo.FindByID(userID)
user, err := s.userRepo.WithContext(ctx).FindByID(userID)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
return nil, apperrors.NotFound("error.user_not_found")
@@ -286,7 +286,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
}
// Determine auth provider
authProvider, err := s.userRepo.FindAuthProvider(userID)
authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -308,7 +308,7 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
// Start transaction and cascade delete
var fileURLs []string
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
urls, err := txRepo.DeleteUserCascade(userID)
if err != nil {
return err
@@ -324,15 +324,15 @@ func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string)
}
// UpdateProfile updates a user's profile
func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.FindByID(userID)
func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.WithContext(ctx).FindByID(userID)
if err != nil {
return nil, err
}
// Check if new email is taken (if email is being changed)
if req.Email != nil && *req.Email != user.Email {
exists, err := s.userRepo.ExistsByEmail(*req.Email)
exists, err := s.userRepo.WithContext(ctx).ExistsByEmail(*req.Email)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -349,17 +349,17 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
user.LastName = *req.LastName
}
if err := s.userRepo.Update(user); err != nil {
if err := s.userRepo.WithContext(ctx).Update(user); err != nil {
return nil, apperrors.Internal(err)
}
// Reload with profile
user, err = s.userRepo.FindByIDWithProfile(userID)
user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(userID)
if err != nil {
return nil, err
}
authProvider, err := s.userRepo.FindAuthProvider(userID)
authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID)
if err != nil {
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
authProvider = "email"
@@ -370,9 +370,9 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
}
// VerifyEmail verifies a user's email with a confirmation code
func (s *AuthService) VerifyEmail(userID uint, code string) error {
func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string) error {
// Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID)
profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -384,14 +384,14 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
// Check for test code when DEBUG_FIXED_CODES is enabled
if s.cfg.Server.DebugFixedCodes && code == "123456" {
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err)
}
return nil
}
// Find and validate confirmation code
confirmCode, err := s.userRepo.FindConfirmationCode(userID, code)
confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) {
return apperrors.BadRequest("error.invalid_verification_code")
@@ -403,12 +403,12 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
}
// Mark code as used
if err := s.userRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
return apperrors.Internal(err)
}
// Set profile as verified
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err)
}
@@ -416,9 +416,9 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
}
// ResendVerificationCode creates and returns a new verification code
func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
func (s *AuthService) ResendVerificationCode(ctx context.Context, userID uint) (string, error) {
// Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID)
profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID)
if err != nil {
return "", apperrors.Internal(err)
}
@@ -437,7 +437,7 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
}
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil {
if _, err := s.userRepo.WithContext(ctx).CreateConfirmationCode(userID, code, expiresAt); err != nil {
return "", apperrors.Internal(err)
}
@@ -445,9 +445,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
}
// ForgotPassword initiates the password reset process
func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) {
func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string, *models.User, error) {
// Find user by email
user, err := s.userRepo.FindByEmail(email)
user, err := s.userRepo.WithContext(ctx).FindByEmail(email)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
// Don't reveal that the email doesn't exist
@@ -457,7 +457,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
}
// Check rate limit
count, err := s.userRepo.CountRecentPasswordResetRequests(user.ID)
count, err := s.userRepo.WithContext(ctx).CountRecentPasswordResetRequests(user.ID)
if err != nil {
return "", nil, apperrors.Internal(err)
}
@@ -481,7 +481,7 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
return "", nil, apperrors.Internal(err)
}
if _, err := s.userRepo.CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
if _, err := s.userRepo.WithContext(ctx).CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
return "", nil, apperrors.Internal(err)
}
@@ -489,9 +489,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
}
// VerifyResetCode verifies a password reset code and returns a reset token
func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
func (s *AuthService) VerifyResetCode(ctx context.Context, email, code string) (string, error) {
// Find the reset code
resetCode, user, err := s.userRepo.FindPasswordResetCodeByEmail(email)
resetCode, user, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByEmail(email)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) {
return "", apperrors.BadRequest("error.invalid_verification_code")
@@ -507,7 +507,7 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
// Verify the code
if !resetCode.CheckCode(code) {
// Increment attempts
s.userRepo.IncrementResetCodeAttempts(resetCode.ID)
s.userRepo.WithContext(ctx).IncrementResetCodeAttempts(resetCode.ID)
return "", apperrors.BadRequest("error.invalid_verification_code")
}
@@ -528,9 +528,9 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
}
// ResetPassword resets the user's password using a reset token
func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
func (s *AuthService) ResetPassword(ctx context.Context, resetToken, newPassword string) error {
// Find the reset code by token
resetCode, err := s.userRepo.FindPasswordResetCodeByToken(resetToken)
resetCode, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByToken(resetToken)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) {
return apperrors.BadRequest("error.invalid_reset_token")
@@ -539,7 +539,7 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
}
// Get the user
user, err := s.userRepo.FindByID(resetCode.UserID)
user, err := s.userRepo.WithContext(ctx).FindByID(resetCode.UserID)
if err != nil {
return apperrors.Internal(err)
}
@@ -549,18 +549,18 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
return apperrors.Internal(err)
}
if err := s.userRepo.Update(user); err != nil {
if err := s.userRepo.WithContext(ctx).Update(user); err != nil {
return apperrors.Internal(err)
}
// Mark reset code as used
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
if err := s.userRepo.WithContext(ctx).MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
// Log error but don't fail
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
}
// Invalidate all existing tokens for this user (security measure)
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
if err := s.userRepo.WithContext(ctx).DeleteTokenByUserID(user.ID); err != nil {
// Log error but don't fail
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
}
@@ -583,10 +583,10 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
}
// 2. Check if this Apple ID is already linked to an account
existingAuth, err := s.userRepo.FindByAppleID(appleID)
existingAuth, err := s.userRepo.WithContext(ctx).FindByAppleID(appleID)
if err == nil && existingAuth != nil {
// User already linked with this Apple ID - log them in
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -596,13 +596,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
}
// Get or create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Update last login
_ = s.userRepo.UpdateLastLogin(user.ID)
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
return &responses.AppleSignInResponse{
Token: token.Key,
@@ -614,7 +614,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// 3. Check if email matches an existing user (for account linking)
email := getEmailFromRequest(req.Email, claims.Email)
if email != "" {
existingUser, err := s.userRepo.FindByEmail(email)
existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email)
if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user
log.Warn().
@@ -630,24 +630,24 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
Email: email,
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
}
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil {
return nil, apperrors.Internal(err)
}
// Mark as verified since Apple verified the email
_ = s.userRepo.SetProfileVerified(existingUser.ID, true)
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
// Get or create token
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID)
// B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -675,19 +675,19 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
randomPassword := generateResetToken()
_ = user.SetPassword(randomPassword)
if err := s.userRepo.Create(user); err != nil {
if err := s.userRepo.WithContext(ctx).Create(user); err != nil {
return nil, apperrors.Internal(err)
}
// Create profile (already verified since Apple verified)
profile, _ := s.userRepo.GetOrCreateProfile(user.ID)
profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID)
if profile != nil {
_ = s.userRepo.SetProfileVerified(user.ID, true)
_ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true)
}
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
}
}
@@ -699,18 +699,18 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
Email: getEmailOrDefault(email),
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
}
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil {
return nil, apperrors.Internal(err)
}
// Create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID)
user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -736,10 +736,10 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
// 2. Check if this Google ID is already linked to an account
existingAuth, err := s.userRepo.FindByGoogleID(googleID)
existingAuth, err := s.userRepo.WithContext(ctx).FindByGoogleID(googleID)
if err == nil && existingAuth != nil {
// User already linked with this Google ID - log them in
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -749,13 +749,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
// Get or create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Update last login
_ = s.userRepo.UpdateLastLogin(user.ID)
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
return &responses.GoogleSignInResponse{
Token: token.Key,
@@ -767,7 +767,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// 3. Check if email matches an existing user (for account linking)
email := tokenInfo.Email
if email != "" {
existingUser, err := s.userRepo.FindByEmail(email)
existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email)
if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user
log.Warn().
@@ -784,26 +784,26 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
Name: tokenInfo.Name,
Picture: tokenInfo.Picture,
}
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil {
if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil {
return nil, apperrors.Internal(err)
}
// Mark as verified since Google verified the email
if tokenInfo.IsEmailVerified() {
_ = s.userRepo.SetProfileVerified(existingUser.ID, true)
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
}
// Get or create token
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID)
// B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -831,19 +831,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
randomPassword := generateResetToken()
_ = user.SetPassword(randomPassword)
if err := s.userRepo.Create(user); err != nil {
if err := s.userRepo.WithContext(ctx).Create(user); err != nil {
return nil, apperrors.Internal(err)
}
// Create profile (already verified if Google verified email)
profile, _ := s.userRepo.GetOrCreateProfile(user.ID)
profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID)
if profile != nil && tokenInfo.IsEmailVerified() {
_ = s.userRepo.SetProfileVerified(user.ID, true)
_ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true)
}
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
}
}
@@ -856,18 +856,18 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
Name: tokenInfo.Name,
Picture: tokenInfo.Picture,
}
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil {
if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil {
return nil, apperrors.Internal(err)
}
// Create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
// B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID)
user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
+54 -53
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"net/http"
"testing"
"time"
@@ -53,7 +54,7 @@ func TestAuthService_Login(t *testing.T) {
Password: "Password123",
}
resp, err := service.Login(req)
resp, err := service.Login(context.Background(), req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "testuser", resp.User.Username)
@@ -74,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) {
Password: "Password123",
}
resp, err := service.Login(req)
resp, err := service.Login(context.Background(), req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
}
@@ -94,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) {
Password: "WrongPassword1",
}
_, err := service.Login(req)
_, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
@@ -111,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) {
Password: "Password123",
}
_, err := service.Login(req)
_, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
@@ -133,7 +134,7 @@ func TestAuthService_Login_InactiveUser(t *testing.T) {
Password: "Password123",
}
_, err := service.Login(req)
_, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
}
@@ -148,7 +149,7 @@ func TestAuthService_Register(t *testing.T) {
Password: "Password123",
}
resp, code, err := service.Register(req)
resp, code, err := service.Register(context.Background(), req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "newuser", resp.User.Username)
@@ -172,7 +173,7 @@ func TestAuthService_Register_DuplicateUsername(t *testing.T) {
Password: "Password123",
}
_, _, err := service.Register(req)
_, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
}
@@ -193,7 +194,7 @@ func TestAuthService_Register_DuplicateEmail(t *testing.T) {
Password: "Password123",
}
_, _, err := service.Register(req)
_, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
}
@@ -211,7 +212,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
// Create profile
userRepo.GetOrCreateProfile(user.ID)
resp, err := service.GetCurrentUser(user.ID)
resp, err := service.GetCurrentUser(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@test.com", resp.Email)
@@ -238,7 +239,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
LastName: &newLast,
}
resp, err := service.UpdateProfile(user.ID, req)
resp, err := service.UpdateProfile(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "John", resp.FirstName)
assert.Equal(t, "Doe", resp.LastName)
@@ -261,7 +262,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
Email: &takenEmail,
}
_, err := service.UpdateProfile(user2.ID, req)
_, err := service.UpdateProfile(context.Background(), user2.ID, req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken")
}
@@ -282,7 +283,7 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
}
// Same email should not trigger duplicate error
resp, err := service.UpdateProfile(user.ID, req)
resp, err := service.UpdateProfile(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "test@test.com", resp.Email)
}
@@ -298,7 +299,7 @@ func TestAuthService_VerifyEmail(t *testing.T) {
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
// Get the user ID
@@ -306,11 +307,11 @@ func TestAuthService_VerifyEmail(t *testing.T) {
require.NoError(t, err)
// Verify with the debug code
err = service.VerifyEmail(user.ID, "123456")
err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err)
// Verify again — should get already verified error
err = service.VerifyEmail(user.ID, "123456")
err = service.VerifyEmail(context.Background(), user.ID, "123456")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
@@ -323,7 +324,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
@@ -331,7 +332,7 @@ func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
// but a wrong code should use the normal path
err = service.VerifyEmail(user.ID, "000000")
err = service.VerifyEmail(context.Background(), user.ID, "000000")
assert.Error(t, err)
}
@@ -346,13 +347,13 @@ func TestAuthService_ResendVerificationCode(t *testing.T) {
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
code, err := service.ResendVerificationCode(user.ID)
code, err := service.ResendVerificationCode(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
}
@@ -366,16 +367,16 @@ func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(req)
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
err = service.VerifyEmail(user.ID, "123456")
err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err)
_, err = service.ResendVerificationCode(user.ID)
_, err = service.ResendVerificationCode(context.Background(), user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
@@ -390,10 +391,10 @@ func TestAuthService_ForgotPassword(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
code, user, err := service.ForgotPassword("test@test.com")
code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
assert.NotNil(t, user)
@@ -404,7 +405,7 @@ func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Should not reveal that email doesn't exist
code, user, err := service.ForgotPassword("nonexistent@test.com")
code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
require.NoError(t, err)
assert.Empty(t, code)
assert.Nil(t, user)
@@ -421,20 +422,20 @@ func TestAuthService_ResetPassword(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
// Forgot password
_, _, err = service.ForgotPassword("test@test.com")
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
// Verify reset code to get the token
resetToken, err := service.VerifyResetCode("test@test.com", "123456")
resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
require.NoError(t, err)
assert.NotEmpty(t, resetToken)
// Reset password
err = service.ResetPassword(resetToken, "NewPassword123")
err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
require.NoError(t, err)
// Login with new password
@@ -442,7 +443,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
Username: "testuser",
Password: "NewPassword123",
}
loginResp, err := service.Login(loginReq)
loginResp, err := service.Login(context.Background(), loginReq)
require.NoError(t, err)
assert.NotEmpty(t, loginResp.Token)
}
@@ -450,7 +451,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
service, _ := setupAuthService(t)
err := service.ResetPassword("invalid-token", "NewPassword123")
err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
}
@@ -471,15 +472,15 @@ func TestAuthService_Logout(t *testing.T) {
Username: "testuser",
Password: "Password123",
}
loginResp, err := service.Login(loginReq)
loginResp, err := service.Login(context.Background(), loginReq)
require.NoError(t, err)
// Logout
err = service.Logout(loginResp.Token)
err = service.Logout(context.Background(), loginResp.Token)
require.NoError(t, err)
// Token should be deleted — refreshing should fail
_, err = service.RefreshToken(loginResp.Token, user.ID)
_, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
assert.Error(t, err)
}
@@ -494,14 +495,14 @@ func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
password := "Password123"
_, err = service.DeleteAccount(user.ID, &password, nil)
_, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
require.NoError(t, err)
}
@@ -513,14 +514,14 @@ func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
wrongPassword := "WrongPassword1"
_, err = service.DeleteAccount(user.ID, &wrongPassword, nil)
_, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
@@ -532,13 +533,13 @@ func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
_, err = service.DeleteAccount(user.ID, nil, nil)
_, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
}
@@ -546,7 +547,7 @@ func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t)
password := "Password123"
_, err := service.DeleteAccount(99999, &password, nil)
_, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
}
@@ -658,7 +659,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) {
Password: "",
}
_, err := service.Login(req)
_, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
@@ -672,17 +673,17 @@ func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
// Make max allowed reset requests (3 based on setup)
for i := 0; i < 3; i++ {
_, _, err := service.ForgotPassword("test@test.com")
_, _, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
}
// The 4th should be rate limited
_, _, err = service.ForgotPassword("test@test.com")
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
assert.Error(t, err)
}
@@ -696,14 +697,14 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
_, _, err = service.ForgotPassword("test@test.com")
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
// Wrong code but with debug mode, "123456" works, "000000" should fail
_, err = service.VerifyResetCode("test@test.com", "000000")
_, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
assert.Error(t, err)
}
@@ -712,7 +713,7 @@ func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
_, err := service.VerifyResetCode("nonexistent@test.com", "123456")
_, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
assert.Error(t, err)
}
@@ -734,7 +735,7 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
Email: &newEmail,
}
resp, err := service.UpdateProfile(user.ID, req)
resp, err := service.UpdateProfile(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "newemail@test.com", resp.Email)
}
@@ -749,14 +750,14 @@ func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(registerReq)
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
emptyPw := ""
_, err = service.DeleteAccount(user.ID, &emptyPw, nil)
_, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
}
@@ -789,7 +790,7 @@ func TestAuthService_Register_CreatesProfile(t *testing.T) {
LastName: "Doe",
}
resp, _, err := service.Register(req)
resp, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, "profileuser", resp.User.Username)
+39 -38
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"errors"
"gorm.io/gorm"
@@ -33,8 +34,8 @@ func NewContractorService(contractorRepo *repositories.ContractorRepository, res
}
// GetContractor gets a contractor by ID with access check
func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
func (s *ContractorService) GetContractor(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found")
@@ -43,7 +44,7 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
@@ -55,14 +56,14 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
// Access rules:
// - If contractor has no residence: only the creator has access
// - If contractor has a residence: all users with access to that residence
func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, userID uint) bool {
func (s *ContractorService) hasContractorAccess(ctx context.Context, contractor *models.Contractor, userID uint) bool {
if contractor.ResidenceID == nil {
// Personal contractor - only creator has access
return contractor.CreatedByID == userID
}
// Residence contractor - check residence access
hasAccess, err := s.residenceRepo.HasAccess(*contractor.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*contractor.ResidenceID, userID)
if err != nil {
return false
}
@@ -70,15 +71,15 @@ func (s *ContractorService) hasContractorAccess(contractor *models.Contractor, u
}
// ListContractors lists all contractors accessible to a user
func (s *ContractorService) ListContractors(userID uint) ([]responses.ContractorResponse, error) {
func (s *ContractorService) ListContractors(ctx context.Context, userID uint) ([]responses.ContractorResponse, error) {
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
// FindByUser now handles both personal and residence contractors
contractors, err := s.contractorRepo.FindByUser(userID, residenceIDs)
contractors, err := s.contractorRepo.WithContext(ctx).FindByUser(userID, residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -87,10 +88,10 @@ func (s *ContractorService) ListContractors(userID uint) ([]responses.Contractor
}
// CreateContractor creates a new contractor
func (s *ContractorService) CreateContractor(req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) {
func (s *ContractorService) CreateContractor(ctx context.Context, req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) {
// If residence is provided, check access
if req.ResidenceID != nil {
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -122,19 +123,19 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
IsActive: true,
}
if err := s.contractorRepo.Create(contractor); err != nil {
if err := s.contractorRepo.WithContext(ctx).Create(contractor); err != nil {
return nil, apperrors.Internal(err)
}
// Set specialties if provided
if len(req.SpecialtyIDs) > 0 {
if err := s.contractorRepo.SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
return nil, apperrors.Internal(err)
}
}
// Reload with relations
contractor, reloadErr := s.contractorRepo.FindByID(contractor.ID)
contractor, reloadErr := s.contractorRepo.WithContext(ctx).FindByID(contractor.ID)
if reloadErr != nil {
return nil, apperrors.Internal(reloadErr)
}
@@ -144,8 +145,8 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
}
// UpdateContractor updates a contractor
func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
func (s *ContractorService) UpdateContractor(ctx context.Context, contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found")
@@ -154,7 +155,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
@@ -198,7 +199,7 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
// If residence_id is provided, verify the user has access to the NEW residence.
// This prevents an attacker from reassigning a contractor to someone else's residence.
if req.ResidenceID != nil {
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(*req.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -211,19 +212,19 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
// removed the residence association - contractor becomes personal
contractor.ResidenceID = req.ResidenceID
if err := s.contractorRepo.Update(contractor); err != nil {
if err := s.contractorRepo.WithContext(ctx).Update(contractor); err != nil {
return nil, apperrors.Internal(err)
}
// Update specialties if provided
if req.SpecialtyIDs != nil {
if err := s.contractorRepo.SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
if err := s.contractorRepo.WithContext(ctx).SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
return nil, apperrors.Internal(err)
}
}
// Reload
contractor, err = s.contractorRepo.FindByID(contractorID)
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -233,8 +234,8 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
}
// DeleteContractor soft-deletes a contractor
func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
contractor, err := s.contractorRepo.FindByID(contractorID)
func (s *ContractorService) DeleteContractor(ctx context.Context, contractorID, userID uint) error {
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.contractor_not_found")
@@ -243,11 +244,11 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
if !s.hasContractorAccess(ctx, contractor, userID) {
return apperrors.Forbidden("error.contractor_access_denied")
}
if err := s.contractorRepo.Delete(contractorID); err != nil {
if err := s.contractorRepo.WithContext(ctx).Delete(contractorID); err != nil {
return apperrors.Internal(err)
}
@@ -255,8 +256,8 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
}
// ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor
func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found")
@@ -265,17 +266,17 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
_, err = s.contractorRepo.ToggleFavorite(contractorID)
_, err = s.contractorRepo.WithContext(ctx).ToggleFavorite(contractorID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Re-fetch the contractor to get the updated state with all relations
contractor, err = s.contractorRepo.FindByID(contractorID)
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -285,8 +286,8 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
}
// GetContractorTasks gets all tasks for a contractor
func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]responses.TaskResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
func (s *ContractorService) GetContractorTasks(ctx context.Context, contractorID, userID uint) ([]responses.TaskResponse, error) {
contractor, err := s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found")
@@ -295,11 +296,11 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
if !s.hasContractorAccess(ctx, contractor, userID) {
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
tasks, err := s.contractorRepo.GetTasksForContractor(contractorID)
tasks, err := s.contractorRepo.WithContext(ctx).GetTasksForContractor(contractorID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -308,9 +309,9 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
}
// ListContractorsByResidence lists all contractors for a specific residence
func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint) ([]responses.ContractorResponse, error) {
func (s *ContractorService) ListContractorsByResidence(ctx context.Context, residenceID, userID uint) ([]responses.ContractorResponse, error) {
// Check user has access to the residence
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(residenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -318,7 +319,7 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
return nil, apperrors.Forbidden("error.residence_access_denied")
}
contractors, err := s.contractorRepo.FindByResidence(residenceID)
contractors, err := s.contractorRepo.WithContext(ctx).FindByResidence(residenceID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -327,8 +328,8 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
}
// GetSpecialties returns all contractor specialties
func (s *ContractorService) GetSpecialties() ([]responses.ContractorSpecialtyResponse, error) {
specialties, err := s.contractorRepo.GetAllSpecialties()
func (s *ContractorService) GetSpecialties(ctx context.Context) ([]responses.ContractorSpecialtyResponse, error) {
specialties, err := s.contractorRepo.WithContext(ctx).GetAllSpecialties()
if err != nil {
return nil, apperrors.Internal(err)
}
+38 -37
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"net/http"
"testing"
@@ -41,7 +42,7 @@ func TestContractorService_CreateContractor(t *testing.T) {
Email: "bob@plumbing.com",
}
resp, err := service.CreateContractor(req, user.ID)
resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "Bob's Plumbing", resp.Name)
@@ -63,7 +64,7 @@ func TestContractorService_CreateContractor_Personal(t *testing.T) {
Name: "Personal Handyman",
}
resp, err := service.CreateContractor(req, user.ID)
resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err)
assert.Equal(t, "Personal Handyman", resp.Name)
}
@@ -84,7 +85,7 @@ func TestContractorService_CreateContractor_AccessDenied(t *testing.T) {
Name: "Unauthorized Contractor",
}
_, err := service.CreateContractor(req, other.ID)
_, err := service.CreateContractor(context.Background(), req, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
@@ -105,7 +106,7 @@ func TestContractorService_CreateContractor_WithFavorite(t *testing.T) {
IsFavorite: &isFav,
}
resp, err := service.CreateContractor(req, user.ID)
resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err)
assert.True(t, resp.IsFavorite)
}
@@ -123,7 +124,7 @@ func TestContractorService_GetContractor(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
resp, err := service.GetContractor(contractor.ID, user.ID)
resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID)
require.NoError(t, err)
assert.Equal(t, contractor.ID, resp.ID)
assert.Equal(t, "Test Contractor", resp.Name)
@@ -138,7 +139,7 @@ func TestContractorService_GetContractor_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetContractor(9999, user.ID)
_, err := service.GetContractor(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
@@ -154,7 +155,7 @@ func TestContractorService_GetContractor_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.GetContractor(contractor.ID, other.ID)
_, err := service.GetContractor(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
@@ -171,7 +172,7 @@ func TestContractorService_GetContractor_SharedUserHasAccess(t *testing.T) {
residenceRepo.AddUser(residence.ID, shared.ID)
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor")
resp, err := service.GetContractor(contractor.ID, shared.ID)
resp, err := service.GetContractor(context.Background(), contractor.ID, shared.ID)
require.NoError(t, err)
assert.Equal(t, "Shared Contractor", resp.Name)
}
@@ -190,7 +191,7 @@ func TestContractorService_ListContractors(t *testing.T) {
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
resp, err := service.ListContractors(user.ID)
resp, err := service.ListContractors(context.Background(), user.ID)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
@@ -208,11 +209,11 @@ func TestContractorService_DeleteContractor(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
err := service.DeleteContractor(contractor.ID, user.ID)
err := service.DeleteContractor(context.Background(), contractor.ID, user.ID)
require.NoError(t, err)
// Should not be found after deletion
_, err = service.GetContractor(contractor.ID, user.ID)
_, err = service.GetContractor(context.Background(), contractor.ID, user.ID)
assert.Error(t, err)
}
@@ -225,7 +226,7 @@ func TestContractorService_DeleteContractor_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteContractor(9999, user.ID)
err := service.DeleteContractor(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
@@ -241,7 +242,7 @@ func TestContractorService_DeleteContractor_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
err := service.DeleteContractor(contractor.ID, other.ID)
err := service.DeleteContractor(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
@@ -259,17 +260,17 @@ func TestContractorService_ToggleFavorite(t *testing.T) {
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
// Initially not favorite
resp, err := service.GetContractor(contractor.ID, user.ID)
resp, err := service.GetContractor(context.Background(), contractor.ID, user.ID)
require.NoError(t, err)
assert.False(t, resp.IsFavorite)
// Toggle to favorite
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID)
require.NoError(t, err)
assert.True(t, resp.IsFavorite)
// Toggle back
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
resp, err = service.ToggleFavorite(context.Background(), contractor.ID, user.ID)
require.NoError(t, err)
assert.False(t, resp.IsFavorite)
}
@@ -283,7 +284,7 @@ func TestContractorService_ToggleFavorite_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.ToggleFavorite(9999, user.ID)
_, err := service.ToggleFavorite(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
@@ -299,7 +300,7 @@ func TestContractorService_ToggleFavorite_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.ToggleFavorite(contractor.ID, other.ID)
_, err := service.ToggleFavorite(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
@@ -317,7 +318,7 @@ func TestContractorService_ListContractorsByResidence(t *testing.T) {
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
@@ -333,7 +334,7 @@ func TestContractorService_ListContractorsByResidence_AccessDenied(t *testing.T)
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.ListContractorsByResidence(residence.ID, other.ID)
_, err := service.ListContractorsByResidence(context.Background(), residence.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
@@ -348,7 +349,7 @@ func TestContractorService_GetContractorTasks_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetContractorTasks(9999, user.ID)
_, err := service.GetContractorTasks(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
@@ -364,7 +365,7 @@ func TestContractorService_GetContractorTasks_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
_, err := service.GetContractorTasks(contractor.ID, other.ID)
_, err := service.GetContractorTasks(context.Background(), contractor.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
@@ -379,7 +380,7 @@ func TestContractorService_GetContractorTasks_Empty(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
resp, err := service.GetContractorTasks(contractor.ID, user.ID)
resp, err := service.GetContractorTasks(context.Background(), contractor.ID, user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -393,7 +394,7 @@ func TestContractorService_GetSpecialties(t *testing.T) {
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
resp, err := service.GetSpecialties()
resp, err := service.GetSpecialties(context.Background())
require.NoError(t, err)
// SeedLookupData creates 4 specialties
assert.Len(t, resp, 4)
@@ -413,7 +414,7 @@ func TestContractorService_UpdateContractor_NotFound(t *testing.T) {
newName := "Won't Work"
req := &requests.UpdateContractorRequest{Name: &newName}
_, err := service.UpdateContractor(9999, user.ID, req)
_, err := service.UpdateContractor(context.Background(), 9999, user.ID, req)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
}
@@ -432,7 +433,7 @@ func TestContractorService_UpdateContractor_AccessDenied(t *testing.T) {
newName := "Hacked"
req := &requests.UpdateContractorRequest{Name: &newName}
_, err := service.UpdateContractor(contractor.ID, other.ID, req)
_, err := service.UpdateContractor(context.Background(), contractor.ID, other.ID, req)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
@@ -461,7 +462,7 @@ func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
ResidenceID: &newResidenceID,
}
_, err := service.UpdateContractor(contractor.ID, attacker.ID, req)
_, err := service.UpdateContractor(context.Background(), contractor.ID, attacker.ID, req)
require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
}
@@ -486,7 +487,7 @@ func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) {
ResidenceID: &newResidenceID,
}
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req)
require.NoError(t, err, "should allow reassigning contractor to a residence the user owns")
require.NotNil(t, resp)
require.Equal(t, "Updated Contractor", resp.Name)
@@ -508,7 +509,7 @@ func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
ResidenceID: nil,
}
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
resp, err := service.UpdateContractor(context.Background(), contractor.ID, owner.ID, req)
require.NoError(t, err, "should allow removing residence association")
require.NotNil(t, resp)
}
@@ -555,7 +556,7 @@ func TestContractorService_UpdateContractor_PartialUpdate(t *testing.T) {
ResidenceID: &residence.ID,
}
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated Plumber", resp.Name)
assert.Equal(t, "555-9999", resp.Phone)
@@ -588,7 +589,7 @@ func TestContractorService_UpdateContractor_WithSpecialties(t *testing.T) {
ResidenceID: &residence.ID,
}
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
resp, err := service.UpdateContractor(context.Background(), contractor.ID, user.ID, req)
require.NoError(t, err)
assert.NotNil(t, resp)
}
@@ -615,7 +616,7 @@ func TestContractorService_CreateContractor_WithSpecialties(t *testing.T) {
SpecialtyIDs: []uint{specialties[0].ID},
}
resp, err := service.CreateContractor(req, user.ID)
resp, err := service.CreateContractor(context.Background(), req, user.ID)
require.NoError(t, err)
assert.Equal(t, "Specialized Plumber", resp.Name)
}
@@ -631,7 +632,7 @@ func TestContractorService_ListContractors_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// No residence, no contractors
resp, err := service.ListContractors(user.ID)
resp, err := service.ListContractors(context.Background(), user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -648,7 +649,7 @@ func TestContractorService_ListContractorsByResidence_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House")
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
resp, err := service.ListContractorsByResidence(context.Background(), residence.ID, user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -669,14 +670,14 @@ func TestContractorService_PersonalContractor_OnlyCreatorAccess(t *testing.T) {
req := &requests.CreateContractorRequest{
Name: "Personal Plumber",
}
resp, err := service.CreateContractor(req, creator.ID)
resp, err := service.CreateContractor(context.Background(), req, creator.ID)
require.NoError(t, err)
// Creator can access
_, err = service.GetContractor(resp.ID, creator.ID)
_, err = service.GetContractor(context.Background(), resp.ID, creator.ID)
require.NoError(t, err)
// Other user cannot
_, err = service.GetContractor(resp.ID, other.ID)
_, err = service.GetContractor(context.Background(), resp.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
}
+44 -43
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"errors"
"gorm.io/gorm"
@@ -34,8 +35,8 @@ func NewDocumentService(documentRepo *repositories.DocumentRepository, residence
}
// GetDocument gets a document by ID with access check
func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
func (s *DocumentService) GetDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found")
@@ -44,7 +45,7 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -57,9 +58,9 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
}
// ListDocuments lists all documents accessible to a user, with optional filters.
func (s *DocumentService) ListDocuments(userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) {
func (s *DocumentService) ListDocuments(ctx context.Context, userID uint, filter *repositories.DocumentFilter) ([]responses.DocumentResponse, error) {
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -83,7 +84,7 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume
residenceIDs = []uint{*filter.ResidenceID}
}
documents, err := s.documentRepo.FindByUserFiltered(residenceIDs, filter)
documents, err := s.documentRepo.WithContext(ctx).FindByUserFiltered(residenceIDs, filter)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -92,9 +93,9 @@ func (s *DocumentService) ListDocuments(userID uint, filter *repositories.Docume
}
// ListWarranties lists all warranty documents
func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentResponse, error) {
func (s *DocumentService) ListWarranties(ctx context.Context, userID uint) ([]responses.DocumentResponse, error) {
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByUser(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -103,7 +104,7 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
return []responses.DocumentResponse{}, nil
}
documents, err := s.documentRepo.FindWarranties(residenceIDs)
documents, err := s.documentRepo.WithContext(ctx).FindWarranties(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -112,9 +113,9 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
}
// CreateDocument creates a new document
func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) {
func (s *DocumentService) CreateDocument(ctx context.Context, req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) {
// Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(req.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -147,7 +148,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
IsActive: true,
}
if err := s.documentRepo.Create(document); err != nil {
if err := s.documentRepo.WithContext(ctx).Create(document); err != nil {
return nil, apperrors.Internal(err)
}
@@ -158,7 +159,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
DocumentID: document.ID,
ImageURL: imageURL,
}
if err := s.documentRepo.CreateDocumentImage(img); err != nil {
if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil {
// Log but don't fail the whole operation
continue
}
@@ -166,7 +167,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
}
// Reload with relations
document, err = s.documentRepo.FindByID(document.ID)
document, err = s.documentRepo.WithContext(ctx).FindByID(document.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -176,8 +177,8 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
}
// UpdateDocument updates a document
func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
func (s *DocumentService) UpdateDocument(ctx context.Context, documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found")
@@ -186,7 +187,7 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -238,12 +239,12 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
document.TaskID = req.TaskID
}
if err := s.documentRepo.Update(document); err != nil {
if err := s.documentRepo.WithContext(ctx).Update(document); err != nil {
return nil, apperrors.Internal(err)
}
// Reload
document, err = s.documentRepo.FindByID(documentID)
document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -253,8 +254,8 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
}
// DeleteDocument soft-deletes a document
func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
document, err := s.documentRepo.FindByID(documentID)
func (s *DocumentService) DeleteDocument(ctx context.Context, documentID, userID uint) error {
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.document_not_found")
@@ -263,7 +264,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -271,7 +272,7 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
return apperrors.Forbidden("error.document_access_denied")
}
if err := s.documentRepo.Delete(documentID); err != nil {
if err := s.documentRepo.WithContext(ctx).Delete(documentID); err != nil {
return apperrors.Internal(err)
}
@@ -279,15 +280,15 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
}
// ActivateDocument activates a document
func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
func (s *DocumentService) ActivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
// First check if document exists (even if inactive)
var document models.Document
if err := s.documentRepo.FindByIDIncludingInactive(documentID, &document); err != nil {
if err := s.documentRepo.WithContext(ctx).FindByIDIncludingInactive(documentID, &document); err != nil {
return nil, apperrors.NotFound("error.document_not_found")
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -295,12 +296,12 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
return nil, apperrors.Forbidden("error.document_access_denied")
}
if err := s.documentRepo.Activate(documentID); err != nil {
if err := s.documentRepo.WithContext(ctx).Activate(documentID); err != nil {
return nil, apperrors.Internal(err)
}
// Reload
doc, err := s.documentRepo.FindByID(documentID)
doc, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -310,8 +311,8 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
}
// DeactivateDocument deactivates a document
func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
func (s *DocumentService) DeactivateDocument(ctx context.Context, documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found")
@@ -320,7 +321,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -328,7 +329,7 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
return nil, apperrors.Forbidden("error.document_access_denied")
}
if err := s.documentRepo.Deactivate(documentID); err != nil {
if err := s.documentRepo.WithContext(ctx).Deactivate(documentID); err != nil {
return nil, apperrors.Internal(err)
}
@@ -338,8 +339,8 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
}
// UploadDocumentImage adds an image to an existing document
func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
func (s *DocumentService) UploadDocumentImage(ctx context.Context, documentID, userID uint, imageURL, caption string) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found")
@@ -348,7 +349,7 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -361,12 +362,12 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
ImageURL: imageURL,
Caption: caption,
}
if err := s.documentRepo.CreateDocumentImage(img); err != nil {
if err := s.documentRepo.WithContext(ctx).CreateDocumentImage(img); err != nil {
return nil, apperrors.Internal(err)
}
// Reload with relations
document, err = s.documentRepo.FindByID(documentID)
document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -376,9 +377,9 @@ func (s *DocumentService) UploadDocumentImage(documentID, userID uint, imageURL,
}
// DeleteDocumentImage removes an image from a document
func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint) (*responses.DocumentResponse, error) {
func (s *DocumentService) DeleteDocumentImage(ctx context.Context, documentID, imageID, userID uint) (*responses.DocumentResponse, error) {
// Find the image first
image, err := s.documentRepo.FindImageByID(imageID)
image, err := s.documentRepo.WithContext(ctx).FindImageByID(imageID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_image_not_found")
@@ -392,7 +393,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
}
// Find parent document to check access
document, err := s.documentRepo.FindByID(documentID)
document, err := s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.document_not_found")
@@ -401,7 +402,7 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -409,12 +410,12 @@ func (s *DocumentService) DeleteDocumentImage(documentID, imageID, userID uint)
return nil, apperrors.Forbidden("error.document_access_denied")
}
if err := s.documentRepo.DeleteDocumentImage(imageID); err != nil {
if err := s.documentRepo.WithContext(ctx).DeleteDocumentImage(imageID); err != nil {
return nil, apperrors.Internal(err)
}
// Reload with relations
document, err = s.documentRepo.FindByID(documentID)
document, err = s.documentRepo.WithContext(ctx).FindByID(documentID)
if err != nil {
return nil, apperrors.Internal(err)
}
+42 -41
View File
@@ -1,6 +1,7 @@
package services
import (
"context"
"net/http"
"testing"
@@ -42,7 +43,7 @@ func TestDocumentService_CreateDocument(t *testing.T) {
FileName: "manual.pdf",
}
resp, err := service.CreateDocument(req, user.ID)
resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "Furnace Manual", resp.Title)
@@ -64,7 +65,7 @@ func TestDocumentService_CreateDocument_DefaultType(t *testing.T) {
// DocumentType not set — should default to "general"
}
resp, err := service.CreateDocument(req, user.ID)
resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err)
assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType)
}
@@ -84,7 +85,7 @@ func TestDocumentService_CreateDocument_WithImages(t *testing.T) {
ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"},
}
resp, err := service.CreateDocument(req, user.ID)
resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "Receipt with photos", resp.Title)
@@ -105,7 +106,7 @@ func TestDocumentService_CreateDocument_AccessDenied(t *testing.T) {
Title: "Unauthorized Doc",
}
_, err := service.CreateDocument(req, other.ID)
_, err := service.CreateDocument(context.Background(), req, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
@@ -121,7 +122,7 @@ func TestDocumentService_GetDocument(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
resp, err := service.GetDocument(doc.ID, user.ID)
resp, err := service.GetDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err)
assert.Equal(t, doc.ID, resp.ID)
assert.Equal(t, "Test Doc", resp.Title)
@@ -135,7 +136,7 @@ func TestDocumentService_GetDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.GetDocument(9999, user.ID)
_, err := service.GetDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
@@ -150,7 +151,7 @@ func TestDocumentService_GetDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.GetDocument(doc.ID, other.ID)
_, err := service.GetDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -173,7 +174,7 @@ func TestDocumentService_UpdateDocument(t *testing.T) {
Description: &newDesc,
}
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated Title", resp.Title)
assert.Equal(t, "Updated description", resp.Description)
@@ -190,7 +191,7 @@ func TestDocumentService_UpdateDocument_NotFound(t *testing.T) {
newTitle := "Won't Work"
req := &requests.UpdateDocumentRequest{Title: &newTitle}
_, err := service.UpdateDocument(9999, user.ID, req)
_, err := service.UpdateDocument(context.Background(), 9999, user.ID, req)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
@@ -208,7 +209,7 @@ func TestDocumentService_UpdateDocument_AccessDenied(t *testing.T) {
newTitle := "Hacked"
req := &requests.UpdateDocumentRequest{Title: &newTitle}
_, err := service.UpdateDocument(doc.ID, other.ID, req)
_, err := service.UpdateDocument(context.Background(), doc.ID, other.ID, req)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -225,7 +226,7 @@ func TestDocumentService_UpdateDocument_ChangeType(t *testing.T) {
newType := models.DocumentTypeWarranty
req := &requests.UpdateDocumentRequest{DocumentType: &newType}
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
}
@@ -242,11 +243,11 @@ func TestDocumentService_DeleteDocument(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
err := service.DeleteDocument(doc.ID, user.ID)
err := service.DeleteDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err)
// Should not be found after deletion
_, err = service.GetDocument(doc.ID, user.ID)
_, err = service.GetDocument(context.Background(), doc.ID, user.ID)
assert.Error(t, err)
}
@@ -258,7 +259,7 @@ func TestDocumentService_DeleteDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDocument(9999, user.ID)
err := service.DeleteDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
@@ -273,7 +274,7 @@ func TestDocumentService_DeleteDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
err := service.DeleteDocument(doc.ID, other.ID)
err := service.DeleteDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -290,7 +291,7 @@ func TestDocumentService_ListDocuments(t *testing.T) {
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
resp, err := service.ListDocuments(user.ID, nil)
resp, err := service.ListDocuments(context.Background(), user.ID, nil)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
@@ -303,7 +304,7 @@ func TestDocumentService_ListDocuments_NoResidences(t *testing.T) {
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListDocuments(user.ID, nil)
resp, err := service.ListDocuments(context.Background(), user.ID, nil)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -321,7 +322,7 @@ func TestDocumentService_ListDocuments_FilterByResidence(t *testing.T) {
testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B")
filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID}
resp, err := service.ListDocuments(user.ID, filter)
resp, err := service.ListDocuments(context.Background(), user.ID, filter)
require.NoError(t, err)
assert.Len(t, resp, 1)
assert.Equal(t, "Doc A", resp[0].Title)
@@ -340,7 +341,7 @@ func TestDocumentService_ListDocuments_FilterByResidence_AccessDenied(t *testing
testutil.CreateTestResidence(t, db, other.ID, "Other House")
filter := &repositories.DocumentFilter{ResidenceID: &residence.ID}
_, err := service.ListDocuments(other.ID, filter)
_, err := service.ListDocuments(context.Background(), other.ID, filter)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
@@ -369,7 +370,7 @@ func TestDocumentService_ListWarranties(t *testing.T) {
// Create a general doc
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
resp, err := service.ListWarranties(user.ID)
resp, err := service.ListWarranties(context.Background(), user.ID)
require.NoError(t, err)
assert.Len(t, resp, 1)
assert.Equal(t, "HVAC Warranty", resp[0].Title)
@@ -383,7 +384,7 @@ func TestDocumentService_ListWarranties_NoResidences(t *testing.T) {
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
resp, err := service.ListWarranties(user.ID)
resp, err := service.ListWarranties(context.Background(), user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -400,7 +401,7 @@ func TestDocumentService_DeactivateDocument(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate")
resp, err := service.DeactivateDocument(doc.ID, user.ID)
resp, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err)
assert.False(t, resp.IsActive)
}
@@ -413,7 +414,7 @@ func TestDocumentService_DeactivateDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.DeactivateDocument(9999, user.ID)
_, err := service.DeactivateDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
@@ -428,7 +429,7 @@ func TestDocumentService_DeactivateDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.DeactivateDocument(doc.ID, other.ID)
_, err := service.DeactivateDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -444,7 +445,7 @@ func TestDocumentService_UploadDocumentImage(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
resp, err := service.UploadDocumentImage(doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
resp, err := service.UploadDocumentImage(context.Background(), doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
require.NoError(t, err)
assert.NotNil(t, resp)
}
@@ -457,7 +458,7 @@ func TestDocumentService_UploadDocumentImage_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.UploadDocumentImage(9999, user.ID, "https://example.com/photo.jpg", "")
_, err := service.UploadDocumentImage(context.Background(), 9999, user.ID, "https://example.com/photo.jpg", "")
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
@@ -472,7 +473,7 @@ func TestDocumentService_UploadDocumentImage_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.UploadDocumentImage(doc.ID, other.ID, "https://example.com/photo.jpg", "")
_, err := service.UploadDocumentImage(context.Background(), doc.ID, other.ID, "https://example.com/photo.jpg", "")
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -496,7 +497,7 @@ func TestDocumentService_DeleteDocumentImage(t *testing.T) {
err := db.Create(img).Error
require.NoError(t, err)
resp, err := service.DeleteDocumentImage(doc.ID, img.ID, user.ID)
resp, err := service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
}
@@ -511,7 +512,7 @@ func TestDocumentService_DeleteDocumentImage_NotFound(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
_, err := service.DeleteDocumentImage(doc.ID, 9999, user.ID)
_, err := service.DeleteDocumentImage(context.Background(), doc.ID, 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
}
@@ -535,7 +536,7 @@ func TestDocumentService_DeleteDocumentImage_WrongDocument(t *testing.T) {
require.NoError(t, err)
// Try to delete the image specifying doc2
_, err = service.DeleteDocumentImage(doc2.ID, img.ID, user.ID)
_, err = service.DeleteDocumentImage(context.Background(), doc2.ID, img.ID, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
}
@@ -557,7 +558,7 @@ func TestDocumentService_DeleteDocumentImage_AccessDenied(t *testing.T) {
err := db.Create(img).Error
require.NoError(t, err)
_, err = service.DeleteDocumentImage(doc.ID, img.ID, other.ID)
_, err = service.DeleteDocumentImage(context.Background(), doc.ID, img.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -575,7 +576,7 @@ func TestDocumentService_GetDocument_SharedUserHasAccess(t *testing.T) {
residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
resp, err := service.GetDocument(doc.ID, shared.ID)
resp, err := service.GetDocument(context.Background(), doc.ID, shared.ID)
require.NoError(t, err)
assert.Equal(t, "Shared Doc", resp.Title)
}
@@ -593,11 +594,11 @@ func TestDocumentService_ActivateDocument(t *testing.T) {
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate")
// Deactivate first
_, err := service.DeactivateDocument(doc.ID, user.ID)
_, err := service.DeactivateDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err)
// Now activate
resp, err := service.ActivateDocument(doc.ID, user.ID)
resp, err := service.ActivateDocument(context.Background(), doc.ID, user.ID)
require.NoError(t, err)
assert.True(t, resp.IsActive)
}
@@ -610,7 +611,7 @@ func TestDocumentService_ActivateDocument_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
_, err := service.ActivateDocument(9999, user.ID)
_, err := service.ActivateDocument(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
}
@@ -625,7 +626,7 @@ func TestDocumentService_ActivateDocument_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
_, err := service.ActivateDocument(doc.ID, other.ID)
_, err := service.ActivateDocument(context.Background(), doc.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
}
@@ -646,7 +647,7 @@ func TestDocumentService_CreateDocument_WithEmptyImageURL(t *testing.T) {
ImageURLs: []string{"", "https://example.com/img.jpg", ""},
}
resp, err := service.CreateDocument(req, user.ID)
resp, err := service.CreateDocument(context.Background(), req, user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
}
@@ -687,7 +688,7 @@ func TestDocumentService_UpdateDocument_AllFields(t *testing.T) {
ModelNumber: &newModel,
}
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
resp, err := service.UpdateDocument(context.Background(), doc.ID, user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated", resp.Title)
assert.Equal(t, "New description", resp.Description)
@@ -720,7 +721,7 @@ func TestDocumentService_ListDocuments_FilterByType(t *testing.T) {
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)}
resp, err := service.ListDocuments(user.ID, filter)
resp, err := service.ListDocuments(context.Background(), user.ID, filter)
require.NoError(t, err)
assert.Len(t, resp, 1)
assert.Equal(t, "Warranty Doc", resp[0].Title)
@@ -742,7 +743,7 @@ func TestDocumentService_SharedUser_CanUpdate(t *testing.T) {
newTitle := "Updated by shared user"
req := &requests.UpdateDocumentRequest{Title: &newTitle}
resp, err := service.UpdateDocument(doc.ID, shared.ID, req)
resp, err := service.UpdateDocument(context.Background(), doc.ID, shared.ID, req)
require.NoError(t, err)
assert.Equal(t, "Updated by shared user", resp.Title)
}
@@ -759,6 +760,6 @@ func TestDocumentService_SharedUser_CanDelete(t *testing.T) {
residenceRepo.AddUser(residence.ID, shared.ID)
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
err := service.DeleteDocument(doc.ID, shared.ID)
err := service.DeleteDocument(context.Background(), doc.ID, shared.ID)
require.NoError(t, err)
}
+51 -51
View File
@@ -44,8 +44,8 @@ func NewNotificationService(notificationRepo *repositories.NotificationRepositor
// === Notifications ===
// GetNotifications gets notifications for a user
func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ([]NotificationResponse, error) {
notifications, err := s.notificationRepo.FindByUser(userID, limit, offset)
func (s *NotificationService) GetNotifications(ctx context.Context, userID uint, limit, offset int) ([]NotificationResponse, error) {
notifications, err := s.notificationRepo.WithContext(ctx).FindByUser(userID, limit, offset)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -58,8 +58,8 @@ func (s *NotificationService) GetNotifications(userID uint, limit, offset int) (
}
// GetUnreadCount gets the count of unread notifications
func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
count, err := s.notificationRepo.CountUnread(userID)
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID uint) (int64, error) {
count, err := s.notificationRepo.WithContext(ctx).CountUnread(userID)
if err != nil {
return 0, apperrors.Internal(err)
}
@@ -67,8 +67,8 @@ func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
}
// MarkAsRead marks a notification as read
func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
notification, err := s.notificationRepo.FindByID(notificationID)
func (s *NotificationService) MarkAsRead(ctx context.Context, notificationID, userID uint) error {
notification, err := s.notificationRepo.WithContext(ctx).FindByID(notificationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.notification_not_found")
@@ -80,15 +80,15 @@ func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
return apperrors.NotFound("error.notification_not_found")
}
if err := s.notificationRepo.MarkAsRead(notificationID); err != nil {
if err := s.notificationRepo.WithContext(ctx).MarkAsRead(notificationID); err != nil {
return apperrors.Internal(err)
}
return nil
}
// MarkAllAsRead marks all notifications as read
func (s *NotificationService) MarkAllAsRead(userID uint) error {
if err := s.notificationRepo.MarkAllAsRead(userID); err != nil {
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID uint) error {
if err := s.notificationRepo.WithContext(ctx).MarkAllAsRead(userID); err != nil {
return apperrors.Internal(err)
}
return nil
@@ -97,7 +97,7 @@ func (s *NotificationService) MarkAllAsRead(userID uint) error {
// CreateAndSendNotification creates a notification and sends it via push
func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error {
// Check user preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -117,12 +117,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
Data: string(dataJSON),
}
if err := s.notificationRepo.Create(notification); err != nil {
if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil {
return apperrors.Internal(err)
}
// Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -144,12 +144,12 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
if s.pushClient != nil {
err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData)
if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error())
s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error())
return apperrors.Internal(err)
}
}
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil {
if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil {
return apperrors.Internal(err)
}
return nil
@@ -178,8 +178,8 @@ func (s *NotificationService) isNotificationEnabled(prefs *models.NotificationPr
// === Notification Preferences ===
// GetPreferences gets notification preferences for a user
func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
func (s *NotificationService) GetPreferences(ctx context.Context, userID uint) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -196,7 +196,7 @@ func validateHourField(val *int, fieldName string) error {
}
// UpdatePreferences updates notification preferences
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
func (s *NotificationService) UpdatePreferences(ctx context.Context, userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
// B-12: Validate hour fields are in range 0-23
hourFields := []struct {
value *int
@@ -213,7 +213,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
}
}
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -258,7 +258,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
prefs.DailyDigestHour = req.DailyDigestHour
}
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil {
return nil, apperrors.Internal(err)
}
@@ -268,14 +268,14 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
// UpdateUserTimezone updates the user's timezone for background job calculations.
// This is called automatically when the user makes API calls (e.g., fetching tasks).
// The timezone should be an IANA timezone name (e.g., "America/Los_Angeles").
func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
func (s *NotificationService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) {
// Validate timezone is a valid IANA name
if _, err := time.LoadLocation(timezone); err != nil {
return // Invalid timezone, skip silently
}
// Get or create preferences and update timezone
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil {
return // Skip silently on error
}
@@ -283,7 +283,7 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
// Only update if timezone changed (avoid unnecessary DB writes)
if prefs.Timezone == nil || *prefs.Timezone != timezone {
prefs.Timezone = &timezone
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
if err := s.notificationRepo.WithContext(ctx).UpdatePreferences(prefs); err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
Msg("Failed to update user timezone in notification preferences")
}
@@ -293,27 +293,27 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
// === Device Registration ===
// RegisterDevice registers a device for push notifications
func (s *NotificationService) RegisterDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
func (s *NotificationService) RegisterDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
switch req.Platform {
case push.PlatformIOS:
return s.registerAPNSDevice(userID, req)
return s.registerAPNSDevice(ctx, userID, req)
case push.PlatformAndroid:
return s.registerGCMDevice(userID, req)
return s.registerGCMDevice(ctx, userID, req)
default:
return nil, apperrors.BadRequest("error.invalid_platform")
}
}
func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
// Check if device exists
existing, err := s.notificationRepo.FindAPNSDeviceByToken(req.RegistrationID)
existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID)
if err == nil {
// Update existing device
existing.UserID = &userID
existing.Active = true
existing.Name = req.Name
existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateAPNSDevice(existing); err != nil {
if err := s.notificationRepo.WithContext(ctx).UpdateAPNSDevice(existing); err != nil {
return nil, apperrors.Internal(err)
}
return NewAPNSDeviceResponse(existing), nil
@@ -327,22 +327,22 @@ func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDevic
RegistrationID: req.RegistrationID,
Active: true,
}
if err := s.notificationRepo.CreateAPNSDevice(device); err != nil {
if err := s.notificationRepo.WithContext(ctx).CreateAPNSDevice(device); err != nil {
return nil, apperrors.Internal(err)
}
return NewAPNSDeviceResponse(device), nil
}
func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
// Check if device exists
existing, err := s.notificationRepo.FindGCMDeviceByToken(req.RegistrationID)
existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID)
if err == nil {
// Update existing device
existing.UserID = &userID
existing.Active = true
existing.Name = req.Name
existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateGCMDevice(existing); err != nil {
if err := s.notificationRepo.WithContext(ctx).UpdateGCMDevice(existing); err != nil {
return nil, apperrors.Internal(err)
}
return NewGCMDeviceResponse(existing), nil
@@ -357,20 +357,20 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice
CloudMessageType: "FCM",
Active: true,
}
if err := s.notificationRepo.CreateGCMDevice(device); err != nil {
if err := s.notificationRepo.WithContext(ctx).CreateGCMDevice(device); err != nil {
return nil, apperrors.Internal(err)
}
return NewGCMDeviceResponse(device), nil
}
// ListDevices lists all devices for a user
func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) {
iosDevices, err := s.notificationRepo.FindAPNSDevicesByUser(userID)
func (s *NotificationService) ListDevices(ctx context.Context, userID uint) ([]DeviceResponse, error) {
iosDevices, err := s.notificationRepo.WithContext(ctx).FindAPNSDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.Internal(err)
}
androidDevices, err := s.notificationRepo.FindGCMDevicesByUser(userID)
androidDevices, err := s.notificationRepo.WithContext(ctx).FindGCMDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.Internal(err)
}
@@ -387,10 +387,10 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error)
// DeleteDevice deactivates a device after verifying it belongs to the requesting user.
// Without ownership verification, an attacker could deactivate push notifications for other users.
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error {
func (s *NotificationService) DeleteDevice(ctx context.Context, deviceID uint, platform string, userID uint) error {
switch platform {
case push.PlatformIOS:
device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID)
device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByID(deviceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.device_not_found")
@@ -401,11 +401,11 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
if device.UserID == nil || *device.UserID != userID {
return apperrors.Forbidden("error.device_access_denied")
}
if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil {
if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(deviceID); err != nil {
return apperrors.Internal(err)
}
case push.PlatformAndroid:
device, err := s.notificationRepo.FindGCMDeviceByID(deviceID)
device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByID(deviceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.device_not_found")
@@ -416,7 +416,7 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
if device.UserID == nil || *device.UserID != userID {
return apperrors.Forbidden("error.device_access_denied")
}
if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil {
if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(deviceID); err != nil {
return apperrors.Internal(err)
}
default:
@@ -426,10 +426,10 @@ func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userI
}
// UnregisterDevice deactivates a device by its registration token
func (s *NotificationService) UnregisterDevice(registrationID, platform string, userID uint) error {
func (s *NotificationService) UnregisterDevice(ctx context.Context, registrationID, platform string, userID uint) error {
switch platform {
case push.PlatformIOS:
device, err := s.notificationRepo.FindAPNSDeviceByToken(registrationID)
device, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(registrationID)
if err != nil {
return apperrors.NotFound("error.device_not_found")
}
@@ -437,11 +437,11 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
if device.UserID == nil || *device.UserID != userID {
return apperrors.NotFound("error.device_not_found")
}
if err := s.notificationRepo.DeactivateAPNSDevice(device.ID); err != nil {
if err := s.notificationRepo.WithContext(ctx).DeactivateAPNSDevice(device.ID); err != nil {
return apperrors.Internal(err)
}
case push.PlatformAndroid:
device, err := s.notificationRepo.FindGCMDeviceByToken(registrationID)
device, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(registrationID)
if err != nil {
return apperrors.NotFound("error.device_not_found")
}
@@ -449,7 +449,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
if device.UserID == nil || *device.UserID != userID {
return apperrors.NotFound("error.device_not_found")
}
if err := s.notificationRepo.DeactivateGCMDevice(device.ID); err != nil {
if err := s.notificationRepo.WithContext(ctx).DeactivateGCMDevice(device.ID); err != nil {
return apperrors.Internal(err)
}
default:
@@ -624,7 +624,7 @@ func (s *NotificationService) CreateAndSendTaskNotification(
task *models.Task,
) error {
// Check user notification preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
prefs, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -662,12 +662,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
TaskID: &task.ID,
}
if err := s.notificationRepo.Create(notification); err != nil {
if err := s.notificationRepo.WithContext(ctx).Create(notification); err != nil {
return apperrors.Internal(err)
}
// Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
iosTokens, androidTokens, err := s.notificationRepo.WithContext(ctx).GetActiveTokensForUser(userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -691,12 +691,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
if s.pushClient != nil {
err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID)
if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error())
s.notificationRepo.WithContext(ctx).SetError(notification.ID, err.Error())
return apperrors.Internal(err)
}
}
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil {
if err := s.notificationRepo.WithContext(ctx).MarkAsSent(notification.ID); err != nil {
return apperrors.Internal(err)
}
return nil
+58 -58
View File
@@ -44,7 +44,7 @@ func TestNotificationService_GetNotifications(t *testing.T) {
require.NoError(t, err)
}
resp, err := service.GetNotifications(user.ID, 10, 0)
resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, resp, 3)
}
@@ -56,7 +56,7 @@ func TestNotificationService_GetNotifications_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := service.GetNotifications(user.ID, 10, 0)
resp, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -80,7 +80,7 @@ func TestNotificationService_GetNotifications_Pagination(t *testing.T) {
}
// Get first 2
resp, err := service.GetNotifications(user.ID, 2, 0)
resp, err := service.GetNotifications(context.Background(), user.ID, 2, 0)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
@@ -107,7 +107,7 @@ func TestNotificationService_GetUnreadCount(t *testing.T) {
require.NoError(t, err)
}
count, err := service.GetUnreadCount(user.ID)
count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
@@ -119,7 +119,7 @@ func TestNotificationService_GetUnreadCount_Zero(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
count, err := service.GetUnreadCount(user.ID)
count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
@@ -142,11 +142,11 @@ func TestNotificationService_MarkAsRead(t *testing.T) {
err := db.Create(notif).Error
require.NoError(t, err)
err = service.MarkAsRead(notif.ID, user.ID)
err = service.MarkAsRead(context.Background(), notif.ID, user.ID)
require.NoError(t, err)
// Verify unread count is 0
count, err := service.GetUnreadCount(user.ID)
count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
@@ -168,7 +168,7 @@ func TestNotificationService_MarkAsRead_WrongUser(t *testing.T) {
err := db.Create(notif).Error
require.NoError(t, err)
err = service.MarkAsRead(notif.ID, other.ID)
err = service.MarkAsRead(context.Background(), notif.ID, other.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
}
@@ -179,7 +179,7 @@ func TestNotificationService_MarkAsRead_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.MarkAsRead(9999, user.ID)
err := service.MarkAsRead(context.Background(), 9999, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.notification_not_found")
}
@@ -204,10 +204,10 @@ func TestNotificationService_MarkAllAsRead(t *testing.T) {
require.NoError(t, err)
}
err := service.MarkAllAsRead(user.ID)
err := service.MarkAllAsRead(context.Background(), user.ID)
require.NoError(t, err)
count, err := service.GetUnreadCount(user.ID)
count, err := service.GetUnreadCount(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
}
@@ -229,7 +229,7 @@ func TestNotificationService_CreateAndSendNotification(t *testing.T) {
require.NoError(t, err)
// Verify notification was created
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
assert.Equal(t, "Due Soon", notifs[0].Title)
@@ -254,7 +254,7 @@ func TestNotificationService_CreateAndSendNotification_DisabledPreference(t *tes
require.NoError(t, err)
// Verify no notification was created (silently skipped)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, notifs)
}
@@ -268,7 +268,7 @@ func TestNotificationService_GetPreferences(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := service.GetPreferences(user.ID)
resp, err := service.GetPreferences(context.Background(), user.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
// Defaults should all be true
@@ -289,7 +289,7 @@ func TestNotificationService_UpdatePreferences(t *testing.T) {
TaskDueSoon: &falseVal,
}
resp, err := service.UpdatePreferences(user.ID, req)
resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
require.NoError(t, err)
assert.False(t, resp.TaskDueSoon)
assert.True(t, resp.TaskOverdue) // unchanged
@@ -307,7 +307,7 @@ func TestNotificationService_UpdatePreferences_InvalidHour(t *testing.T) {
TaskDueSoonHour: &invalidHour,
}
_, err := service.UpdatePreferences(user.ID, req)
_, err := service.UpdatePreferences(context.Background(), user.ID, req)
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
}
@@ -323,7 +323,7 @@ func TestNotificationService_UpdatePreferences_ValidHour(t *testing.T) {
TaskDueSoonHour: &hour,
}
resp, err := service.UpdatePreferences(user.ID, req)
resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, 9, *resp.TaskDueSoonHour)
}
@@ -344,7 +344,7 @@ func TestNotificationService_RegisterDevice_iOS(t *testing.T) {
Platform: push.PlatformIOS,
}
resp, err := service.RegisterDevice(user.ID, req)
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "iPhone 15", resp.Name)
assert.Equal(t, push.PlatformIOS, resp.Platform)
@@ -365,7 +365,7 @@ func TestNotificationService_RegisterDevice_Android(t *testing.T) {
Platform: push.PlatformAndroid,
}
resp, err := service.RegisterDevice(user.ID, req)
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Pixel 8", resp.Name)
assert.Equal(t, push.PlatformAndroid, resp.Platform)
@@ -386,7 +386,7 @@ func TestNotificationService_RegisterDevice_InvalidPlatform(t *testing.T) {
Platform: "windows",
}
_, err := service.RegisterDevice(user.ID, req)
_, err := service.RegisterDevice(context.Background(), user.ID, req)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
}
@@ -404,12 +404,12 @@ func TestNotificationService_RegisterDevice_UpdateExisting(t *testing.T) {
RegistrationID: "token-xyz",
Platform: push.PlatformIOS,
}
_, err := service.RegisterDevice(user.ID, req)
_, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err)
// Re-register with same token (should update, not duplicate)
req.Name = "iPhone 15 Pro"
resp, err := service.RegisterDevice(user.ID, req)
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "iPhone 15 Pro", resp.Name)
}
@@ -445,7 +445,7 @@ func TestNotificationService_ListDevices(t *testing.T) {
err = db.Create(androidDevice).Error
require.NoError(t, err)
resp, err := service.ListDevices(user.ID)
resp, err := service.ListDevices(context.Background(), user.ID)
require.NoError(t, err)
assert.Len(t, resp, 2)
}
@@ -457,7 +457,7 @@ func TestNotificationService_ListDevices_Empty(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := service.ListDevices(user.ID)
resp, err := service.ListDevices(context.Background(), user.ID)
require.NoError(t, err)
assert.Empty(t, resp)
}
@@ -471,7 +471,7 @@ func TestDeleteDevice_InvalidPlatform_Returns400(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDevice(1, "windows", user.ID)
err := service.DeleteDevice(context.Background(), 1, "windows", user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
}
@@ -494,7 +494,7 @@ func TestNotificationService_UnregisterDevice_iOS(t *testing.T) {
err := db.Create(device).Error
require.NoError(t, err)
err = service.UnregisterDevice("reg-token-ios", push.PlatformIOS, user.ID)
err = service.UnregisterDevice(context.Background(), "reg-token-ios", push.PlatformIOS, user.ID)
require.NoError(t, err)
// Verify device is deactivated
@@ -522,7 +522,7 @@ func TestNotificationService_UnregisterDevice_Android(t *testing.T) {
err := db.Create(device).Error
require.NoError(t, err)
err = service.UnregisterDevice("reg-token-android", push.PlatformAndroid, user.ID)
err = service.UnregisterDevice(context.Background(), "reg-token-android", push.PlatformAndroid, user.ID)
require.NoError(t, err)
var found models.GCMDevice
@@ -538,7 +538,7 @@ func TestNotificationService_UnregisterDevice_NotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.UnregisterDevice("nonexistent-token", push.PlatformIOS, user.ID)
err := service.UnregisterDevice(context.Background(), "nonexistent-token", push.PlatformIOS, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
}
@@ -560,7 +560,7 @@ func TestNotificationService_UnregisterDevice_WrongUser(t *testing.T) {
err := db.Create(device).Error
require.NoError(t, err)
err = service.UnregisterDevice("owner-token", push.PlatformIOS, attacker.ID)
err = service.UnregisterDevice(context.Background(), "owner-token", push.PlatformIOS, attacker.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
}
@@ -571,7 +571,7 @@ func TestNotificationService_UnregisterDevice_InvalidPlatform(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.UnregisterDevice("some-token", "windows", user.ID)
err := service.UnregisterDevice(context.Background(), "some-token", "windows", user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_platform")
}
@@ -585,7 +585,7 @@ func TestNotificationService_UpdateUserTimezone(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Should not panic, just silently update
service.UpdateUserTimezone(user.ID, "America/Los_Angeles")
service.UpdateUserTimezone(context.Background(), user.ID, "America/Los_Angeles")
// Verify timezone was stored
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
@@ -602,7 +602,7 @@ func TestNotificationService_UpdateUserTimezone_Invalid(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Invalid timezone should be silently ignored
service.UpdateUserTimezone(user.ID, "Invalid/Timezone")
service.UpdateUserTimezone(context.Background(), user.ID, "Invalid/Timezone")
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
@@ -617,10 +617,10 @@ func TestNotificationService_UpdateUserTimezone_NoChangeSkipsWrite(t *testing.T)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
// Set timezone
service.UpdateUserTimezone(user.ID, "America/New_York")
service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York")
// Set same timezone again — should be a no-op
service.UpdateUserTimezone(user.ID, "America/New_York")
service.UpdateUserTimezone(context.Background(), user.ID, "America/New_York")
prefs, err := notifRepo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
@@ -648,7 +648,7 @@ func TestDeleteDevice_WrongUser_Returns403(t *testing.T) {
require.NoError(t, err)
// Attacker tries to deactivate the owner's device
err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID)
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, attacker.ID)
require.Error(t, err, "should not allow deleting another user's device")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
@@ -678,7 +678,7 @@ func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) {
require.NoError(t, err)
// Owner deactivates their own device
err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID)
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformIOS, owner.ID)
require.NoError(t, err, "owner should be able to deactivate their own device")
// Verify the device is now inactive
@@ -709,7 +709,7 @@ func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) {
require.NoError(t, err)
// Attacker tries to deactivate the owner's Android device
err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID)
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, attacker.ID)
require.Error(t, err, "should not allow deleting another user's Android device")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
@@ -727,7 +727,7 @@ func TestDeleteDevice_NonExistent_Returns404(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
err := service.DeleteDevice(99999, push.PlatformIOS, user.ID)
err := service.DeleteDevice(context.Background(), 99999, push.PlatformIOS, user.ID)
require.Error(t, err, "should return error for non-existent device")
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
}
@@ -744,7 +744,7 @@ func TestNotificationService_CreateAndSend_TaskOverdue(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Task is overdue", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
assert.Equal(t, "Overdue", notifs[0].Title)
@@ -760,7 +760,7 @@ func TestNotificationService_CreateAndSend_TaskCompleted(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
}
@@ -775,7 +775,7 @@ func TestNotificationService_CreateAndSend_TaskAssigned(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned to you", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
}
@@ -790,7 +790,7 @@ func TestNotificationService_CreateAndSend_ResidenceShared(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Someone shared a home", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
}
@@ -805,7 +805,7 @@ func TestNotificationService_CreateAndSend_WarrantyExpiring(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Expiring", "Warranty expiring soon", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
}
@@ -828,7 +828,7 @@ func TestNotificationService_DisabledPrefs_TaskOverdue(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskOverdue, "Overdue", "Overdue task", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, notifs)
}
@@ -849,7 +849,7 @@ func TestNotificationService_DisabledPrefs_TaskCompleted(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskCompleted, "Completed", "Task done", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, notifs)
}
@@ -870,7 +870,7 @@ func TestNotificationService_DisabledPrefs_TaskAssigned(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskAssigned, "Assigned", "Task assigned", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, notifs)
}
@@ -891,7 +891,7 @@ func TestNotificationService_DisabledPrefs_ResidenceShared(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationResidenceShared, "Shared", "Home shared", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, notifs)
}
@@ -912,7 +912,7 @@ func TestNotificationService_DisabledPrefs_WarrantyExpiring(t *testing.T) {
err = service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationWarrantyExpiring, "Warranty", "Expiring", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Empty(t, notifs)
}
@@ -929,7 +929,7 @@ func TestNotificationService_CreateAndSend_UnknownTypeDefaultsEnabled(t *testing
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationType("unknown_type"), "Unknown", "Unknown notification", nil)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
}
@@ -953,7 +953,7 @@ func TestNotificationService_CreateAndSend_WithMixedDataTypes(t *testing.T) {
err := service.CreateAndSendNotification(context.Background(), user.ID, models.NotificationTaskDueSoon, "Due Soon", "Fix faucet", data)
require.NoError(t, err)
notifs, err := service.GetNotifications(user.ID, 10, 0)
notifs, err := service.GetNotifications(context.Background(), user.ID, 10, 0)
require.NoError(t, err)
assert.Len(t, notifs, 1)
assert.NotNil(t, notifs[0].Data)
@@ -985,7 +985,7 @@ func TestNotificationService_UpdatePreferences_MultipleFields(t *testing.T) {
TaskOverdueHour: &hour14,
}
resp, err := service.UpdatePreferences(user.ID, req)
resp, err := service.UpdatePreferences(context.Background(), user.ID, req)
require.NoError(t, err)
assert.False(t, resp.TaskDueSoon)
assert.False(t, resp.TaskOverdue)
@@ -1013,7 +1013,7 @@ func TestNotificationService_UpdatePreferences_NegativeHour(t *testing.T) {
TaskOverdueHour: &negHour,
}
_, err := service.UpdatePreferences(user.ID, req)
_, err := service.UpdatePreferences(context.Background(), user.ID, req)
testutil.AssertAppErrorCode(t, err, http.StatusBadRequest)
}
@@ -1032,12 +1032,12 @@ func TestNotificationService_RegisterDevice_UpdateExistingAndroid(t *testing.T)
RegistrationID: "token-android-1",
Platform: push.PlatformAndroid,
}
_, err := service.RegisterDevice(user.ID, req)
_, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err)
// Re-register with same token but new name
req.Name = "Pixel 8 Pro"
resp, err := service.RegisterDevice(user.ID, req)
resp, err := service.RegisterDevice(context.Background(), user.ID, req)
require.NoError(t, err)
assert.Equal(t, "Pixel 8 Pro", resp.Name)
assert.Equal(t, push.PlatformAndroid, resp.Platform)
@@ -1052,7 +1052,7 @@ func TestDeleteDevice_AndroidNotFound_Returns404(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.DeleteDevice(99999, push.PlatformAndroid, user.ID)
err := service.DeleteDevice(context.Background(), 99999, push.PlatformAndroid, user.ID)
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
}
@@ -1076,7 +1076,7 @@ func TestDeleteDevice_CorrectUser_Android_Succeeds(t *testing.T) {
err := db.Create(device).Error
require.NoError(t, err)
err = service.DeleteDevice(device.ID, push.PlatformAndroid, owner.ID)
err = service.DeleteDevice(context.Background(), device.ID, push.PlatformAndroid, owner.ID)
require.NoError(t, err)
var found models.GCMDevice
@@ -1106,7 +1106,7 @@ func TestNotificationService_UnregisterDevice_WrongUser_Android(t *testing.T) {
err := db.Create(device).Error
require.NoError(t, err)
err = service.UnregisterDevice("owner-android-token", push.PlatformAndroid, attacker.ID)
err = service.UnregisterDevice(context.Background(), "owner-android-token", push.PlatformAndroid, attacker.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
}
@@ -1119,7 +1119,7 @@ func TestNotificationService_UnregisterDevice_AndroidNotFound(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
err := service.UnregisterDevice("nonexistent-android", push.PlatformAndroid, user.ID)
err := service.UnregisterDevice(context.Background(), "nonexistent-android", push.PlatformAndroid, user.ID)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.device_not_found")
}
+1 -1
View File
@@ -186,7 +186,7 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary {
func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) {
// Check subscription tier limits (if subscription service is wired up)
if s.subscriptionService != nil {
if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil {
if err := s.subscriptionService.CheckLimit(ctx, ownerID, "properties"); err != nil {
return nil, err
}
}
+42 -42
View File
@@ -98,8 +98,8 @@ func NewSubscriptionService(
}
// GetSubscription gets the subscription for a user
func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
func (s *SubscriptionService) GetSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) {
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -107,13 +107,13 @@ func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionRespons
}
// GetSubscriptionStatus gets detailed subscription status including limits
func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionStatusResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID uint) (*SubscriptionStatusResponse, error) {
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
settings, err := s.subscriptionRepo.GetSettings()
settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings()
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -122,18 +122,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
now := time.Now().UTC()
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
if err := s.subscriptionRepo.WithContext(ctx).SetTrialDates(userID, now, trialEnd); err != nil {
return nil, apperrors.Internal(err)
}
// Re-fetch after starting trial so response reflects the new state
sub, err = s.subscriptionRepo.FindByUserID(userID)
sub, err = s.subscriptionRepo.WithContext(ctx).FindByUserID(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
}
// Get all tier limits and build a map
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
allLimits, err := s.subscriptionRepo.WithContext(ctx).GetAllTierLimits()
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -154,7 +154,7 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
}
// Get current usage
usage, err := s.getUserUsage(userID)
usage, err := s.getUserUsage(ctx, userID)
if err != nil {
return nil, err
}
@@ -204,31 +204,31 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
// getUserUsage calculates current usage for a user.
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
func (s *SubscriptionService) getUserUsage(ctx context.Context, userID uint) (*UsageResponse, error) {
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
propertiesCount, err := s.residenceRepo.CountByOwner(userID)
propertiesCount, err := s.residenceRepo.WithContext(ctx).CountByOwner(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Still need residence IDs for batch counting tasks/contractors/documents
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID)
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByOwner(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Count tasks, contractors, and documents across all residences with single queries each
tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs)
tasksCount, err := s.taskRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs)
contractorsCount, err := s.contractorRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs)
documentsCount, err := s.documentRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -242,8 +242,8 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
}
// CheckLimit checks if a user has exceeded a specific limit
func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
settings, err := s.subscriptionRepo.GetSettings()
func (s *SubscriptionService) CheckLimit(ctx context.Context, userID uint, limitType string) error {
settings, err := s.subscriptionRepo.WithContext(ctx).GetSettings()
if err != nil {
return apperrors.Internal(err)
}
@@ -253,7 +253,7 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
return nil
}
sub, err := s.subscriptionRepo.GetOrCreate(userID)
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil {
return apperrors.Internal(err)
}
@@ -268,12 +268,12 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
return nil
}
limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier)
limits, err := s.subscriptionRepo.WithContext(ctx).GetTierLimits(sub.Tier)
if err != nil {
return apperrors.Internal(err)
}
usage, err := s.getUserUsage(userID)
usage, err := s.getUserUsage(ctx, userID)
if err != nil {
return err
}
@@ -301,8 +301,8 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
}
// GetUpgradeTrigger gets an upgrade trigger by key
func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResponse, error) {
trigger, err := s.subscriptionRepo.GetUpgradeTrigger(key)
func (s *SubscriptionService) GetUpgradeTrigger(ctx context.Context, key string) (*UpgradeTriggerResponse, error) {
trigger, err := s.subscriptionRepo.WithContext(ctx).GetUpgradeTrigger(key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.upgrade_trigger_not_found")
@@ -314,8 +314,8 @@ func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResp
// GetAllUpgradeTriggers gets all active upgrade triggers as a map keyed by trigger_key
// KMM client expects Map<String, UpgradeTriggerData>
func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTriggerDataResponse, error) {
triggers, err := s.subscriptionRepo.GetAllUpgradeTriggers()
func (s *SubscriptionService) GetAllUpgradeTriggers(ctx context.Context) (map[string]*UpgradeTriggerDataResponse, error) {
triggers, err := s.subscriptionRepo.WithContext(ctx).GetAllUpgradeTriggers()
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -328,8 +328,8 @@ func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTrigge
}
// GetFeatureBenefits gets all feature benefits
func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, error) {
benefits, err := s.subscriptionRepo.GetFeatureBenefits()
func (s *SubscriptionService) GetFeatureBenefits(ctx context.Context) ([]FeatureBenefitResponse, error) {
benefits, err := s.subscriptionRepo.WithContext(ctx).GetFeatureBenefits()
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -342,13 +342,13 @@ func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, er
}
// GetActivePromotions gets active promotions for a user
func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID uint) ([]PromotionResponse, error) {
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
promotions, err := s.subscriptionRepo.GetActivePromotions(sub.Tier)
promotions, err := s.subscriptionRepo.WithContext(ctx).GetActivePromotions(sub.Tier)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -362,13 +362,13 @@ func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionRespo
// ProcessApplePurchase processes an Apple IAP purchase
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
// Store receipt/transaction data
dataToStore := receiptData
if dataToStore == "" {
dataToStore = transactionID
}
if err := s.subscriptionRepo.UpdateReceiptData(userID, dataToStore); err != nil {
if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil {
return nil, apperrors.Internal(err)
}
@@ -406,18 +406,18 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
// Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return nil, apperrors.Internal(err)
}
return s.GetSubscription(userID)
return s.GetSubscription(ctx, userID)
}
// ProcessGooglePurchase processes a Google Play purchase
// productID is optional but helps validate the specific subscription
func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
// Store purchase token first
if err := s.subscriptionRepo.UpdatePurchaseToken(userID, purchaseToken); err != nil {
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
return nil, apperrors.Internal(err)
}
@@ -463,25 +463,25 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
}
// Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil {
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
return nil, apperrors.Internal(err)
}
return s.GetSubscription(userID)
return s.GetSubscription(ctx, userID)
}
// CancelSubscription cancels a subscription (downgrades to free at end of period)
func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResponse, error) {
if err := s.subscriptionRepo.SetAutoRenew(userID, false); err != nil {
func (s *SubscriptionService) CancelSubscription(ctx context.Context, userID uint) (*SubscriptionResponse, error) {
if err := s.subscriptionRepo.WithContext(ctx).SetAutoRenew(userID, false); err != nil {
return nil, apperrors.Internal(err)
}
return s.GetSubscription(userID)
return s.GetSubscription(ctx, userID)
}
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(ctx context.Context, userID uint, requestedPlatform string) (bool, string, error) {
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil {
return false, "", apperrors.Internal(err)
}
@@ -1,6 +1,7 @@
package services
import (
"context"
"testing"
"time"
@@ -70,7 +71,7 @@ func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) {
googleClient: nil,
}
_, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "")
_, err := svc.ProcessApplePurchase(context.Background(), user.ID, "fake-receipt", "")
assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil")
// Verify user was NOT upgraded to Pro
@@ -109,7 +110,7 @@ func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
}
// Neither receipt data nor transaction ID - should still not grant Pro
_, err := svc.ProcessApplePurchase(user.ID, "", "")
_, err := svc.ProcessApplePurchase(context.Background(), user.ID, "", "")
assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data")
// Verify no upgrade happened
@@ -140,7 +141,7 @@ func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) {
googleClient: nil, // Not configured
}
_, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.honeyDue.pro.monthly")
_, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "fake-token", "com.tt.honeyDue.pro.monthly")
assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil")
// Verify user was NOT upgraded to Pro
@@ -172,7 +173,7 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
}
// With empty token
_, err := svc.ProcessGooglePurchase(user.ID, "", "")
_, err := svc.ProcessGooglePurchase(context.Background(), user.ID, "", "")
assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil")
// Verify no upgrade happened
@@ -202,7 +203,7 @@ func TestSubscriptionService_GetSubscription(t *testing.T) {
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
resp, err := svc.GetSubscription(user.ID)
resp, err := svc.GetSubscription(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, "free", resp.Tier)
assert.False(t, resp.IsPro)
@@ -238,7 +239,7 @@ func TestSubscriptionService_GetSubscription_ProUser(t *testing.T) {
err := db.Create(sub).Error
require.NoError(t, err)
resp, err := svc.GetSubscription(user.ID)
resp, err := svc.GetSubscription(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, "pro", resp.Tier)
assert.True(t, resp.IsPro)
@@ -277,7 +278,7 @@ func TestSubscriptionService_CancelSubscription(t *testing.T) {
err := db.Create(sub).Error
require.NoError(t, err)
resp, err := svc.CancelSubscription(user.ID)
resp, err := svc.CancelSubscription(context.Background(), user.ID)
require.NoError(t, err)
assert.False(t, resp.AutoRenew)
}
@@ -365,7 +366,7 @@ func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
err := db.Create(sub).Error
require.NoError(t, err)
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(context.Background(), user.ID, tt.requestedPlatform)
require.NoError(t, err)
assert.Equal(t, tt.wantConflict, conflict)
assert.Equal(t, tt.wantPlatform, existingPlatform)
+3 -3
View File
@@ -898,7 +898,7 @@ func (s *TaskService) sendTaskCompletedNotification(ctx context.Context, task *m
// Send email notification (to everyone INCLUDING the person who completed it)
// Check user's email notification preferences first
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
prefs, prefsErr := s.notificationService.GetPreferences(user.ID)
prefs, prefsErr := s.notificationService.GetPreferences(ctx, user.ID)
// LE-06: Log fail-open behavior when preferences cannot be loaded
if prefsErr != nil {
log.Warn().
@@ -1264,8 +1264,8 @@ func (s *TaskService) GetFrequencies(ctx context.Context) ([]responses.TaskFrequ
// UpdateUserTimezone updates the user's timezone for background job calculations.
// This is called from handlers when the X-Timezone header is present.
// Delegates to NotificationService since timezone is stored in notification preferences.
func (s *TaskService) UpdateUserTimezone(userID uint, timezone string) {
func (s *TaskService) UpdateUserTimezone(ctx context.Context, userID uint, timezone string) {
if s.notificationService != nil {
s.notificationService.UpdateUserTimezone(userID, timezone)
s.notificationService.UpdateUserTimezone(ctx, userID, timezone)
}
}