From 7690f07a2bf7c1f9607307820ea99bf2a1144468 Mon Sep 17 00:00:00 2001 From: Trey t Date: Mon, 2 Mar 2026 09:48:01 -0600 Subject: [PATCH] Harden API security: input validation, safe auth extraction, new tests, and deploy config Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 --- .deploy_prod | 5 + audit-digest-1.md | 38 + audit-digest-2.md | 51 + audit-digest-3.md | 30 + audit-digest-4.md | 48 + audit-digest-5.md | 45 + audit-digest-6.md | 40 + audit-digest-7.md | 57 + audit-digest-8.md | 35 + audit-digest-9.md | 49 + deploy/.gitignore | 18 + deploy/README.md | 134 ++ deploy/cluster.env.example | 22 + deploy/prod.env.example | 73 + deploy/registry.env.example | 11 + deploy/scripts/deploy_prod.sh | 397 +++++ deploy/secrets/README.md | 11 + deploy/secrets/apns_auth_key.p8.example | 3 + .../secrets/email_host_password.txt.example | 1 + deploy/secrets/fcm_server_key.txt.example | 1 + deploy/secrets/postgres_password.txt.example | 1 + deploy/secrets/secret_key.txt.example | 1 + deploy/shit_deploy_cant_do.md | 67 + deploy/swarm-stack.prod.yml | 288 ++++ docs/AUDIT_FINDINGS.md | 1527 +++++++++++++++++ hardening-report.md | 37 + internal/admin/dto/requests.go | 8 + .../admin/handlers/admin_security_test.go | 199 +++ internal/admin/handlers/admin_user_handler.go | 10 +- .../handlers/apple_social_auth_handler.go | 10 +- internal/admin/handlers/auth_token_handler.go | 9 +- internal/admin/handlers/completion_handler.go | 10 +- .../handlers/completion_image_handler.go | 9 +- .../handlers/confirmation_code_handler.go | 9 +- internal/admin/handlers/contractor_handler.go | 10 +- internal/admin/handlers/device_handler.go | 16 +- internal/admin/handlers/document_handler.go | 10 +- .../admin/handlers/document_image_handler.go | 9 +- .../admin/handlers/feature_benefit_handler.go | 8 +- internal/admin/handlers/lookup_handler.go | 12 + .../admin/handlers/notification_handler.go | 21 +- .../handlers/notification_prefs_handler.go | 9 +- .../handlers/password_reset_code_handler.go | 9 +- internal/admin/handlers/promotion_handler.go | 9 +- internal/admin/handlers/residence_handler.go | 10 +- internal/admin/handlers/share_code_handler.go | 18 +- .../admin/handlers/subscription_handler.go | 10 +- internal/admin/handlers/task_handler.go | 11 +- internal/admin/handlers/user_handler.go | 10 +- .../admin/handlers/user_profile_handler.go | 9 +- internal/config/config.go | 4 + internal/database/database.go | 21 +- internal/dto/requests/contractor.go | 12 +- internal/dto/requests/document.go | 14 +- internal/dto/requests/residence.go | 14 +- internal/dto/requests/task.go | 18 +- internal/handlers/auth_handler.go | 30 + internal/handlers/contractor_handler.go | 51 +- internal/handlers/contractor_handler_test.go | 182 ++ internal/handlers/document_handler.go | 59 +- internal/handlers/media_handler.go | 37 +- internal/handlers/media_handler_test.go | 74 + internal/handlers/noauth_test.go | 334 ++++ internal/handlers/notification_handler.go | 64 +- .../handlers/notification_handler_test.go | 88 + internal/handlers/residence_handler.go | 74 +- internal/handlers/residence_handler_test.go | 42 + internal/handlers/subscription_handler.go | 39 +- .../handlers/subscription_webhook_handler.go | 223 ++- .../subscription_webhook_handler_test.go | 56 + internal/handlers/task_handler.go | 108 +- internal/handlers/task_handler_test.go | 111 ++ internal/handlers/tracking_handler.go | 10 +- internal/handlers/upload_handler.go | 32 +- internal/handlers/upload_handler_test.go | 43 + internal/handlers/user_handler.go | 16 +- .../integration/security_regression_test.go | 633 +++++++ internal/middleware/admin_auth.go | 14 +- internal/middleware/auth.go | 24 +- internal/middleware/auth_safety_test.go | 119 ++ internal/middleware/request_id.go | 11 +- internal/middleware/request_id_test.go | 125 ++ internal/middleware/sanitize.go | 19 + internal/middleware/sanitize_test.go | 59 + internal/middleware/timezone.go | 24 +- internal/monitoring/collector.go | 25 +- internal/monitoring/handler.go | 11 +- internal/monitoring/service.go | 4 + internal/monitoring/writer.go | 55 +- internal/push/fcm.go | 3 + internal/push/fcm_test.go | 186 ++ internal/repositories/contractor_repo.go | 48 +- internal/repositories/contractor_repo_test.go | 96 ++ internal/repositories/document_repo.go | 20 +- internal/repositories/document_repo_test.go | 38 + internal/repositories/notification_repo.go | 59 +- .../repositories/notification_repo_test.go | 96 ++ internal/repositories/reminder_repo.go | 78 + internal/repositories/residence_repo.go | 11 +- internal/repositories/subscription_repo.go | 48 +- .../repositories/subscription_repo_test.go | 79 + internal/repositories/task_repo.go | 199 ++- internal/repositories/task_repo_test.go | 167 ++ internal/repositories/task_template_repo.go | 5 +- internal/repositories/util.go | 11 + internal/router/router.go | 1 + internal/services/contractor_service.go | 12 + internal/services/contractor_service_test.go | 98 ++ internal/services/iap_validation.go | 25 +- internal/services/notification_service.go | 43 +- .../services/notification_service_test.go | 126 ++ internal/services/onboarding_email_service.go | 35 +- internal/services/path_utils.go | 51 + internal/services/path_utils_test.go | 55 + internal/services/pdf_service.go | 7 +- internal/services/residence_service.go | 28 +- internal/services/residence_service_test.go | 123 ++ internal/services/storage_service.go | 20 +- internal/services/subscription_service.go | 212 ++- .../services/subscription_service_test.go | 181 ++ internal/services/task_service.go | 169 +- internal/services/task_service_test.go | 328 ++++ internal/worker/jobs/handler.go | 99 +- 123 files changed, 8321 insertions(+), 750 deletions(-) create mode 100755 .deploy_prod create mode 100644 audit-digest-1.md create mode 100644 audit-digest-2.md create mode 100644 audit-digest-3.md create mode 100644 audit-digest-4.md create mode 100644 audit-digest-5.md create mode 100644 audit-digest-6.md create mode 100644 audit-digest-7.md create mode 100644 audit-digest-8.md create mode 100644 audit-digest-9.md create mode 100644 deploy/.gitignore create mode 100644 deploy/README.md create mode 100644 deploy/cluster.env.example create mode 100644 deploy/prod.env.example create mode 100644 deploy/registry.env.example create mode 100755 deploy/scripts/deploy_prod.sh create mode 100644 deploy/secrets/README.md create mode 100644 deploy/secrets/apns_auth_key.p8.example create mode 100644 deploy/secrets/email_host_password.txt.example create mode 100644 deploy/secrets/fcm_server_key.txt.example create mode 100644 deploy/secrets/postgres_password.txt.example create mode 100644 deploy/secrets/secret_key.txt.example create mode 100644 deploy/shit_deploy_cant_do.md create mode 100644 deploy/swarm-stack.prod.yml create mode 100644 docs/AUDIT_FINDINGS.md create mode 100644 hardening-report.md create mode 100644 internal/admin/handlers/admin_security_test.go create mode 100644 internal/handlers/contractor_handler_test.go create mode 100644 internal/handlers/media_handler_test.go create mode 100644 internal/handlers/noauth_test.go create mode 100644 internal/handlers/notification_handler_test.go create mode 100644 internal/handlers/subscription_webhook_handler_test.go create mode 100644 internal/handlers/upload_handler_test.go create mode 100644 internal/integration/security_regression_test.go create mode 100644 internal/middleware/auth_safety_test.go create mode 100644 internal/middleware/request_id_test.go create mode 100644 internal/middleware/sanitize.go create mode 100644 internal/middleware/sanitize_test.go create mode 100644 internal/push/fcm_test.go create mode 100644 internal/repositories/contractor_repo_test.go create mode 100644 internal/repositories/document_repo_test.go create mode 100644 internal/repositories/notification_repo_test.go create mode 100644 internal/repositories/subscription_repo_test.go create mode 100644 internal/repositories/util.go create mode 100644 internal/services/contractor_service_test.go create mode 100644 internal/services/notification_service_test.go create mode 100644 internal/services/path_utils.go create mode 100644 internal/services/path_utils_test.go create mode 100644 internal/services/subscription_service_test.go diff --git a/.deploy_prod b/.deploy_prod new file mode 100755 index 0000000..c6d4ad2 --- /dev/null +++ b/.deploy_prod @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +exec "${SCRIPT_DIR}/deploy/scripts/deploy_prod.sh" "$@" diff --git a/audit-digest-1.md b/audit-digest-1.md new file mode 100644 index 0000000..7a859df --- /dev/null +++ b/audit-digest-1.md @@ -0,0 +1,38 @@ +# Digest 1: cmd/, admin/dto, admin/handlers (first 15 files) + +## Systemic Issues (across all admin handlers) +- **SQL Injection via SortBy**: Every admin list handler concatenates `filters.SortBy` directly into GORM `Order()` without allowlist validation +- **Unchecked Count() errors**: Every paginated handler ignores GORM Count error returns +- **Unchecked post-mutation Preload errors**: After Save/Create, handlers reload with Preload but ignore errors +- **`binding` vs `validate` tag mismatch**: Some request DTOs use `binding` (Gin) instead of `validate` (Echo) +- **Direct DB access**: All admin handlers bypass Service layer, accessing `*gorm.DB` directly +- **Unsafe type assertions**: `c.Get(key).(*models.AdminUser)` without comma-ok + +## Per-File Highlights + +### cmd/api/main.go - App entry point, wires dependencies +### cmd/worker/main.go - Background worker entry point + +### admin/handlers/admin_user_handler.go (347 lines) +- N+1 query: `toUserResponse` does 2 extra DB queries per user (residence count, task count) +- Line 64: SortBy SQL injection +- Line 173: Unchecked profile creation error (user created without profile) + +### admin/handlers/apple_social_auth_handler.go - CRUD for Apple social auth records +- Same systemic SQL injection and unchecked errors + +### admin/handlers/auth_handler.go - Admin login/session management +### admin/handlers/auth_token_handler.go - Auth token CRUD +### admin/handlers/completion_handler.go - Task completion CRUD +### admin/handlers/completion_image_handler.go - Completion image CRUD +### admin/handlers/confirmation_code_handler.go - Email confirmation code CRUD +### admin/handlers/contractor_handler.go - Contractor CRUD +### admin/handlers/dashboard_handler.go - Admin dashboard stats + +### admin/handlers/device_handler.go (317 lines) +- Exposes device push tokens (RegistrationID) in API responses +- Lines 293-296: Unchecked Count errors in GetStats + +### admin/handlers/document_handler.go (394 lines) +- Lines 176-183: Date parsing errors silently ignored +- Line 379: Precision loss from decimal.Float64() discarded diff --git a/audit-digest-2.md b/audit-digest-2.md new file mode 100644 index 0000000..fa95fe3 --- /dev/null +++ b/audit-digest-2.md @@ -0,0 +1,51 @@ +# Digest 2: admin/handlers (remaining 15 files) + +### admin/handlers/document_image_handler.go (245 lines) +- N+1: toResponse queries DB per image in List +- Same SortBy SQL injection + +### admin/handlers/feature_benefit_handler.go (231 lines) +- `binding` tags instead of `validate` - required fields never enforced + +### admin/handlers/limitations_handler.go (451 lines) +- Line 37: Unchecked Create error for default settings +- Line 191-197: UpdateTierLimits overwrites ALL fields even for partial updates + +### admin/handlers/lookup_handler.go (877 lines) +- **CRITICAL**: Lines 30-32, 50-52, etc.: refreshXxxCache checks `if cache == nil {}` with EMPTY body, then calls cache.CacheXxx() — nil pointer panic when cache is nil +- Line 792: Hardcoded join table name "task_contractor_specialties" + +### admin/handlers/notification_handler.go (419 lines) +- Line 351-363: HTML template built by string concatenation with user-provided subject/body — XSS in admin emails + +### admin/handlers/notification_prefs_handler.go (347 lines) +- Line 154: Unchecked user lookup — deleted user produces zero-value username/email + +### admin/handlers/onboarding_handler.go (343 lines) +- Line 304: Internal error details leaked to client + +### admin/handlers/password_reset_code_handler.go (161 lines) +- **BUG**: Line 85: `code.ResetToken[:8] + "..." + code.ResetToken[len-4:]` panics if token < 8 chars + +### admin/handlers/promotion_handler.go (304 lines) +- `binding` tags: required fields never enforced + +### admin/handlers/residence_handler.go (371 lines) +- Lines 121-122: Unchecked Count errors for task/document counts + +### admin/handlers/settings_handler.go (794 lines) +- Line 378: Raw SQL execution from seed files (no parameterization) +- Line 529-793: ClearAllData is destructive with no double-auth check +- Line 536-539: Panic in ClearAllData silently swallowed + +### admin/handlers/share_code_handler.go (225 lines) +- Line 155-162: `IsActive` as non-pointer bool — absent field defaults to false, deactivating codes + +### admin/handlers/subscription_handler.go (237 lines) +- **BUG**: Line 40-41: JOIN uses "users" but actual table is "auth_user" — query fails on PostgreSQL + +### admin/handlers/task_handler.go (401 lines) +- Line 247-296: Admin Create bypasses service layer — no business logic applied + +### admin/handlers/task_template_handler.go (347 lines) +- Lines 29-31: Same nil cache panic as lookup_handler.go diff --git a/audit-digest-3.md b/audit-digest-3.md new file mode 100644 index 0000000..bd18704 --- /dev/null +++ b/audit-digest-3.md @@ -0,0 +1,30 @@ +# Digest 3: admin routes, apperrors, config, database, dto/requests, dto/responses (first half) + +### admin/routes.go (483 lines) +- Route ordering: users DELETE "/:id" before "/bulk" — "/bulk" matches as id param +- Line 454: Uses os.Getenv instead of Viper config +- Line 460-462: url.Parse failure returns silently, no logging +- Line 467-469: Proxy errors not surfaced (always returns nil) + +### apperrors/errors.go (98 lines) - Clean. Error types with Wrap/Unwrap. +### apperrors/handler.go (67 lines) - c.JSON error returns discarded (minor) + +### config/config.go (427 lines) +- Line 339: Hardcoded debug secret key "change-me-in-production-secret-key-12345" +- Lines 311-312: Comments say wrong UTC times (says 8PM/9AM, actually 2PM/3PM) + +### database/database.go (468 lines) +- **SECURITY**: Line 372-382: Hardcoded bcrypt hash for GoAdmin with password "admin" — migration RESETS password every run +- **SECURITY**: Line 447-463: Hardcoded admin@mycrib.com / admin123 — password reset on every migration +- Line 182+: Multiple db.Exec errors unchecked for index creation +- Line 100-102: WithTransaction coupled to global db variable + +### dto/requests/auth.go (66 lines) - LoginRequest min=1 password (intentional for login) +### dto/requests/contractor.go (37 lines) - Rating has no min/max validation bounds +### dto/requests/document.go (47 lines) - ImageURLs no length limit, Description no max length +### dto/requests/residence.go (59 lines) - Bedrooms/SquareFootage accept negative values, ExpiresInHours no validation +### dto/requests/task.go (110 lines) - Rating no bounds, ImageURLs no length limit, CustomIntervalDays no min + +### dto/responses/auth.go (190 lines) - Clean. Proper nil checks on Profile. +### dto/responses/contractor.go (131 lines) - TaskCount depends on preloaded Tasks association +### dto/responses/document.go (126 lines) - Line 101: CreatedBy accessed as value type (fragile if changed to pointer) diff --git a/audit-digest-4.md b/audit-digest-4.md new file mode 100644 index 0000000..9bbea9a --- /dev/null +++ b/audit-digest-4.md @@ -0,0 +1,48 @@ +# Digest 4: dto/responses (remaining), echohelpers, handlers (first half) + +### dto/responses/residence.go (215 lines) - NewResidenceResponse no nil check on param. Owner zero-value if not preloaded. +### dto/responses/task_template.go (135 lines) - No nil check on template param +### dto/responses/task.go (399 lines) - No nil checks on params in factory functions +### dto/responses/user.go (20 lines) - Clean data types +### echohelpers/helpers.go (46 lines) - Clean utilities +### echohelpers/pagination.go (33 lines) - Clean, properly bounded + +### handlers/auth_handler.go (379 lines) +- **ARCHITECTURE**: Lines 83, 178, 207, 241, 329, 370: SIX goroutine spawns for email — violates "no goroutines in handlers" rule +- Line 308-312: AppError constructed directly instead of factory function + +### handlers/contractor_handler.go (154 lines) +- Line 28+: Unchecked type assertions throughout (7 instances) +- Line 31: Raw err.Error() returned to client +- Line 55: CreateContractor missing c.Validate() call + +### handlers/document_handler.go (336 lines) +- Line 37+: Unchecked type assertions (10 instances) +- Line 92-93: Raw error leaked to client +- Line 137: No DocumentType validation — any string accepted +- Lines 187, 217: Missing c.Validate() calls + +### handlers/media_handler.go (172 lines) +- **SECURITY**: Line 156-171: resolveFilePath uses filepath.Join with user-influenced data — PATH TRAVERSAL vulnerability. TrimPrefix doesn't sanitize ../ +- Line 19-22: Handler accesses repositories directly, bypasses service layer + +### handlers/notification_handler.go (200 lines) +- Line 29-40: No upper bound on limit — unbounded query with limit=999999999 +- Line 168: Silent default to "ios" platform + +### handlers/residence_handler.go (365 lines) +- Line 38+: Unchecked type assertions (14 instances) +- Lines 187, 209, 303: Bind errors silently discarded +- Line 224: JoinWithCode missing c.Validate() + +### handlers/static_data_handler.go (152 lines) - Uses interface{} instead of concrete types +### handlers/subscription_handler.go (176 lines) - Lines 97, 150: Missing c.Validate() calls +- Line 159-163: RestoreSubscription doesn't validate receipt/transaction ID presence + +### handlers/subscription_webhook_handler.go (821 lines) +- **SECURITY**: Line 190-192: Apple JWS payload decoded WITHOUT signature verification +- **SECURITY**: Line 787-793: VerifyGooglePubSubToken ALWAYS returns true — webhook unauthenticated +- Line 639-643: Subscription duration guessed by string matching product ID +- Line 657, 694: Hardcoded 1-month extension regardless of actual plan +- Line 759, 772: Unchecked type assertions in VerifyAppleSignature +- Line 162: Apple renewal info error silently discarded diff --git a/audit-digest-5.md b/audit-digest-5.md new file mode 100644 index 0000000..ae83de9 --- /dev/null +++ b/audit-digest-5.md @@ -0,0 +1,45 @@ +# Digest 5: handlers (remaining), i18n, middleware, models (first half) + +### handlers/task_handler.go (440 lines) +- Line 35+: Unchecked type assertions (18 locations) +- Line 42: Fire-and-forget goroutine for UpdateUserTimezone — no error handling, no context +- Lines 112-115, 134-137: Missing c.Validate() calls +- Line 317: 32MB multipart limit with no per-file size check + +### handlers/task_template_handler.go (98 lines) +- Line 59: No max length on search query — slow LIKE queries possible + +### handlers/tracking_handler.go (46 lines) +- Line 25: Package-level base64 decode error discarded +- Lines 34-36: Fire-and-forget goroutine — violates no-goroutines rule + +### handlers/upload_handler.go (93 lines) +- Line 31: User-controlled `category` param passed to storage — potential path traversal +- Line 80: `binding` tag instead of `validate` +- No file type or size validation at handler level + +### handlers/user_handler.go (76 lines) - Unchecked type assertions + +### i18n/i18n.go (87 lines) +- Line 16: Global Bundle is nil until Init() — NewLocalizer dereferences without nil check +- Line 37: MustParseMessageFileBytes panics on malformed translation files +- Line 83: MustT panics on missing translations + +### i18n/middleware.go (127 lines) - Clean + +### middleware/admin_auth.go (133 lines) +- **SECURITY**: Line 50: Admin JWT accepted via query param — tokens leak into server/proxy logs +- Line 124: Unchecked type assertion + +### middleware/auth.go (229 lines) +- **BUG**: Line 66: `token[:8]` panics if token is fewer than 8 characters +- Line 104: cacheUserID error silently discarded +- Line 209: Unchecked type assertion + +### middleware/logger.go (54 lines) - Clean +### middleware/request_id.go (44 lines) - Line 21: Client-supplied X-Request-ID accepted without validation (log injection) +### middleware/timezone.go (101 lines) - Lines 88, 99: Unchecked type assertions + +### models/admin.go (64 lines) - Line 38: No max password length check; bcrypt truncates at 72 bytes +### models/base.go (39 lines) - Clean GORM hooks +### models/contractor.go (54 lines) - *float64 mapped to decimal(2,1) — minor precision mismatch diff --git a/audit-digest-6.md b/audit-digest-6.md new file mode 100644 index 0000000..2c3d145 --- /dev/null +++ b/audit-digest-6.md @@ -0,0 +1,40 @@ +# Digest 6: models (remaining), monitoring + +### models/document.go (100 lines) - Clean +### models/notification.go (141 lines) - Clean +### models/onboarding_email.go (35 lines) - Clean +### models/reminder_log.go (92 lines) - Clean + +### models/residence.go (106 lines) +- Lines 65-70: GetAllUsers/HasAccess assumes Owner and Users are preloaded — returns wrong results if not + +### models/subscription.go (169 lines) +- Lines 57-65: IsActive()/IsPro() don't account for IsFree admin override field — misleading method names + +### models/task_template.go (23 lines) - Clean +### models/task.go (317 lines) +- Lines 189-264: GetKanbanColumnWithTimezone duplicates categorization chain logic — maintenance drift risk +- Lines 158-182: IsDueSoon uses time.Now() internally — non-deterministic, harder to test + +### models/user.go (268 lines) +- Line 101: crypto/rand.Read error unchecked (safe in practice since Go 1.20) +- Line 164-172: GenerateConfirmationCode has slight distribution bias (negligible) + +### monitoring/buffer.go (166 lines) - Line 75: Corrupted Redis data silently dropped +### monitoring/collector.go (201 lines) +- Line 82: cpu.Percent blocks 1 second per collection +- Line 96-110: ReadMemStats called TWICE per cycle (also in collectRuntime) + +### monitoring/handler.go (195 lines) +- **SECURITY**: Line 19-22: WebSocket CheckOrigin always returns true +- **BUG**: Line 117-119: After upgrader.Upgrade fails, execution continues to conn.Close() — nil pointer panic +- **BUG**: Line 177: Missing return after ctx.Done() — goroutine spins +- Lines 183-184: GetAllStats error silently ignored +- Line 192: WriteJSON error unchecked + +### monitoring/middleware.go (220 lines) - Clean +### monitoring/models.go (129 lines) - Clean + +### monitoring/service.go (196 lines) +- Line 121: Hardcoded primary key 1 for singleton settings row +- Line 191-194: MetricsMiddleware returns nil when httpCollector is nil — caller must nil-check diff --git a/audit-digest-7.md b/audit-digest-7.md new file mode 100644 index 0000000..606fa3e --- /dev/null +++ b/audit-digest-7.md @@ -0,0 +1,57 @@ +# Digest 7: monitoring/writer, notifications, push, repositories + +### monitoring/writer.go (96 lines) +- Line 90-92: Unbounded fire-and-forget goroutines for Redis push — no rate limiting + +### notifications/reminder_config.go (64 lines) - Clean config data +### notifications/reminder_schedule.go (199 lines) +- Line 112: Integer truncation of float division — DST off-by-one possible +- Line 148-161: Custom itoa reimplements strconv.Itoa + +### push/apns.go (209 lines) +- Line 44: Double-negative logic — both Production=false and Sandbox=false defaults to production + +### push/client.go (158 lines) +- Line 89-105: SendToAll last-error-wins — cannot tell which platform failed +- Line 150-157: HealthCheck always returns nil — useless health check + +### push/fcm.go (140 lines) +- Line 16: Legacy FCM HTTP API (deprecated by Google) +- Line 119-126: If FCM returns fewer results than tokens, index out of bounds panic + +### repositories/admin_repo.go (108 lines) +- Line 92: Negative page produces negative offset + +### repositories/contractor_repo.go (166 lines) +- **RACE**: Line 89-101: ToggleFavorite read-then-write without transaction +- Line 91: ToggleFavorite doesn't filter is_active — can toggle deleted contractors + +### repositories/document_repo.go (201 lines) +- Line 92: LIKE wildcards in user input not escaped +- Line 12: DocumentFilter.ResidenceID field defined but never used + +### repositories/notification_repo.go (267 lines) +- **RACE**: Line 137-161: GetOrCreatePreferences race — concurrent calls both create, duplicate key error +- Line 143: Uses == instead of errors.Is for ErrRecordNotFound + +### repositories/reminder_repo.go (126 lines) +- Line 115-122: rows.Err() not checked after iteration loop + +### repositories/residence_repo.go (344 lines) +- Line 272: DeactivateShareCode error silently ignored +- Line 298-301: Count error unchecked in generateUniqueCode — potential duplicate codes +- Line 125-128: PostgreSQL-specific ON CONFLICT — fails on SQLite in tests + +### repositories/subscription_repo.go (257 lines) +- **RACE**: Line 40: GetOrCreate race condition (same as notification_repo) +- Line 66: GORM v1 pattern `gorm:query_option` for FOR UPDATE — may not work in GORM v2 +- Line 129: LIKE search on receipt data blobs — inefficient, no index +- Lines 40, 168, 196: Uses == instead of errors.Is + +### repositories/task_repo.go (765 lines) +- Line 707-709: DeleteCompletion ignores image deletion error +- Line 62-101: applyFilterOptions applies NO scope when no filter set — could query all tasks + +### repositories/task_template_repo.go (124 lines) +- Line 48: LIKE wildcard escape issue +- Line 79-81: Save without Omit could corrupt Category/Frequency lookup data diff --git a/audit-digest-8.md b/audit-digest-8.md new file mode 100644 index 0000000..5beb0d7 --- /dev/null +++ b/audit-digest-8.md @@ -0,0 +1,35 @@ +# Digest 8: repositories (remaining), router, services (first half) + +### repositories/user_repo.go - Standard GORM CRUD +### repositories/webhook_event_repo.go - Webhook event storage + +### router/router.go - Route registration wiring + +### services/apple_auth.go - Apple Sign In JWT validation +### services/auth_service.go - Token management, password hashing, email verification + +### services/cache_service.go - Redis caching for lookups +### services/contractor_service.go - Contractor CRUD via repository + +### services/document_service.go - Document management +### services/email_service.go - SMTP email sending + +### services/google_auth.go - Google OAuth token validation +### services/iap_validation.go - Apple/Google receipt validation + +### services/notification_service.go - Push notifications, preferences + +### services/onboarding_email_service.go (371 lines) +- **ARCHITECTURE**: Direct *gorm.DB access — bypasses repository layer entirely +- Line 43-46: HasSentEmail ignores Count error — could send duplicate emails +- Line 128-133: GetEmailStats ignores 4 Count errors +- Line 170: Raw SQL references "auth_user" table +- Line 354: Delete error silently ignored + +### services/pdf_service.go (179 lines) +- **BUG**: Line 131-133: Byte-level truncation of title — breaks multi-byte UTF-8 (CJK, emoji) + +### services/residence_service.go (648 lines) +- Line 155: TODO comment — subscription tier limit check commented out (free tier bypass) +- Line 447-450: Empty if block — DeactivateShareCode error completely ignored +- Line 625: Status only set for in-progress tasks; all others have empty string diff --git a/audit-digest-9.md b/audit-digest-9.md new file mode 100644 index 0000000..389c1bf --- /dev/null +++ b/audit-digest-9.md @@ -0,0 +1,49 @@ +# Digest 9: services (remaining), task package, testutil, validator, worker, pkg + +### services/storage_service.go (184 lines) +- Line 75: UUID truncated to 8 chars — increased collision risk +- **SECURITY**: Line 137-138: filepath.Abs errors ignored — path traversal check could be bypassed + +### services/subscription_service.go (659 lines) +- **PERFORMANCE**: Line 186-204: N+1 queries in getUserUsage — 3 queries per residence +- **SECURITY/BUSINESS**: Line 371: Apple validation failure grants 1-month free Pro +- **SECURITY/BUSINESS**: Line 381: Apple validation not configured grants 1-year free Pro +- **SECURITY/BUSINESS**: Line 429, 449: Same for Google — errors/misconfiguration grant free Pro +- Line 7: Uses stdlib "log" instead of zerolog + +### services/task_button_types.go (85 lines) - Clean, uses predicates correctly +### services/task_service.go (1092 lines) +- **DATA INTEGRITY**: Line 601: If task update fails after completion creation, error only logged not returned — stale NextDueDate/InProgress +- Line 735: Goroutine in QuickComplete (service method) — inconsistent with synchronous CreateCompletion +- Line 773: Unbounded goroutine creation per user for notifications +- Line 790: Fail-open email notification on error (intentional but risky) +- **SECURITY**: Line 857-862: resolveImageFilePath has NO path traversal validation + +### services/task_template_service.go (70 lines) - Errors returned raw (not wrapped with apperrors) +### services/user_service.go (88 lines) - Returns nil instead of empty slice (JSON null vs []) + +### task/categorization/chain.go (359 lines) - Clean chain-of-responsibility +### task/predicates/predicates.go (217 lines) +- Line 210: IsRecurring requires Frequency preloaded — returns false without it + +### task/scopes/scopes.go (270 lines) +- Line 118: ScopeOverdue doesn't exclude InProgress — differs from categorization chain + +### task/task.go (261 lines) - Clean facade re-exports + +### testutil/testutil.go (359 lines) +- Line 86: json.Marshal error ignored in MakeRequest +- Line 92: http.NewRequest error ignored + +### validator/validator.go (103 lines) - Clean + +### worker/jobs/email_jobs.go (118 lines) - Clean +### worker/jobs/handler.go (810 lines) +- Lines 95-106, 193-204: Direct DB access bypasses repository layer +- Line 627-635: Raw SQL with fmt.Sprintf (not currently user-supplied but fragile) +- Line 154, 251: O(N*M) lookup instead of map + +### worker/scheduler.go (240 lines) +- Line 200-212: Cron schedules at fixed UTC times may conflict with smart reminder system — potential duplicate notifications + +### pkg/utils/logger.go (132 lines) - Panic recovery bypasses apperrors.HTTPErrorHandler diff --git a/deploy/.gitignore b/deploy/.gitignore new file mode 100644 index 0000000..ee6276a --- /dev/null +++ b/deploy/.gitignore @@ -0,0 +1,18 @@ +# Local deploy inputs (copy from *.example files) +cluster.env +registry.env +prod.env + +# Local secret material +secrets/*.txt +secrets/*.p8 + +# Keep templates and docs tracked +!*.example +!README.md +!shit_deploy_cant_do.md +!swarm-stack.prod.yml +!scripts/ +!scripts/** +!secrets/*.example +!secrets/README.md diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 0000000..3cdcd54 --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,134 @@ +# Deploy Folder + +This folder is the full production deploy toolkit for `myCribAPI-go`. + +Run deploy with: + +```bash +./.deploy_prod +``` + +The script will refuse to run until all required values are set. + +## First-Time Prerequisite: Create The Swarm Cluster + +You must do this once before `./.deploy_prod` can work. + +1. SSH to manager #1 and initialize Swarm: + +```bash +docker swarm init --advertise-addr +``` + +2. On manager #1, get join commands: + +```bash +docker swarm join-token manager +docker swarm join-token worker +``` + +3. SSH to each additional node and run the appropriate `docker swarm join ...` command. + +4. Verify from manager #1: + +```bash +docker node ls +``` + +## Security Requirements Before Public Launch + +Use this as a mandatory checklist before you route production traffic. + +### 1) Firewall Rules (Node-Level) + +Apply firewall rules to all Swarm nodes: + +- SSH port (for example `2222/tcp`): your IP only +- `80/tcp`, `443/tcp`: Hetzner LB only (or Cloudflare IP ranges only if no LB) +- `2377/tcp`: Swarm nodes only +- `7946/tcp,udp`: Swarm nodes only +- `4789/udp`: Swarm nodes only +- Everything else: blocked + +### 2) SSH Hardening + +On each node, harden `/etc/ssh/sshd_config`: + +```text +Port 2222 +PermitRootLogin no +PasswordAuthentication no +PubkeyAuthentication yes +AllowUsers deploy +``` + +### 3) Cloudflare Origin Lockdown + +- Keep public DNS records proxied (orange cloud on). +- Point Cloudflare to LB, not node IPs. +- Do not publish Swarm node IPs in DNS. +- Enforce firewall source restrictions so public traffic cannot bypass Cloudflare/LB. + +### 4) Secrets Policy + +- Keep runtime secrets in Docker Swarm secrets only. +- Do not put production secrets in git or plain `.env` files. +- `./.deploy_prod` already creates versioned Swarm secrets from files in `deploy/secrets/`. +- Rotate secrets after incidents or credential exposure. + +### 5) Data Path Security + +- Neon/Postgres: `DB_SSLMODE=require`, strong DB password, Neon IP allowlist limited to node IPs. +- Backblaze B2: HTTPS only, scoped app keys (not master key), least-privilege bucket access. +- Swarm overlay: encrypted network enabled in stack (`driver_opts.encrypted: "true"`). + +### 6) Dozzle Hardening + +- Keep Dozzle private (no public DNS/ingress). +- Put auth/SSO in front (Cloudflare Access or equivalent). +- Prefer a Docker socket proxy with restricted read-only scope. + +### 7) Backup + Restore Readiness + +- Postgres PITR path tested in staging. +- Redis persistence enabled and restore path tested. +- Written runbook for restore and secret rotation. +- Named owner for incident response. + +## Files You Fill In + +Paste your values into these files: + +- `deploy/cluster.env` +- `deploy/registry.env` +- `deploy/prod.env` +- `deploy/secrets/postgres_password.txt` +- `deploy/secrets/secret_key.txt` +- `deploy/secrets/email_host_password.txt` +- `deploy/secrets/fcm_server_key.txt` +- `deploy/secrets/apns_auth_key.p8` + +If one is missing, the deploy script auto-copies it from its `.example` template and exits so you can fill it. + +## What `./.deploy_prod` Does + +1. Validates all required config files and credentials. +2. Builds and pushes `api`, `worker`, and `admin` images. +3. Uploads deploy bundle to your Swarm manager over SSH. +4. Creates versioned Docker secrets on the manager. +5. Deploys the stack with `docker stack deploy --with-registry-auth`. +6. Waits until service replicas converge. +7. Runs an HTTP health check (if `DEPLOY_HEALTHCHECK_URL` is set). + +## Useful Flags + +Environment flags: + +- `SKIP_BUILD=1 ./.deploy_prod` to deploy already-pushed images. +- `SKIP_HEALTHCHECK=1 ./.deploy_prod` to skip final URL check. +- `DEPLOY_TAG= ./.deploy_prod` to deploy a specific image tag. + +## Important + +- `deploy/shit_deploy_cant_do.md` lists the manual tasks this script cannot automate. +- Keep real credentials and secret files out of git. diff --git a/deploy/cluster.env.example b/deploy/cluster.env.example new file mode 100644 index 0000000..a0202ef --- /dev/null +++ b/deploy/cluster.env.example @@ -0,0 +1,22 @@ +# Swarm manager connection +DEPLOY_MANAGER_HOST=CHANGEME_MANAGER_IP_OR_HOSTNAME +DEPLOY_MANAGER_USER=deploy +DEPLOY_MANAGER_SSH_PORT=22 +DEPLOY_SSH_KEY_PATH=~/.ssh/id_ed25519 + +# Stack settings +DEPLOY_STACK_NAME=casera +DEPLOY_REMOTE_DIR=/opt/casera/deploy +DEPLOY_WAIT_SECONDS=420 +DEPLOY_HEALTHCHECK_URL=https://api.casera.app/api/health/ + +# Replicas and published ports +API_REPLICAS=3 +WORKER_REPLICAS=2 +ADMIN_REPLICAS=1 +API_PORT=8000 +ADMIN_PORT=3000 +DOZZLE_PORT=9999 + +# Build behavior +PUSH_LATEST_TAG=true diff --git a/deploy/prod.env.example b/deploy/prod.env.example new file mode 100644 index 0000000..68ab21b --- /dev/null +++ b/deploy/prod.env.example @@ -0,0 +1,73 @@ +# API service settings +DEBUG=false +ALLOWED_HOSTS=api.casera.app,casera.app +CORS_ALLOWED_ORIGINS=https://casera.app,https://admin.casera.app +TIMEZONE=UTC +BASE_URL=https://casera.app +PORT=8000 + +# Admin service settings +NEXT_PUBLIC_API_URL=https://api.casera.app +ADMIN_PANEL_URL=https://admin.casera.app + +# Database (Neon recommended) +DB_HOST=CHANGEME_NEON_HOST +DB_PORT=5432 +POSTGRES_USER=CHANGEME_DB_USER +POSTGRES_DB=casera +DB_SSLMODE=require +DB_MAX_OPEN_CONNS=25 +DB_MAX_IDLE_CONNS=10 +DB_MAX_LIFETIME=600s + +# Redis (in stack defaults to redis://redis:6379/0) +REDIS_URL=redis://redis:6379/0 +REDIS_DB=0 + +# Email (password goes in deploy/secrets/email_host_password.txt) +EMAIL_HOST=smtp.gmail.com +EMAIL_PORT=587 +EMAIL_USE_TLS=true +EMAIL_HOST_USER=CHANGEME_EMAIL_USER +DEFAULT_FROM_EMAIL=Casera + +# Push notifications +# APNS private key goes in deploy/secrets/apns_auth_key.p8 +APNS_AUTH_KEY_ID=CHANGEME_APNS_KEY_ID +APNS_TEAM_ID=CHANGEME_APNS_TEAM_ID +APNS_TOPIC=com.tt.casera +APNS_USE_SANDBOX=false +APNS_PRODUCTION=true + +# Worker schedules (UTC) +TASK_REMINDER_HOUR=14 +OVERDUE_REMINDER_HOUR=15 +DAILY_DIGEST_HOUR=3 + +# Storage +STORAGE_UPLOAD_DIR=/app/uploads +STORAGE_BASE_URL=/uploads +STORAGE_MAX_FILE_SIZE=10485760 +STORAGE_ALLOWED_TYPES=image/jpeg,image/png,image/gif,image/webp,application/pdf + +# Feature flags +FEATURE_PUSH_ENABLED=true +FEATURE_EMAIL_ENABLED=true +FEATURE_WEBHOOKS_ENABLED=true +FEATURE_ONBOARDING_EMAILS_ENABLED=true +FEATURE_PDF_REPORTS_ENABLED=true +FEATURE_WORKER_ENABLED=true + +# Optional auth/iap values +APPLE_CLIENT_ID= +APPLE_TEAM_ID= +GOOGLE_CLIENT_ID= +GOOGLE_ANDROID_CLIENT_ID= +GOOGLE_IOS_CLIENT_ID= +APPLE_IAP_KEY_ID= +APPLE_IAP_ISSUER_ID= +APPLE_IAP_BUNDLE_ID= +APPLE_IAP_KEY_PATH= +APPLE_IAP_SANDBOX=false +GOOGLE_IAP_PACKAGE_NAME= +GOOGLE_IAP_SERVICE_ACCOUNT_PATH= diff --git a/deploy/registry.env.example b/deploy/registry.env.example new file mode 100644 index 0000000..7347533 --- /dev/null +++ b/deploy/registry.env.example @@ -0,0 +1,11 @@ +# Container registry used for deploy images. +# For GHCR: +# REGISTRY=ghcr.io +# REGISTRY_NAMESPACE= +# REGISTRY_USERNAME= +# REGISTRY_TOKEN= + +REGISTRY=ghcr.io +REGISTRY_NAMESPACE=CHANGEME_NAMESPACE +REGISTRY_USERNAME=CHANGEME_USERNAME +REGISTRY_TOKEN=CHANGEME_TOKEN diff --git a/deploy/scripts/deploy_prod.sh b/deploy/scripts/deploy_prod.sh new file mode 100755 index 0000000..f862f85 --- /dev/null +++ b/deploy/scripts/deploy_prod.sh @@ -0,0 +1,397 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DEPLOY_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" +REPO_DIR="$(cd "${DEPLOY_DIR}/.." && pwd)" + +STACK_TEMPLATE="${DEPLOY_DIR}/swarm-stack.prod.yml" +CLUSTER_ENV="${DEPLOY_DIR}/cluster.env" +REGISTRY_ENV="${DEPLOY_DIR}/registry.env" +PROD_ENV="${DEPLOY_DIR}/prod.env" + +SECRET_POSTGRES="${DEPLOY_DIR}/secrets/postgres_password.txt" +SECRET_APP_KEY="${DEPLOY_DIR}/secrets/secret_key.txt" +SECRET_EMAIL_PASS="${DEPLOY_DIR}/secrets/email_host_password.txt" +SECRET_FCM_KEY="${DEPLOY_DIR}/secrets/fcm_server_key.txt" +SECRET_APNS_KEY="${DEPLOY_DIR}/secrets/apns_auth_key.p8" + +SKIP_BUILD="${SKIP_BUILD:-0}" +SKIP_HEALTHCHECK="${SKIP_HEALTHCHECK:-0}" + +log() { + printf '[deploy] %s\n' "$*" +} + +warn() { + printf '[deploy][warn] %s\n' "$*" >&2 +} + +die() { + printf '[deploy][error] %s\n' "$*" >&2 + exit 1 +} + +require_cmd() { + command -v "$1" >/dev/null 2>&1 || die "Missing required command: $1" +} + +contains_placeholder() { + local value="$1" + [[ -z "${value}" ]] && return 0 + local lowered + lowered="$(printf '%s' "${value}" | tr '[:upper:]' '[:lower:]')" + case "${lowered}" in + *changeme*|*replace_me*|*example.com*|*your-*|*todo*|*fill_me*|*paste_here*) + return 0 + ;; + *) + return 1 + ;; + esac +} + +ensure_file_from_example() { + local path="$1" + local example="${path}.example" + if [[ -f "${path}" ]]; then + return + fi + if [[ -f "${example}" ]]; then + cp "${example}" "${path}" + die "Created ${path} from template. Fill it in and rerun." + fi + die "Missing required file: ${path}" +} + +require_var() { + local name="$1" + local value="${!name:-}" + [[ -n "${value}" ]] || die "Missing required value: ${name}" + if contains_placeholder "${value}"; then + die "Value still uses placeholder text: ${name}=${value}" + fi +} + +require_secret_file() { + local path="$1" + local label="$2" + ensure_file_from_example "${path}" + local contents + contents="$(tr -d '\r' < "${path}" | sed 's/[[:space:]]*$//')" + [[ -n "${contents}" ]] || die "Secret file is empty: ${path}" + if contains_placeholder "${contents}"; then + die "Secret file still has placeholder text (${label}): ${path}" + fi +} + +print_usage() { + cat <<'EOF' +Usage: + ./.deploy_prod + +Optional environment flags: + SKIP_BUILD=1 Deploy existing image tags without rebuilding/pushing. + SKIP_HEALTHCHECK=1 Skip final HTTP health check. + DEPLOY_TAG= Override image tag (default: git short sha). +EOF +} + +while (($# > 0)); do + case "$1" in + -h|--help) + print_usage + exit 0 + ;; + *) + die "Unknown argument: $1" + ;; + esac +done + +require_cmd docker +require_cmd ssh +require_cmd scp +require_cmd git +require_cmd awk +require_cmd sed +require_cmd grep +require_cmd mktemp +require_cmd date +require_cmd curl + +ensure_file_from_example "${CLUSTER_ENV}" +ensure_file_from_example "${REGISTRY_ENV}" +ensure_file_from_example "${PROD_ENV}" + +require_secret_file "${SECRET_POSTGRES}" "Postgres password" +require_secret_file "${SECRET_APP_KEY}" "SECRET_KEY" +require_secret_file "${SECRET_EMAIL_PASS}" "SMTP password" +require_secret_file "${SECRET_FCM_KEY}" "FCM server key" +require_secret_file "${SECRET_APNS_KEY}" "APNS private key" + +set -a +# shellcheck disable=SC1090 +source "${CLUSTER_ENV}" +# shellcheck disable=SC1090 +source "${REGISTRY_ENV}" +# shellcheck disable=SC1090 +source "${PROD_ENV}" +set +a + +DEPLOY_MANAGER_SSH_PORT="${DEPLOY_MANAGER_SSH_PORT:-22}" +DEPLOY_STACK_NAME="${DEPLOY_STACK_NAME:-casera}" +DEPLOY_REMOTE_DIR="${DEPLOY_REMOTE_DIR:-/opt/casera/deploy}" +DEPLOY_WAIT_SECONDS="${DEPLOY_WAIT_SECONDS:-420}" +DEPLOY_TAG="${DEPLOY_TAG:-$(git -C "${REPO_DIR}" rev-parse --short HEAD)}" +PUSH_LATEST_TAG="${PUSH_LATEST_TAG:-true}" + +require_var DEPLOY_MANAGER_HOST +require_var DEPLOY_MANAGER_USER +require_var DEPLOY_STACK_NAME +require_var DEPLOY_REMOTE_DIR +require_var REGISTRY +require_var REGISTRY_NAMESPACE +require_var REGISTRY_USERNAME +require_var REGISTRY_TOKEN + +require_var ALLOWED_HOSTS +require_var CORS_ALLOWED_ORIGINS +require_var BASE_URL +require_var NEXT_PUBLIC_API_URL +require_var DB_HOST +require_var DB_PORT +require_var POSTGRES_USER +require_var POSTGRES_DB +require_var DB_SSLMODE +require_var REDIS_URL +require_var EMAIL_HOST +require_var EMAIL_PORT +require_var EMAIL_HOST_USER +require_var DEFAULT_FROM_EMAIL +require_var APNS_AUTH_KEY_ID +require_var APNS_TEAM_ID +require_var APNS_TOPIC + +if [[ ! "$(tr -d '\r\n' < "${SECRET_APNS_KEY}")" =~ BEGIN[[:space:]]+PRIVATE[[:space:]]+KEY ]]; then + die "APNS key file does not look like a private key: ${SECRET_APNS_KEY}" +fi + +app_secret_len="$(tr -d '\r\n' < "${SECRET_APP_KEY}" | wc -c | tr -d ' ')" +if (( app_secret_len < 32 )); then + die "deploy/secrets/secret_key.txt must be at least 32 characters." +fi + +REGISTRY_PREFIX="${REGISTRY%/}/${REGISTRY_NAMESPACE#/}" +API_IMAGE="${REGISTRY_PREFIX}/casera-api:${DEPLOY_TAG}" +WORKER_IMAGE="${REGISTRY_PREFIX}/casera-worker:${DEPLOY_TAG}" +ADMIN_IMAGE="${REGISTRY_PREFIX}/casera-admin:${DEPLOY_TAG}" + +SSH_KEY_PATH="${DEPLOY_SSH_KEY_PATH:-}" +if [[ -n "${SSH_KEY_PATH}" ]]; then + SSH_KEY_PATH="${SSH_KEY_PATH/#\~/${HOME}}" +fi + +SSH_TARGET="${DEPLOY_MANAGER_USER}@${DEPLOY_MANAGER_HOST}" +SSH_OPTS=(-p "${DEPLOY_MANAGER_SSH_PORT}") +SCP_OPTS=(-P "${DEPLOY_MANAGER_SSH_PORT}") +if [[ -n "${SSH_KEY_PATH}" ]]; then + SSH_OPTS+=(-i "${SSH_KEY_PATH}") + SCP_OPTS+=(-i "${SSH_KEY_PATH}") +fi + +log "Validating SSH access to ${SSH_TARGET}" +if ! ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "echo ok" >/dev/null 2>&1; then + die "SSH connection failed to ${SSH_TARGET}" +fi + +remote_swarm_state="$(ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "docker info --format '{{.Swarm.LocalNodeState}} {{.Swarm.ControlAvailable}}'" 2>/dev/null || true)" +if [[ -z "${remote_swarm_state}" ]]; then + die "Could not read Docker Swarm state on manager. Is Docker installed/running?" +fi + +if [[ "${remote_swarm_state}" != "active true" ]]; then + die "Remote node must be a Swarm manager. Got: ${remote_swarm_state}" +fi + +if [[ "${SKIP_BUILD}" != "1" ]]; then + log "Logging in to ${REGISTRY}" + printf '%s' "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" -u "${REGISTRY_USERNAME}" --password-stdin >/dev/null + + log "Building API image ${API_IMAGE}" + docker build --target api -t "${API_IMAGE}" "${REPO_DIR}" + log "Building Worker image ${WORKER_IMAGE}" + docker build --target worker -t "${WORKER_IMAGE}" "${REPO_DIR}" + log "Building Admin image ${ADMIN_IMAGE}" + docker build --target admin -t "${ADMIN_IMAGE}" "${REPO_DIR}" + + log "Pushing deploy images" + docker push "${API_IMAGE}" + docker push "${WORKER_IMAGE}" + docker push "${ADMIN_IMAGE}" + + if [[ "${PUSH_LATEST_TAG}" == "true" ]]; then + log "Updating :latest tags" + docker tag "${API_IMAGE}" "${REGISTRY_PREFIX}/casera-api:latest" + docker tag "${WORKER_IMAGE}" "${REGISTRY_PREFIX}/casera-worker:latest" + docker tag "${ADMIN_IMAGE}" "${REGISTRY_PREFIX}/casera-admin:latest" + docker push "${REGISTRY_PREFIX}/casera-api:latest" + docker push "${REGISTRY_PREFIX}/casera-worker:latest" + docker push "${REGISTRY_PREFIX}/casera-admin:latest" + fi +else + warn "SKIP_BUILD=1 set. Using prebuilt images for tag: ${DEPLOY_TAG}" +fi + +DEPLOY_ID_RAW="${DEPLOY_TAG}-$(date +%Y%m%d%H%M%S)" +DEPLOY_ID="$(printf '%s' "${DEPLOY_ID_RAW}" | tr -c 'a-zA-Z0-9_-' '_')" + +POSTGRES_PASSWORD_SECRET="${DEPLOY_STACK_NAME}_postgres_password_${DEPLOY_ID}" +SECRET_KEY_SECRET="${DEPLOY_STACK_NAME}_secret_key_${DEPLOY_ID}" +EMAIL_HOST_PASSWORD_SECRET="${DEPLOY_STACK_NAME}_email_host_password_${DEPLOY_ID}" +FCM_SERVER_KEY_SECRET="${DEPLOY_STACK_NAME}_fcm_server_key_${DEPLOY_ID}" +APNS_AUTH_KEY_SECRET="${DEPLOY_STACK_NAME}_apns_auth_key_${DEPLOY_ID}" + +TMP_DIR="$(mktemp -d)" +cleanup() { + rm -rf "${TMP_DIR}" +} +trap cleanup EXIT + +cp "${STACK_TEMPLATE}" "${TMP_DIR}/swarm-stack.prod.yml" +cp "${PROD_ENV}" "${TMP_DIR}/prod.env" +cp "${REGISTRY_ENV}" "${TMP_DIR}/registry.env" +mkdir -p "${TMP_DIR}/secrets" +cp "${SECRET_POSTGRES}" "${TMP_DIR}/secrets/postgres_password.txt" +cp "${SECRET_APP_KEY}" "${TMP_DIR}/secrets/secret_key.txt" +cp "${SECRET_EMAIL_PASS}" "${TMP_DIR}/secrets/email_host_password.txt" +cp "${SECRET_FCM_KEY}" "${TMP_DIR}/secrets/fcm_server_key.txt" +cp "${SECRET_APNS_KEY}" "${TMP_DIR}/secrets/apns_auth_key.p8" + +cat > "${TMP_DIR}/runtime.env" </dev/null 2>&1; then + echo "[remote] secret exists: ${name}" + else + docker secret create "${name}" "${src}" >/dev/null + echo "[remote] created secret: ${name}" + fi +} + +printf '%s' "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" -u "${REGISTRY_USERNAME}" --password-stdin >/dev/null +rm -f "${REMOTE_DIR}/registry.env" + +create_secret "${POSTGRES_PASSWORD_SECRET}" "${REMOTE_DIR}/secrets/postgres_password.txt" +create_secret "${SECRET_KEY_SECRET}" "${REMOTE_DIR}/secrets/secret_key.txt" +create_secret "${EMAIL_HOST_PASSWORD_SECRET}" "${REMOTE_DIR}/secrets/email_host_password.txt" +create_secret "${FCM_SERVER_KEY_SECRET}" "${REMOTE_DIR}/secrets/fcm_server_key.txt" +create_secret "${APNS_AUTH_KEY_SECRET}" "${REMOTE_DIR}/secrets/apns_auth_key.p8" + +rm -f "${REMOTE_DIR}/secrets/postgres_password.txt" +rm -f "${REMOTE_DIR}/secrets/secret_key.txt" +rm -f "${REMOTE_DIR}/secrets/email_host_password.txt" +rm -f "${REMOTE_DIR}/secrets/fcm_server_key.txt" +rm -f "${REMOTE_DIR}/secrets/apns_auth_key.p8" + +set -a +# shellcheck disable=SC1090 +source "${REMOTE_DIR}/prod.env" +# shellcheck disable=SC1090 +source "${REMOTE_DIR}/runtime.env" +set +a + +docker stack deploy --with-registry-auth -c "${REMOTE_DIR}/swarm-stack.prod.yml" "${STACK_NAME}" +EOF + +log "Waiting for stack convergence (${DEPLOY_WAIT_SECONDS}s max)" +start_epoch="$(date +%s)" +while true; do + services="$(ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "docker stack services '${DEPLOY_STACK_NAME}' --format '{{.Name}} {{.Replicas}}'" 2>/dev/null || true)" + + if [[ -n "${services}" ]]; then + all_ready=1 + while IFS=' ' read -r svc replicas; do + [[ -z "${svc}" ]] && continue + current="${replicas%%/*}" + desired="${replicas##*/}" + if [[ "${desired}" == "0" ]]; then + continue + fi + if [[ "${current}" != "${desired}" ]]; then + all_ready=0 + fi + done <<< "${services}" + + if [[ "${all_ready}" -eq 1 ]]; then + break + fi + fi + + now_epoch="$(date +%s)" + elapsed=$((now_epoch - start_epoch)) + if (( elapsed >= DEPLOY_WAIT_SECONDS )); then + die "Timed out waiting for stack to converge. Check: ssh ${SSH_TARGET} docker stack services ${DEPLOY_STACK_NAME}" + fi + + sleep 10 +done + +if [[ "${SKIP_HEALTHCHECK}" != "1" && -n "${DEPLOY_HEALTHCHECK_URL:-}" ]]; then + log "Running health check: ${DEPLOY_HEALTHCHECK_URL}" + curl -fsS --max-time 20 "${DEPLOY_HEALTHCHECK_URL}" >/dev/null +fi + +log "Deploy completed successfully." +log "Stack: ${DEPLOY_STACK_NAME}" +log "Images:" +log " ${API_IMAGE}" +log " ${WORKER_IMAGE}" +log " ${ADMIN_IMAGE}" diff --git a/deploy/secrets/README.md b/deploy/secrets/README.md new file mode 100644 index 0000000..1cafef0 --- /dev/null +++ b/deploy/secrets/README.md @@ -0,0 +1,11 @@ +# Secrets Directory + +Create these files (copy from `.example` files): + +- `deploy/secrets/postgres_password.txt` +- `deploy/secrets/secret_key.txt` +- `deploy/secrets/email_host_password.txt` +- `deploy/secrets/fcm_server_key.txt` +- `deploy/secrets/apns_auth_key.p8` + +These are consumed by `./.deploy_prod` and converted into Docker Swarm secrets. diff --git a/deploy/secrets/apns_auth_key.p8.example b/deploy/secrets/apns_auth_key.p8.example new file mode 100644 index 0000000..de98515 --- /dev/null +++ b/deploy/secrets/apns_auth_key.p8.example @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +CHANGEME_APNS_PRIVATE_KEY +-----END PRIVATE KEY----- diff --git a/deploy/secrets/email_host_password.txt.example b/deploy/secrets/email_host_password.txt.example new file mode 100644 index 0000000..23472ef --- /dev/null +++ b/deploy/secrets/email_host_password.txt.example @@ -0,0 +1 @@ +CHANGEME_SMTP_PASSWORD diff --git a/deploy/secrets/fcm_server_key.txt.example b/deploy/secrets/fcm_server_key.txt.example new file mode 100644 index 0000000..246e068 --- /dev/null +++ b/deploy/secrets/fcm_server_key.txt.example @@ -0,0 +1 @@ +CHANGEME_FCM_SERVER_KEY diff --git a/deploy/secrets/postgres_password.txt.example b/deploy/secrets/postgres_password.txt.example new file mode 100644 index 0000000..a16a10a --- /dev/null +++ b/deploy/secrets/postgres_password.txt.example @@ -0,0 +1 @@ +CHANGEME_DATABASE_PASSWORD diff --git a/deploy/secrets/secret_key.txt.example b/deploy/secrets/secret_key.txt.example new file mode 100644 index 0000000..88efe2c --- /dev/null +++ b/deploy/secrets/secret_key.txt.example @@ -0,0 +1 @@ +CHANGEME_SECRET_KEY_MIN_32_CHARS diff --git a/deploy/shit_deploy_cant_do.md b/deploy/shit_deploy_cant_do.md new file mode 100644 index 0000000..f2de212 --- /dev/null +++ b/deploy/shit_deploy_cant_do.md @@ -0,0 +1,67 @@ +# Shit Deploy Can't Do + +This is everything `./.deploy_prod` cannot safely automate for you. + +## 1. Create Infrastructure + +Step: +Create Hetzner servers, networking, and load balancer. + +Reason: +The script only deploys app workloads. It cannot create paid cloud resources without cloud API credentials and IaC wiring. + +## 2. Join Nodes To Swarm + +Step: +Run `docker swarm init` on the first manager and `docker swarm join` on other nodes. + +Reason: +Joining nodes requires one-time bootstrap tokens and host-level control. + +## 3. Configure Firewall And Origin Restrictions + +Step: +Set firewall rules so only expected ingress paths can reach your nodes. + +Reason: +Firewall policies live in provider networking controls, outside this repo. + +## 4. Configure DNS / Cloudflare + +Step: +Point DNS at LB, enable proxying, set SSL mode, and lock down origin access. + +Reason: +DNS and CDN settings are account-level operations in Cloudflare, not deploy-time app actions. + +## 5. Configure External Services + +Step: +Create and configure Neon, B2, email provider, APNS, and FCM credentials. + +Reason: +These credentials are issued in vendor dashboards and must be manually generated/rotated. + +## 6. Seed SSH Trust + +Step: +Ensure your local machine can SSH to the manager with the key in `deploy/cluster.env`. + +Reason: +The script assumes SSH already works; it cannot grant itself SSH access. + +## 7. First-Time Smoke Testing Beyond `/api/health/` + +Step: +Manually test login, push, background jobs, and admin panel flows after first deploy. + +Reason: +Automated health checks prove container readiness, not end-to-end business behavior. + +## 8. Safe Secret Garbage Collection + +Step: +Periodically remove old versioned Docker secrets that are no longer referenced. + +Reason: +This deploy script creates versioned secrets for safe rollouts and does not auto-delete old ones to avoid breaking running services. diff --git a/deploy/swarm-stack.prod.yml b/deploy/swarm-stack.prod.yml new file mode 100644 index 0000000..b4b2a06 --- /dev/null +++ b/deploy/swarm-stack.prod.yml @@ -0,0 +1,288 @@ +version: "3.8" + +services: + redis: + image: redis:7-alpine + command: redis-server --appendonly yes --appendfsync everysec + volumes: + - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + deploy: + replicas: 1 + restart_policy: + condition: any + delay: 5s + placement: + max_replicas_per_node: 1 + networks: + - casera-network + + api: + image: ${API_IMAGE} + ports: + - target: 8000 + published: ${API_PORT} + protocol: tcp + mode: ingress + environment: + PORT: "8000" + DEBUG: "${DEBUG}" + ALLOWED_HOSTS: "${ALLOWED_HOSTS}" + CORS_ALLOWED_ORIGINS: "${CORS_ALLOWED_ORIGINS}" + TIMEZONE: "${TIMEZONE}" + BASE_URL: "${BASE_URL}" + ADMIN_PANEL_URL: "${ADMIN_PANEL_URL}" + + DB_HOST: "${DB_HOST}" + DB_PORT: "${DB_PORT}" + POSTGRES_USER: "${POSTGRES_USER}" + POSTGRES_DB: "${POSTGRES_DB}" + DB_SSLMODE: "${DB_SSLMODE}" + DB_MAX_OPEN_CONNS: "${DB_MAX_OPEN_CONNS}" + DB_MAX_IDLE_CONNS: "${DB_MAX_IDLE_CONNS}" + DB_MAX_LIFETIME: "${DB_MAX_LIFETIME}" + + REDIS_URL: "${REDIS_URL}" + REDIS_DB: "${REDIS_DB}" + + EMAIL_HOST: "${EMAIL_HOST}" + EMAIL_PORT: "${EMAIL_PORT}" + EMAIL_HOST_USER: "${EMAIL_HOST_USER}" + DEFAULT_FROM_EMAIL: "${DEFAULT_FROM_EMAIL}" + EMAIL_USE_TLS: "${EMAIL_USE_TLS}" + + APNS_AUTH_KEY_PATH: "/run/secrets/apns_auth_key" + APNS_AUTH_KEY_ID: "${APNS_AUTH_KEY_ID}" + APNS_TEAM_ID: "${APNS_TEAM_ID}" + APNS_TOPIC: "${APNS_TOPIC}" + APNS_USE_SANDBOX: "${APNS_USE_SANDBOX}" + APNS_PRODUCTION: "${APNS_PRODUCTION}" + + STORAGE_UPLOAD_DIR: "${STORAGE_UPLOAD_DIR}" + STORAGE_BASE_URL: "${STORAGE_BASE_URL}" + STORAGE_MAX_FILE_SIZE: "${STORAGE_MAX_FILE_SIZE}" + STORAGE_ALLOWED_TYPES: "${STORAGE_ALLOWED_TYPES}" + + FEATURE_PUSH_ENABLED: "${FEATURE_PUSH_ENABLED}" + FEATURE_EMAIL_ENABLED: "${FEATURE_EMAIL_ENABLED}" + FEATURE_WEBHOOKS_ENABLED: "${FEATURE_WEBHOOKS_ENABLED}" + FEATURE_ONBOARDING_EMAILS_ENABLED: "${FEATURE_ONBOARDING_EMAILS_ENABLED}" + FEATURE_PDF_REPORTS_ENABLED: "${FEATURE_PDF_REPORTS_ENABLED}" + FEATURE_WORKER_ENABLED: "${FEATURE_WORKER_ENABLED}" + + APPLE_CLIENT_ID: "${APPLE_CLIENT_ID}" + APPLE_TEAM_ID: "${APPLE_TEAM_ID}" + GOOGLE_CLIENT_ID: "${GOOGLE_CLIENT_ID}" + GOOGLE_ANDROID_CLIENT_ID: "${GOOGLE_ANDROID_CLIENT_ID}" + GOOGLE_IOS_CLIENT_ID: "${GOOGLE_IOS_CLIENT_ID}" + APPLE_IAP_KEY_PATH: "${APPLE_IAP_KEY_PATH}" + APPLE_IAP_KEY_ID: "${APPLE_IAP_KEY_ID}" + APPLE_IAP_ISSUER_ID: "${APPLE_IAP_ISSUER_ID}" + APPLE_IAP_BUNDLE_ID: "${APPLE_IAP_BUNDLE_ID}" + APPLE_IAP_SANDBOX: "${APPLE_IAP_SANDBOX}" + GOOGLE_IAP_SERVICE_ACCOUNT_PATH: "${GOOGLE_IAP_SERVICE_ACCOUNT_PATH}" + GOOGLE_IAP_PACKAGE_NAME: "${GOOGLE_IAP_PACKAGE_NAME}" + command: + - /bin/sh + - -lc + - | + set -eu + export POSTGRES_PASSWORD="$$(cat /run/secrets/postgres_password)" + export SECRET_KEY="$$(cat /run/secrets/secret_key)" + export EMAIL_HOST_PASSWORD="$$(cat /run/secrets/email_host_password)" + export FCM_SERVER_KEY="$$(cat /run/secrets/fcm_server_key)" + exec /app/api + secrets: + - source: ${POSTGRES_PASSWORD_SECRET} + target: postgres_password + - source: ${SECRET_KEY_SECRET} + target: secret_key + - source: ${EMAIL_HOST_PASSWORD_SECRET} + target: email_host_password + - source: ${FCM_SERVER_KEY_SECRET} + target: fcm_server_key + - source: ${APNS_AUTH_KEY_SECRET} + target: apns_auth_key + volumes: + - uploads:/app/uploads + healthcheck: + test: ["CMD", "curl", "-f", "http://127.0.0.1:8000/api/health/"] + interval: 30s + timeout: 10s + start_period: 15s + retries: 3 + deploy: + replicas: ${API_REPLICAS} + restart_policy: + condition: any + delay: 5s + update_config: + parallelism: 1 + delay: 10s + order: start-first + rollback_config: + parallelism: 1 + delay: 5s + order: stop-first + networks: + - casera-network + + admin: + image: ${ADMIN_IMAGE} + ports: + - target: 3000 + published: ${ADMIN_PORT} + protocol: tcp + mode: ingress + environment: + PORT: "3000" + HOSTNAME: "0.0.0.0" + NEXT_PUBLIC_API_URL: "${NEXT_PUBLIC_API_URL}" + healthcheck: + test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://127.0.0.1:3000/admin/"] + interval: 30s + timeout: 10s + retries: 3 + deploy: + replicas: ${ADMIN_REPLICAS} + restart_policy: + condition: any + delay: 5s + update_config: + parallelism: 1 + delay: 10s + order: start-first + rollback_config: + parallelism: 1 + delay: 5s + order: stop-first + networks: + - casera-network + + worker: + image: ${WORKER_IMAGE} + environment: + DB_HOST: "${DB_HOST}" + DB_PORT: "${DB_PORT}" + POSTGRES_USER: "${POSTGRES_USER}" + POSTGRES_DB: "${POSTGRES_DB}" + DB_SSLMODE: "${DB_SSLMODE}" + DB_MAX_OPEN_CONNS: "${DB_MAX_OPEN_CONNS}" + DB_MAX_IDLE_CONNS: "${DB_MAX_IDLE_CONNS}" + DB_MAX_LIFETIME: "${DB_MAX_LIFETIME}" + + REDIS_URL: "${REDIS_URL}" + REDIS_DB: "${REDIS_DB}" + + EMAIL_HOST: "${EMAIL_HOST}" + EMAIL_PORT: "${EMAIL_PORT}" + EMAIL_HOST_USER: "${EMAIL_HOST_USER}" + DEFAULT_FROM_EMAIL: "${DEFAULT_FROM_EMAIL}" + EMAIL_USE_TLS: "${EMAIL_USE_TLS}" + + APNS_AUTH_KEY_PATH: "/run/secrets/apns_auth_key" + APNS_AUTH_KEY_ID: "${APNS_AUTH_KEY_ID}" + APNS_TEAM_ID: "${APNS_TEAM_ID}" + APNS_TOPIC: "${APNS_TOPIC}" + APNS_USE_SANDBOX: "${APNS_USE_SANDBOX}" + APNS_PRODUCTION: "${APNS_PRODUCTION}" + + TASK_REMINDER_HOUR: "${TASK_REMINDER_HOUR}" + OVERDUE_REMINDER_HOUR: "${OVERDUE_REMINDER_HOUR}" + DAILY_DIGEST_HOUR: "${DAILY_DIGEST_HOUR}" + + FEATURE_PUSH_ENABLED: "${FEATURE_PUSH_ENABLED}" + FEATURE_EMAIL_ENABLED: "${FEATURE_EMAIL_ENABLED}" + FEATURE_WEBHOOKS_ENABLED: "${FEATURE_WEBHOOKS_ENABLED}" + FEATURE_ONBOARDING_EMAILS_ENABLED: "${FEATURE_ONBOARDING_EMAILS_ENABLED}" + FEATURE_PDF_REPORTS_ENABLED: "${FEATURE_PDF_REPORTS_ENABLED}" + FEATURE_WORKER_ENABLED: "${FEATURE_WORKER_ENABLED}" + command: + - /bin/sh + - -lc + - | + set -eu + export POSTGRES_PASSWORD="$$(cat /run/secrets/postgres_password)" + export SECRET_KEY="$$(cat /run/secrets/secret_key)" + export EMAIL_HOST_PASSWORD="$$(cat /run/secrets/email_host_password)" + export FCM_SERVER_KEY="$$(cat /run/secrets/fcm_server_key)" + exec /app/worker + secrets: + - source: ${POSTGRES_PASSWORD_SECRET} + target: postgres_password + - source: ${SECRET_KEY_SECRET} + target: secret_key + - source: ${EMAIL_HOST_PASSWORD_SECRET} + target: email_host_password + - source: ${FCM_SERVER_KEY_SECRET} + target: fcm_server_key + - source: ${APNS_AUTH_KEY_SECRET} + target: apns_auth_key + deploy: + replicas: ${WORKER_REPLICAS} + restart_policy: + condition: any + delay: 5s + update_config: + parallelism: 1 + delay: 10s + order: start-first + rollback_config: + parallelism: 1 + delay: 5s + order: stop-first + networks: + - casera-network + + dozzle: + image: amir20/dozzle:latest + ports: + - target: 8080 + published: ${DOZZLE_PORT} + protocol: tcp + mode: ingress + environment: + DOZZLE_NO_ANALYTICS: "true" + volumes: + - /var/run/docker.sock:/var/run/docker.sock:ro + deploy: + replicas: 1 + restart_policy: + condition: any + delay: 5s + placement: + constraints: + - node.role == manager + networks: + - casera-network + +volumes: + redis_data: + uploads: + +networks: + casera-network: + driver: overlay + driver_opts: + encrypted: "true" + +secrets: + postgres_password: + external: true + name: ${POSTGRES_PASSWORD_SECRET} + secret_key: + external: true + name: ${SECRET_KEY_SECRET} + email_host_password: + external: true + name: ${EMAIL_HOST_PASSWORD_SECRET} + fcm_server_key: + external: true + name: ${FCM_SERVER_KEY_SECRET} + apns_auth_key: + external: true + name: ${APNS_AUTH_KEY_SECRET} diff --git a/docs/AUDIT_FINDINGS.md b/docs/AUDIT_FINDINGS.md new file mode 100644 index 0000000..ddb9746 --- /dev/null +++ b/docs/AUDIT_FINDINGS.md @@ -0,0 +1,1527 @@ +# MyCrib Go Backend — Deep Audit Findings + +**Date**: 2026-03-01 +**Scope**: All non-test `.go` files under `myCribAPI-go/` +**Agents**: 9 parallel audit agents covering security, authorization, data integrity, concurrency, performance, error handling, architecture compliance, API contracts, and cross-cutting logic + +--- + +## Summary + +| Audit Area | Critical | Bug | Race Cond. | Logic Error | Silent Failure | Warning | Fragile | Performance | Total | +|---|---|---|---|---|---|---|---|---|---| +| Security & Input Validation | 7 | 12 | — | — | — | 17 | — | — | 36 | +| Authorization & Access Control | 6 | 6 | — | — | — | 9 | — | — | 21 | +| Data Integrity (GORM) | 7 | 7 | 3 | — | — | 11 | — | — | 28 | +| Concurrency & Race Conditions | 1 | 4 | 3 | — | — | 10 | — | — | 18 | +| Performance & Query Efficiency | — | — | — | — | — | 1 | — | 19 | 20 | +| Error Handling & Panic Safety | 17 | 10 | — | — | 9 | 12 | — | — | 48 | +| Architecture Compliance | — | 12 | — | — | — | 11 | — | — | 23 | +| API Contract & Validation | — | 19 | — | — | — | 30 | — | — | 49 | +| Cross-Cutting Logic | 5 | 5 | 3 | 3 | 1 | — | 4 | — | 21 | +| **Total (raw)** | **43** | **75** | **9** | **3** | **10** | **101** | **4** | **19** | **264** | + +--- + +## Audit 1: Security & Input Validation + +### SEC-01 | CRITICAL | Apple JWS payload decoded without signature verification +- **File**: `internal/handlers/subscription_webhook_handler.go:190-192` +- **What**: `decodeAppleSignedPayload` splits the JWS token and base64-decodes the payload directly. The comment on line 190 explicitly says "we're trusting Apple's signature for now." `VerifyAppleSignature` exists but is never called from the handler flow. +- **Impact**: An attacker can craft a fake Apple webhook with arbitrary notification data (subscribe/renew/refund), granting or revoking Pro subscriptions for any user who has ever made a purchase. + +### SEC-02 | CRITICAL | Google Pub/Sub webhook always returns true (unauthenticated) +- **File**: `internal/handlers/subscription_webhook_handler.go:787-793` +- **What**: `VerifyGooglePubSubToken` unconditionally returns `true`. The Google webhook endpoint `HandleGoogleWebhook` never calls this function at all. Any HTTP client can POST arbitrary subscription events. +- **Impact**: An attacker can forge Google subscription events to grant themselves Pro access, cancel other users' subscriptions, or trigger arbitrary downgrades. + +### SEC-03 | CRITICAL | GoAdmin password reset to "admin" on every migration +- **File**: `internal/database/database.go:372-382` +- **What**: Line 373 does `INSERT ON CONFLICT DO NOTHING`, but line 379-382 unconditionally `UPDATE goadmin_users SET password = WHERE username = 'admin'`. Every time migrations run, the GoAdmin password is reset to "admin". +- **Impact**: The admin panel is permanently accessible with `admin/admin` credentials. Even if the password is changed, the next deploy resets it. + +### SEC-04 | CRITICAL | Next.js admin password reset to "admin123" on every migration +- **File**: `internal/database/database.go:447-463` +- **What**: Lines 458-463 unconditionally update the admin@mycrib.com password to the bcrypt hash of "admin123" on every migration. The log message on line 463 even says "Updated admin@mycrib.com password to admin123." +- **Impact**: The admin API is permanently accessible with hardcoded credentials. Any attacker who discovers the endpoint can access full admin functionality. + +### SEC-05 | CRITICAL | SQL injection via SortBy in all admin list endpoints +- **File**: `internal/admin/handlers/admin_user_handler.go:86-88` +- **What**: `sortBy = filters.SortBy` is concatenated directly into `query.Order(sortBy + " " + filters.GetSortDir())` without any allowlist validation. This pattern is repeated across every admin list handler (admin_user, auth_token, completion, contractor, document, device, notification, etc.). +- **Impact**: An authenticated admin can inject arbitrary SQL via the `sort_by` query parameter, e.g., `sort_by=created_at; DROP TABLE auth_user; --`, achieving full database read/write. + +### SEC-06 | CRITICAL | Apple validation failure grants 1 month free Pro +- **File**: `internal/services/subscription_service.go:371` +- **What**: When Apple receipt validation returns a non-fatal error (network timeout, transient failure), the code falls through to `expiresAt = time.Now().UTC().AddDate(0, 1, 0)` -- granting 1 month of Pro. Line 381 grants 1 year when Apple IAP is not configured. +- **Impact**: An attacker can send any invalid receipt data, trigger a validation error, and receive free Pro access. Repeating monthly yields indefinite free Pro. + +### SEC-07 | CRITICAL | Google validation failure grants 1 month free Pro; not configured grants 1 year +- **File**: `internal/services/subscription_service.go:429-449` +- **What**: Same pattern as Apple. Line 430: non-fatal Google validation error gives 1-month fallback. Line 449: unconfigured Google client gives 1 year. An attacker sending a garbage `purchaseToken` with `platform=android` triggers the fallback. +- **Impact**: Free Pro subscription for any user by sending invalid purchase data. + +### SEC-08 | BUG | Token slice panic on short tokens +- **File**: `internal/middleware/auth.go:66` +- **What**: `token[:8]+"..."` in the debug log message will panic with an index-out-of-range if the token string is fewer than 8 characters. There is no length check before slicing. +- **Impact**: A malformed Authorization header with a valid scheme but very short token causes a server panic (500) and potential DoS. + +### SEC-09 | BUG | Path traversal in resolveFilePath +- **File**: `internal/handlers/media_handler.go:156-171` +- **What**: `resolveFilePath` uses `strings.TrimPrefix` followed by `filepath.Join(uploadDir, relativePath)`. If the stored URL contains `../` sequences (e.g., `/uploads/../../../etc/passwd`), `filepath.Join` resolves them and `c.File()` serves the resulting path. There is no `filepath.Abs` containment check (unlike `storage_service.Delete`). +- **Impact**: If an attacker can control a stored URL (e.g., via SQL injection or a compromised document record), they can read arbitrary files from the server filesystem. + +### SEC-10 | BUG | Path traversal in resolveImageFilePath (task service) +- **File**: `internal/services/task_service.go:850-862` +- **What**: Identical path traversal vulnerability to media_handler's `resolveFilePath`. No validation that the resolved path stays within the upload directory. +- **Impact**: Arbitrary file read if stored URLs are manipulated. + +### SEC-11 | BUG | Path traversal check bypassed when filepath.Abs errors +- **File**: `internal/services/storage_service.go:137-138` +- **What**: `absUploadDir, _ := filepath.Abs(s.cfg.UploadDir)` and `absFilePath, _ := filepath.Abs(fullPath)` both silently discard errors. If `filepath.Abs` fails, both return empty strings, `strings.HasPrefix("", "")` is true, and the path traversal check passes. +- **Impact**: Under unusual filesystem conditions, the path containment check becomes ineffective, allowing deletion of arbitrary files. + +### SEC-12 | BUG | Nil pointer panic after WebSocket upgrade failure +- **File**: `internal/monitoring/handler.go:116-119` +- **What**: When `upgrader.Upgrade` fails, `conn` is nil but execution continues to `defer conn.Close()`, causing a nil pointer panic. +- **Impact**: Server panic on any failed WebSocket upgrade attempt. + +### SEC-13 | BUG | Missing return after context cancellation causes goroutine spin +- **File**: `internal/monitoring/handler.go:177` +- **What**: The `case <-ctx.Done():` block has no `return` statement, so after the context is cancelled, the `for` loop immediately re-enters the `select` and spins indefinitely. +- **Impact**: 100% CPU usage on one goroutine for every WebSocket connection that disconnects. The goroutine never exits, leaking resources. + +### SEC-14 | BUG | Nil pointer dereference when cache is nil +- **File**: `internal/admin/handlers/lookup_handler.go:30-32` +- **What**: `cache := services.GetCache(); if cache == nil { }` has an empty body, then immediately calls `cache.CacheCategories()`. If cache is nil, this panics. Same pattern at lines 50-52 for priorities. +- **Impact**: Server panic in admin lookup handlers when Redis is unavailable. + +### SEC-15 | BUG | Panic on short reset tokens +- **File**: `internal/admin/handlers/password_reset_code_handler.go:85` +- **What**: `code.ResetToken[:8] + "..." + code.ResetToken[len-4:]` panics with index out of range if the reset token is fewer than 8 characters. +- **Impact**: Admin panel crash when viewing short reset codes. + +### SEC-16 | BUG | Race condition in Apple legacy receipt validation +- **File**: `internal/services/iap_validation.go:381-386` +- **What**: `c.sandbox = true` mutates the struct field, calls `validateLegacyReceipt`, then `c.sandbox = false`. If two concurrent requests hit this path, one may read a sandbox flag set by the other, causing production receipts to validate against sandbox or vice versa. +- **Impact**: Intermittent validation failures or sandbox receipts being accepted in production. + +### SEC-17 | BUG | Index out of bounds in FCM response parsing +- **File**: `internal/push/fcm.go:119` +- **What**: `for i, result := range fcmResp.Results` iterates results and accesses `tokens[i]`. If FCM returns fewer results than tokens sent, this is safe, but if the response is malformed with more results than tokens, it panics. +- **Impact**: Server panic on malformed FCM responses. + +### SEC-18 | BUG | DeleteFile endpoint allows deleting any file without ownership check +- **File**: `internal/handlers/upload_handler.go:78-91` +- **What**: The `DELETE /api/uploads/` endpoint accepts a `url` field in the request body and passes it directly to `storageService.Delete`. There is no check that the authenticated user owns or has access to the resource associated with that file. +- **Impact**: Any authenticated user can delete other users' uploaded files (images, documents, completion photos). + +### SEC-19 | BUG | Unchecked type assertion throughout handlers +- **File**: `internal/handlers/contractor_handler.go:28` (and 60+ other locations) +- **What**: `c.Get(middleware.AuthUserKey).(*models.User)` is used without the comma-ok pattern across contractor_handler (7 instances), document_handler (10 instances), residence_handler (14 instances), task_handler (18 instances), and media_handler (3 instances). If the auth middleware is misconfigured or bypassed, this panics. +- **Impact**: Server panic (500) on any request where the context value is not the expected type. + +### SEC-20 | WARNING | Admin JWT accepted via query parameter +- **File**: `internal/middleware/admin_auth.go:49-50` +- **What**: `tokenString = c.QueryParam("token")` allows passing the admin JWT as a URL query parameter. URL parameters are logged by web servers, proxies, and browser history. +- **Impact**: Admin tokens leaked into access logs, proxy logs, and referrer headers. A compromised log grants full admin access. + +### SEC-21 | WARNING | Hardcoded debug secret key +- **File**: `internal/config/config.go:339` +- **What**: When `SECRET_KEY` is not set and `DEBUG=true`, the secret key defaults to `"change-me-in-production-secret-key-12345"`. If debug mode is accidentally enabled in production, all JWT signatures use this predictable key. +- **Impact**: Trivially forgeable admin JWTs if debug mode leaks to production. + +### SEC-22 | WARNING | XSS in admin email HTML template +- **File**: `internal/admin/handlers/notification_handler.go:351-363` +- **What**: `req.Subject` and `req.Body` are concatenated directly into HTML via string concatenation (`+ req.Subject +`), with no HTML escaping. If the admin enters ``, + expected: `<script>alert("xss")</script>`, + }, + { + name: "img onerror payload", + input: ``, + expected: `<img src=x onerror=alert(1)>`, + }, + { + name: "ampersand and angle brackets", + input: `Tom & Jerry `, + expected: `Tom & Jerry <bros>`, + }, + { + name: "plain text passes through", + input: "Hello World", + expected: "Hello World", + }, + { + name: "single quotes", + input: `It's a 'test'`, + expected: `It's a 'test'`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + escaped := html.EscapeString(tt.input) + assert.Equal(t, tt.expected, escaped) + // Verify the escaped output does NOT contain raw angle brackets from the input + if tt.input != tt.expected { + assert.NotContains(t, escaped, ""}, + {"spaces", "abc def"}, + {"semicolons", "abc;def"}, + {"unicode", "abc\u200bdef"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + mw := RequestIDMiddleware() + handler := mw(func(c echo.Context) error { + return c.String(http.StatusOK, GetRequestID(c)) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(HeaderXRequestID, tt.inputID) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + assert.NoError(t, err) + // The malicious ID should be replaced with a generated UUID + assert.NotEqual(t, tt.inputID, rec.Body.String(), + "control chars should be rejected, got original ID back") + assert.Len(t, rec.Body.String(), 36, "should be a generated UUID") + }) + } +} + +func TestRequestID_TooLong_Sanitized(t *testing.T) { + // SEC-29: X-Request-ID longer than 64 chars should be rejected. + e := echo.New() + mw := RequestIDMiddleware() + handler := mw(func(c echo.Context) error { + return c.String(http.StatusOK, GetRequestID(c)) + }) + + longID := strings.Repeat("a", 65) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(HeaderXRequestID, longID) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + assert.NoError(t, err) + assert.NotEqual(t, longID, rec.Body.String(), "overly long ID should be replaced") + assert.Len(t, rec.Body.String(), 36, "should be a generated UUID") +} + +func TestRequestID_MaxLength_Accepted(t *testing.T) { + // Exactly 64 chars of valid characters should be accepted + e := echo.New() + mw := RequestIDMiddleware() + handler := mw(func(c echo.Context) error { + return c.String(http.StatusOK, GetRequestID(c)) + }) + + maxID := strings.Repeat("a", 64) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(HeaderXRequestID, maxID) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, maxID, rec.Body.String(), "64-char valid ID should be accepted") +} diff --git a/internal/middleware/sanitize.go b/internal/middleware/sanitize.go new file mode 100644 index 0000000..5451995 --- /dev/null +++ b/internal/middleware/sanitize.go @@ -0,0 +1,19 @@ +package middleware + +import "strings" + +// SanitizeSortColumn validates a user-supplied sort column against an allowlist. +// Returns defaultCol if the input is empty or not in the allowlist. +// This prevents SQL injection via ORDER BY clauses. +func SanitizeSortColumn(input string, allowedCols []string, defaultCol string) string { + input = strings.TrimSpace(input) + if input == "" { + return defaultCol + } + for _, col := range allowedCols { + if strings.EqualFold(input, col) { + return col + } + } + return defaultCol +} diff --git a/internal/middleware/sanitize_test.go b/internal/middleware/sanitize_test.go new file mode 100644 index 0000000..42e8f73 --- /dev/null +++ b/internal/middleware/sanitize_test.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSanitizeSortColumn_AllowedColumn_Passes(t *testing.T) { + allowed := []string{"created_at", "updated_at", "name"} + result := SanitizeSortColumn("created_at", allowed, "created_at") + assert.Equal(t, "created_at", result) +} + +func TestSanitizeSortColumn_CaseInsensitive(t *testing.T) { + allowed := []string{"created_at", "updated_at", "name"} + result := SanitizeSortColumn("Created_At", allowed, "created_at") + assert.Equal(t, "created_at", result) +} + +func TestSanitizeSortColumn_SQLInjection_ReturnsDefault(t *testing.T) { + allowed := []string{"created_at", "updated_at", "name"} + + tests := []struct { + name string + input string + }{ + {"drop table", "created_at; DROP TABLE auth_user; --"}, + {"union select", "name UNION SELECT * FROM auth_user"}, + {"or 1=1", "name OR 1=1"}, + {"semicolon", "created_at;"}, + {"subquery", "(SELECT password FROM auth_user LIMIT 1)"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeSortColumn(tt.input, allowed, "created_at") + assert.Equal(t, "created_at", result, "SQL injection attempt should return default") + }) + } +} + +func TestSanitizeSortColumn_Empty_ReturnsDefault(t *testing.T) { + allowed := []string{"created_at", "updated_at", "name"} + result := SanitizeSortColumn("", allowed, "created_at") + assert.Equal(t, "created_at", result) +} + +func TestSanitizeSortColumn_Whitespace_ReturnsDefault(t *testing.T) { + allowed := []string{"created_at", "updated_at", "name"} + result := SanitizeSortColumn(" ", allowed, "created_at") + assert.Equal(t, "created_at", result) +} + +func TestSanitizeSortColumn_UnknownColumn_ReturnsDefault(t *testing.T) { + allowed := []string{"created_at", "updated_at", "name"} + result := SanitizeSortColumn("nonexistent_column", allowed, "created_at") + assert.Equal(t, "created_at", result) +} diff --git a/internal/middleware/timezone.go b/internal/middleware/timezone.go index 8db7f87..1bd9d52 100644 --- a/internal/middleware/timezone.go +++ b/internal/middleware/timezone.go @@ -79,22 +79,30 @@ func parseTimezone(tz string) *time.Location { } // GetUserTimezone retrieves the user's timezone from the Echo context. -// Returns UTC if not set. +// Returns UTC if not set or if the stored value is not a *time.Location. func GetUserTimezone(c echo.Context) *time.Location { - loc := c.Get(TimezoneKey) - if loc == nil { + val := c.Get(TimezoneKey) + if val == nil { return time.UTC } - return loc.(*time.Location) + loc, ok := val.(*time.Location) + if !ok { + return time.UTC + } + return loc } // GetUserNow retrieves the timezone-aware "now" time from the Echo context. // This represents the start of the current day in the user's timezone. -// Returns time.Now().UTC() if not set. +// Returns time.Now().UTC() if not set or if the stored value is not a time.Time. func GetUserNow(c echo.Context) time.Time { - now := c.Get(UserNowKey) - if now == nil { + val := c.Get(UserNowKey) + if val == nil { return time.Now().UTC() } - return now.(time.Time) + now, ok := val.(time.Time) + if !ok { + return time.Now().UTC() + } + return now } diff --git a/internal/monitoring/collector.go b/internal/monitoring/collector.go index bf2e70d..b8e750a 100644 --- a/internal/monitoring/collector.go +++ b/internal/monitoring/collector.go @@ -52,14 +52,18 @@ func (c *Collector) Collect() SystemStats { // CPU stats c.collectCPU(&stats) + // Read Go runtime memory stats once (used by both memory and runtime collectors) + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + // Memory stats (system + Go runtime) - c.collectMemory(&stats) + c.collectMemory(&stats, &memStats) // Disk stats c.collectDisk(&stats) // Go runtime stats - c.collectRuntime(&stats) + c.collectRuntime(&stats, &memStats) // HTTP stats (API only) if c.httpCollector != nil { @@ -77,9 +81,9 @@ func (c *Collector) Collect() SystemStats { } func (c *Collector) collectCPU(stats *SystemStats) { - // Get CPU usage percentage (blocks for 1 second to get accurate sample) - // Shorter intervals can give inaccurate readings - if cpuPercent, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercent) > 0 { + // Get CPU usage percentage (blocks for 200ms to sample) + // This is called periodically, so a shorter window is acceptable + if cpuPercent, err := cpu.Percent(200*time.Millisecond, false); err == nil && len(cpuPercent) > 0 { stats.CPU.UsagePercent = cpuPercent[0] } @@ -93,7 +97,7 @@ func (c *Collector) collectCPU(stats *SystemStats) { } } -func (c *Collector) collectMemory(stats *SystemStats) { +func (c *Collector) collectMemory(stats *SystemStats, m *runtime.MemStats) { // System memory if vmem, err := mem.VirtualMemory(); err == nil { stats.Memory.UsedBytes = vmem.Used @@ -101,9 +105,7 @@ func (c *Collector) collectMemory(stats *SystemStats) { stats.Memory.UsagePercent = vmem.UsedPercent } - // Go runtime memory - var m runtime.MemStats - runtime.ReadMemStats(&m) + // Go runtime memory (reuses pre-read MemStats) stats.Memory.HeapAlloc = m.HeapAlloc stats.Memory.HeapSys = m.HeapSys stats.Memory.HeapInuse = m.HeapInuse @@ -119,10 +121,7 @@ func (c *Collector) collectDisk(stats *SystemStats) { } } -func (c *Collector) collectRuntime(stats *SystemStats) { - var m runtime.MemStats - runtime.ReadMemStats(&m) - +func (c *Collector) collectRuntime(stats *SystemStats, m *runtime.MemStats) { stats.Runtime.Goroutines = runtime.NumGoroutine() stats.Runtime.NumGC = m.NumGC if m.NumGC > 0 { diff --git a/internal/monitoring/handler.go b/internal/monitoring/handler.go index fd1fbd4..a028e1b 100644 --- a/internal/monitoring/handler.go +++ b/internal/monitoring/handler.go @@ -17,8 +17,13 @@ var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { - // Allow connections from admin panel - return true + origin := r.Header.Get("Origin") + if origin == "" { + // Same-origin requests may omit the Origin header + return true + } + // Allow if origin matches the request host + return strings.HasPrefix(origin, "https://"+r.Host) || strings.HasPrefix(origin, "http://"+r.Host) }, } @@ -116,6 +121,7 @@ func (h *Handler) WebSocket(c echo.Context) error { conn, err := upgrader.Upgrade(c.Response().Writer, c.Request(), nil) if err != nil { log.Error().Err(err).Msg("Failed to upgrade WebSocket connection") + return err } defer conn.Close() @@ -174,6 +180,7 @@ func (h *Handler) WebSocket(c echo.Context) error { h.sendStats(conn, &wsMu) case <-ctx.Done(): + return nil } } } diff --git a/internal/monitoring/service.go b/internal/monitoring/service.go index c3e7cc7..e0d360d 100644 --- a/internal/monitoring/service.go +++ b/internal/monitoring/service.go @@ -108,6 +108,10 @@ func (s *Service) Stop() { close(s.settingsStopCh) s.collector.Stop() + + // Flush and close the log writer's background goroutine + s.logWriter.Close() + log.Info().Str("process", s.process).Msg("Monitoring service stopped") } diff --git a/internal/monitoring/writer.go b/internal/monitoring/writer.go index b61fb7a..30e8f50 100644 --- a/internal/monitoring/writer.go +++ b/internal/monitoring/writer.go @@ -8,23 +8,56 @@ import ( "github.com/google/uuid" ) -// RedisLogWriter implements io.Writer to capture zerolog output to Redis +const ( + // writerChannelSize is the buffer size for the async log write channel. + // Entries beyond this limit are dropped to prevent unbounded memory growth. + writerChannelSize = 256 +) + +// RedisLogWriter implements io.Writer to capture zerolog output to Redis. +// It uses a single background goroutine with a buffered channel instead of +// spawning a new goroutine per log line, preventing unbounded goroutine growth. type RedisLogWriter struct { buffer *LogBuffer process string enabled atomic.Bool + ch chan LogEntry + done chan struct{} } -// NewRedisLogWriter creates a new writer that captures logs to Redis +// NewRedisLogWriter creates a new writer that captures logs to Redis. +// It starts a single background goroutine that drains the buffered channel. func NewRedisLogWriter(buffer *LogBuffer, process string) *RedisLogWriter { w := &RedisLogWriter{ buffer: buffer, process: process, + ch: make(chan LogEntry, writerChannelSize), + done: make(chan struct{}), } w.enabled.Store(true) // enabled by default + + // Single background goroutine drains the channel + go w.drainLoop() + return w } +// drainLoop reads entries from the buffered channel and pushes them to Redis. +// It runs in a single goroutine for the lifetime of the writer. +func (w *RedisLogWriter) drainLoop() { + defer close(w.done) + for entry := range w.ch { + _ = w.buffer.Push(entry) // Ignore errors to avoid blocking log output + } +} + +// Close shuts down the background goroutine. It should be called during +// graceful shutdown to ensure all buffered entries are flushed. +func (w *RedisLogWriter) Close() { + close(w.ch) + <-w.done // Wait for drain to finish +} + // SetEnabled enables or disables log capture to Redis func (w *RedisLogWriter) SetEnabled(enabled bool) { w.enabled.Store(enabled) @@ -35,8 +68,10 @@ func (w *RedisLogWriter) IsEnabled() bool { return w.enabled.Load() } -// Write implements io.Writer interface -// It parses zerolog JSON output and writes to Redis asynchronously +// Write implements io.Writer interface. +// It parses zerolog JSON output and sends it to the buffered channel for +// async Redis writes. If the channel is full, the entry is dropped to +// avoid blocking the caller (back-pressure shedding). func (w *RedisLogWriter) Write(p []byte) (n int, err error) { // Skip if monitoring is disabled if !w.enabled.Load() { @@ -86,10 +121,14 @@ func (w *RedisLogWriter) Write(p []byte) (n int, err error) { } } - // Write to Redis asynchronously to avoid blocking - go func() { - _ = w.buffer.Push(entry) // Ignore errors to avoid blocking log output - }() + // Non-blocking send: drop entries if channel is full rather than + // spawning unbounded goroutines or blocking the logger + select { + case w.ch <- entry: + // Sent successfully + default: + // Channel full — drop this entry to avoid back-pressure on the logger + } return len(p), nil } diff --git a/internal/push/fcm.go b/internal/push/fcm.go index fcd4d32..3f324fa 100644 --- a/internal/push/fcm.go +++ b/internal/push/fcm.go @@ -117,6 +117,9 @@ func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message st // Log individual results for i, result := range fcmResp.Results { + if i >= len(tokens) { + break + } if result.Error != "" { log.Error(). Str("token", truncateToken(tokens[i])). diff --git a/internal/push/fcm_test.go b/internal/push/fcm_test.go new file mode 100644 index 0000000..ff4bf95 --- /dev/null +++ b/internal/push/fcm_test.go @@ -0,0 +1,186 @@ +package push + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestFCMClient creates an FCMClient pointing at the given test server URL. +func newTestFCMClient(serverURL string) *FCMClient { + return &FCMClient{ + serverKey: "test-server-key", + httpClient: http.DefaultClient, + } +} + +// serveFCMResponse creates an httptest.Server that returns the given FCMResponse as JSON. +func serveFCMResponse(t *testing.T, resp FCMResponse) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + })) +} + +// sendWithEndpoint is a helper that sends an FCM notification using a custom endpoint +// (the test server) instead of the real FCM endpoint. This avoids modifying the +// production code to be testable and instead temporarily overrides the client's HTTP +// transport to redirect requests to our test server. +func sendWithEndpoint(client *FCMClient, server *httptest.Server, ctx context.Context, tokens []string, title, message string, data map[string]string) error { + // Override the HTTP client to redirect all requests to the test server + client.httpClient = server.Client() + + // We need to intercept the request and redirect it to our test server. + // Use a custom RoundTripper that rewrites the URL. + originalTransport := server.Client().Transport + client.httpClient.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) { + // Rewrite the URL to point to the test server + req.URL.Scheme = "http" + req.URL.Host = server.Listener.Addr().String() + if originalTransport != nil { + return originalTransport.RoundTrip(req) + } + return http.DefaultTransport.RoundTrip(req) + }) + + return client.Send(ctx, tokens, title, message, data) +} + +// roundTripFunc is a function that implements http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestFCMSend_MoreResultsThanTokens_NoPanic(t *testing.T) { + // FCM returns 5 results but we only sent 2 tokens. + // Before the bounds check fix, this would panic with index out of range. + fcmResp := FCMResponse{ + MulticastID: 12345, + Success: 2, + Failure: 3, + Results: []FCMResult{ + {MessageID: "msg1"}, + {MessageID: "msg2"}, + {Error: "InvalidRegistration"}, + {Error: "NotRegistered"}, + {Error: "InvalidRegistration"}, + }, + } + + server := serveFCMResponse(t, fcmResp) + defer server.Close() + + client := newTestFCMClient(server.URL) + tokens := []string{"token-aaa-111", "token-bbb-222"} + + // This must not panic + err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + assert.NoError(t, err) +} + +func TestFCMSend_FewerResultsThanTokens_NoPanic(t *testing.T) { + // FCM returns fewer results than tokens we sent. + // This is also a malformed response but should not panic. + fcmResp := FCMResponse{ + MulticastID: 12345, + Success: 1, + Failure: 0, + Results: []FCMResult{ + {MessageID: "msg1"}, + }, + } + + server := serveFCMResponse(t, fcmResp) + defer server.Close() + + client := newTestFCMClient(server.URL) + tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"} + + err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + assert.NoError(t, err) +} + +func TestFCMSend_EmptyResponse_NoPanic(t *testing.T) { + // FCM returns an empty Results slice. + fcmResp := FCMResponse{ + MulticastID: 12345, + Success: 0, + Failure: 0, + Results: []FCMResult{}, + } + + server := serveFCMResponse(t, fcmResp) + defer server.Close() + + client := newTestFCMClient(server.URL) + tokens := []string{"token-aaa-111"} + + err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + // No panic expected. The function returns nil because fcmResp.Success == 0 + // and fcmResp.Failure == 0 (the "all failed" check requires Failure > 0). + assert.NoError(t, err) +} + +func TestFCMSend_NilResultsSlice_NoPanic(t *testing.T) { + // FCM returns a response with nil Results (e.g., malformed JSON). + fcmResp := FCMResponse{ + MulticastID: 12345, + Success: 0, + Failure: 1, + } + + server := serveFCMResponse(t, fcmResp) + defer server.Close() + + client := newTestFCMClient(server.URL) + tokens := []string{"token-aaa-111"} + + err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + // Should return error because Success == 0 and Failure > 0 + assert.Error(t, err) + assert.Contains(t, err.Error(), "all FCM notifications failed") +} + +func TestFCMSend_EmptyTokens_ReturnsNil(t *testing.T) { + // Verify the early return for empty tokens. + client := &FCMClient{ + serverKey: "test-key", + httpClient: http.DefaultClient, + } + + err := client.Send(context.Background(), []string{}, "Test", "Body", nil) + assert.NoError(t, err) +} + +func TestFCMSend_ResultsWithErrorsMatchTokens(t *testing.T) { + // Normal case: results count matches tokens count, all with errors. + fcmResp := FCMResponse{ + MulticastID: 12345, + Success: 0, + Failure: 2, + Results: []FCMResult{ + {Error: "InvalidRegistration"}, + {Error: "NotRegistered"}, + }, + } + + server := serveFCMResponse(t, fcmResp) + defer server.Close() + + client := newTestFCMClient(server.URL) + tokens := []string{"token-aaa-111", "token-bbb-222"} + + err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "all FCM notifications failed") +} diff --git a/internal/repositories/contractor_repo.go b/internal/repositories/contractor_repo.go index fccc6f1..abc61de 100644 --- a/internal/repositories/contractor_repo.go +++ b/internal/repositories/contractor_repo.go @@ -63,7 +63,7 @@ func (r *ContractorRepository) FindByUser(userID uint, residenceIDs []uint) ([]m query = query.Where("residence_id IS NULL AND created_by_id = ?", userID) } - err := query.Order("is_favorite DESC, name ASC").Find(&contractors).Error + err := query.Order("is_favorite DESC, name ASC").Limit(500).Find(&contractors).Error return contractors, err } @@ -85,18 +85,31 @@ func (r *ContractorRepository) Delete(id uint) error { Update("is_active", false).Error } -// ToggleFavorite toggles the favorite status of a contractor +// ToggleFavorite toggles the favorite status of a contractor atomically. +// Uses a single UPDATE with NOT to avoid read-then-write race conditions. +// Only toggles active contractors to prevent toggling soft-deleted records. func (r *ContractorRepository) ToggleFavorite(id uint) (bool, error) { - var contractor models.Contractor - if err := r.db.First(&contractor, id).Error; err != nil { - return false, err - } - - newStatus := !contractor.IsFavorite - err := r.db.Model(&models.Contractor{}). - Where("id = ?", id). - Update("is_favorite", newStatus).Error + var newStatus bool + err := r.db.Transaction(func(tx *gorm.DB) error { + // Atomic toggle: SET is_favorite = NOT is_favorite for active contractors only + result := tx.Model(&models.Contractor{}). + Where("id = ? AND is_active = ?", id, true). + Update("is_favorite", gorm.Expr("NOT is_favorite")) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + // Read back the new value within the same transaction + var contractor models.Contractor + if err := tx.Select("is_favorite").First(&contractor, id).Error; err != nil { + return err + } + newStatus = contractor.IsFavorite + return nil + }) return newStatus, err } @@ -145,6 +158,19 @@ func (r *ContractorRepository) CountByResidence(residenceID uint) (int64, error) return count, err } +// CountByResidenceIDs counts all active contractors across multiple residences in a single query. +// Returns the total count of active contractors for the given residence IDs. +func (r *ContractorRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) { + if len(residenceIDs) == 0 { + return 0, nil + } + var count int64 + err := r.db.Model(&models.Contractor{}). + Where("residence_id IN ? AND is_active = ?", residenceIDs, true). + Count(&count).Error + return count, err +} + // === Specialty Operations === // GetAllSpecialties returns all contractor specialties diff --git a/internal/repositories/contractor_repo_test.go b/internal/repositories/contractor_repo_test.go new file mode 100644 index 0000000..99ce63c --- /dev/null +++ b/internal/repositories/contractor_repo_test.go @@ -0,0 +1,96 @@ +package repositories + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/models" + "github.com/treytartt/casera-api/internal/testutil" +) + +func TestToggleFavorite_Active_Toggles(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewContractorRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor") + + // Initially is_favorite is false + assert.False(t, contractor.IsFavorite, "contractor should start as not favorite") + + // First toggle: false -> true + newStatus, err := repo.ToggleFavorite(contractor.ID) + require.NoError(t, err) + assert.True(t, newStatus, "first toggle should set favorite to true") + + // Verify in database + var found models.Contractor + err = db.First(&found, contractor.ID).Error + require.NoError(t, err) + assert.True(t, found.IsFavorite, "database should reflect favorite = true") + + // Second toggle: true -> false + newStatus, err = repo.ToggleFavorite(contractor.ID) + require.NoError(t, err) + assert.False(t, newStatus, "second toggle should set favorite to false") + + // Verify in database + err = db.First(&found, contractor.ID).Error + require.NoError(t, err) + assert.False(t, found.IsFavorite, "database should reflect favorite = false") +} + +func TestToggleFavorite_SoftDeleted_ReturnsError(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewContractorRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Deleted Contractor") + + // Soft-delete the contractor + err := db.Model(&models.Contractor{}). + Where("id = ?", contractor.ID). + Update("is_active", false).Error + require.NoError(t, err) + + // Toggling a soft-deleted contractor should fail (record not found) + _, err = repo.ToggleFavorite(contractor.ID) + assert.Error(t, err, "toggling a soft-deleted contractor should return an error") +} + +func TestToggleFavorite_NonExistent_ReturnsError(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewContractorRepository(db) + + _, err := repo.ToggleFavorite(99999) + assert.Error(t, err, "toggling a non-existent contractor should return an error") +} + +func TestContractorRepository_FindByUser_HasDefaultLimit(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewContractorRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + // Create 510 contractors to exceed the default limit of 500 + for i := 0; i < 510; i++ { + c := &models.Contractor{ + ResidenceID: &residence.ID, + CreatedByID: user.ID, + Name: fmt.Sprintf("Contractor %d", i+1), + IsActive: true, + } + err := db.Create(c).Error + require.NoError(t, err) + } + + contractors, err := repo.FindByUser(user.ID, []uint{residence.ID}) + require.NoError(t, err) + assert.Equal(t, 500, len(contractors), "FindByUser should return at most 500 contractors by default") +} diff --git a/internal/repositories/document_repo.go b/internal/repositories/document_repo.go index 419acaa..a935eeb 100644 --- a/internal/repositories/document_repo.go +++ b/internal/repositories/document_repo.go @@ -52,7 +52,8 @@ func (r *DocumentRepository) FindByResidence(residenceID uint) ([]models.Documen return documents, err } -// FindByUser finds all documents accessible to a user +// FindByUser finds all documents accessible to a user. +// A default limit of 500 is applied to prevent unbounded result sets. func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, error) { var documents []models.Document err := r.db.Preload("CreatedBy"). @@ -60,6 +61,7 @@ func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, Preload("Images"). Where("residence_id IN ? AND is_active = ?", residenceIDs, true). Order("created_at DESC"). + Limit(500). Find(&documents).Error return documents, err } @@ -89,7 +91,8 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc query = query.Where("expiry_date IS NOT NULL AND expiry_date > ? AND expiry_date <= ?", now, threshold) } if filter.Search != "" { - searchPattern := "%" + filter.Search + "%" + escaped := escapeLikeWildcards(filter.Search) + searchPattern := "%" + escaped + "%" query = query.Where("(title ILIKE ? OR description ILIKE ?)", searchPattern, searchPattern) } } @@ -169,6 +172,19 @@ func (r *DocumentRepository) CountByResidence(residenceID uint) (int64, error) { return count, err } +// CountByResidenceIDs counts all active documents across multiple residences in a single query. +// Returns the total count of active documents for the given residence IDs. +func (r *DocumentRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) { + if len(residenceIDs) == 0 { + return 0, nil + } + var count int64 + err := r.db.Model(&models.Document{}). + Where("residence_id IN ? AND is_active = ?", residenceIDs, true). + Count(&count).Error + return count, err +} + // FindByIDIncludingInactive finds a document by ID including inactive ones func (r *DocumentRepository) FindByIDIncludingInactive(id uint, document *models.Document) error { return r.db.Preload("CreatedBy").Preload("Images").First(document, id).Error diff --git a/internal/repositories/document_repo_test.go b/internal/repositories/document_repo_test.go new file mode 100644 index 0000000..1a1e34f --- /dev/null +++ b/internal/repositories/document_repo_test.go @@ -0,0 +1,38 @@ +package repositories + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/models" + "github.com/treytartt/casera-api/internal/testutil" +) + +func TestDocumentRepository_FindByUser_HasDefaultLimit(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewDocumentRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + // Create 510 documents to exceed the default limit of 500 + for i := 0; i < 510; i++ { + doc := &models.Document{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: fmt.Sprintf("Doc %d", i+1), + DocumentType: models.DocumentTypeGeneral, + FileURL: "https://example.com/doc.pdf", + IsActive: true, + } + err := db.Create(doc).Error + require.NoError(t, err) + } + + docs, err := repo.FindByUser([]uint{residence.ID}) + require.NoError(t, err) + assert.Equal(t, 500, len(docs), "FindByUser should return at most 500 documents by default") +} diff --git a/internal/repositories/notification_repo.go b/internal/repositories/notification_repo.go index 0d89903..f4f7b60 100644 --- a/internal/repositories/notification_repo.go +++ b/internal/repositories/notification_repo.go @@ -1,6 +1,7 @@ package repositories import ( + "errors" "time" "gorm.io/gorm" @@ -130,18 +131,25 @@ func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPre // UpdatePreferences updates notification preferences func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error { - return r.db.Save(prefs).Error + return r.db.Omit("User").Save(prefs).Error } -// GetOrCreatePreferences gets or creates notification preferences for a user +// GetOrCreatePreferences gets or creates notification preferences for a user. +// Uses a transaction to avoid TOCTOU race conditions on concurrent requests. func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) { - prefs, err := r.FindPreferencesByUser(userID) - if err == nil { - return prefs, nil - } + var prefs models.NotificationPreference - if err == gorm.ErrRecordNotFound { - prefs = &models.NotificationPreference{ + err := r.db.Transaction(func(tx *gorm.DB) error { + err := tx.Where("user_id = ?", userID).First(&prefs).Error + if err == nil { + return nil // Found existing preferences + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err // Unexpected error + } + + // Record not found -- create with defaults + prefs = models.NotificationPreference{ UserID: userID, TaskDueSoon: true, TaskOverdue: true, @@ -151,17 +159,36 @@ func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.No WarrantyExpiring: true, EmailTaskCompleted: true, } - if err := r.CreatePreferences(prefs); err != nil { - return nil, err - } - return prefs, nil + return tx.Create(&prefs).Error + }) + if err != nil { + return nil, err } - - return nil, err + return &prefs, nil } // === Device Registration === +// FindAPNSDeviceByID finds an APNS device by ID +func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) { + var device models.APNSDevice + err := r.db.First(&device, id).Error + if err != nil { + return nil, err + } + return &device, nil +} + +// FindGCMDeviceByID finds a GCM device by ID +func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) { + var device models.GCMDevice + err := r.db.First(&device, id).Error + if err != nil { + return nil, err + } + return &device, nil +} + // FindAPNSDeviceByToken finds an APNS device by registration token func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) { var device models.APNSDevice @@ -243,12 +270,12 @@ func (r *NotificationRepository) DeactivateGCMDevice(id uint) error { // GetActiveTokensForUser gets all active push tokens for a user func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) { apnsDevices, err := r.FindAPNSDevicesByUser(userID) - if err != nil && err != gorm.ErrRecordNotFound { + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, err } gcmDevices, err := r.FindGCMDevicesByUser(userID) - if err != nil && err != gorm.ErrRecordNotFound { + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, err } diff --git a/internal/repositories/notification_repo_test.go b/internal/repositories/notification_repo_test.go new file mode 100644 index 0000000..ba00159 --- /dev/null +++ b/internal/repositories/notification_repo_test.go @@ -0,0 +1,96 @@ +package repositories + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/models" + "github.com/treytartt/casera-api/internal/testutil" +) + +func TestGetOrCreatePreferences_New_Creates(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewNotificationRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // No preferences exist yet for this user + prefs, err := repo.GetOrCreatePreferences(user.ID) + require.NoError(t, err) + require.NotNil(t, prefs) + + // Verify defaults were set + assert.Equal(t, user.ID, prefs.UserID) + assert.True(t, prefs.TaskDueSoon) + assert.True(t, prefs.TaskOverdue) + assert.True(t, prefs.TaskCompleted) + assert.True(t, prefs.TaskAssigned) + assert.True(t, prefs.ResidenceShared) + assert.True(t, prefs.WarrantyExpiring) + assert.True(t, prefs.EmailTaskCompleted) + + // Verify it was actually persisted + var count int64 + db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(1), count, "should have exactly one preferences record") +} + +func TestGetOrCreatePreferences_AlreadyExists_Returns(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewNotificationRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // Create preferences manually first + existingPrefs := &models.NotificationPreference{ + UserID: user.ID, + TaskDueSoon: true, + TaskOverdue: true, + TaskCompleted: true, + TaskAssigned: true, + ResidenceShared: true, + WarrantyExpiring: true, + EmailTaskCompleted: true, + } + err := db.Create(existingPrefs).Error + require.NoError(t, err) + require.NotZero(t, existingPrefs.ID) + + // GetOrCreatePreferences should return the existing record, not create a new one + prefs, err := repo.GetOrCreatePreferences(user.ID) + require.NoError(t, err) + require.NotNil(t, prefs) + + // The returned record should have the same ID as the existing one + assert.Equal(t, existingPrefs.ID, prefs.ID, "should return the existing record by ID") + assert.Equal(t, user.ID, prefs.UserID, "should have correct user_id") + + // Verify still only one record exists (no duplicate created) + var count int64 + db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(1), count, "should still have exactly one preferences record") +} + +func TestGetOrCreatePreferences_Idempotent(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewNotificationRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // Call twice in succession + prefs1, err := repo.GetOrCreatePreferences(user.ID) + require.NoError(t, err) + + prefs2, err := repo.GetOrCreatePreferences(user.ID) + require.NoError(t, err) + + // Both should return the same record + assert.Equal(t, prefs1.ID, prefs2.ID) + + // Should only have one record + var count int64 + db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(1), count, "should have exactly one preferences record after two calls") +} diff --git a/internal/repositories/reminder_repo.go b/internal/repositories/reminder_repo.go index 1fbfd4f..225b3f3 100644 --- a/internal/repositories/reminder_repo.go +++ b/internal/repositories/reminder_repo.go @@ -37,6 +37,84 @@ func (r *ReminderRepository) HasSentReminder(taskID, userID uint, dueDate time.T return count > 0, nil } +// ReminderKey uniquely identifies a reminder that may have been sent. +type ReminderKey struct { + TaskID uint + UserID uint + DueDate time.Time + Stage models.ReminderStage +} + +// HasSentReminderBatch checks which reminders from the given list have already been sent. +// Returns a set of indices into the input slice that have already been sent. +// This replaces N individual HasSentReminder calls with a single query. +func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]bool, error) { + result := make(map[int]bool) + if len(keys) == 0 { + return result, nil + } + + // Build a lookup from (task_id, user_id, due_date, stage) -> index + type normalizedKey struct { + TaskID uint + UserID uint + DueDate string + Stage models.ReminderStage + } + keyToIdx := make(map[normalizedKey][]int, len(keys)) + + // Collect unique task IDs and user IDs for the WHERE clause + taskIDSet := make(map[uint]bool) + userIDSet := make(map[uint]bool) + for i, k := range keys { + taskIDSet[k.TaskID] = true + userIDSet[k.UserID] = true + dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC) + nk := normalizedKey{ + TaskID: k.TaskID, + UserID: k.UserID, + DueDate: dueDateOnly.Format("2006-01-02"), + Stage: k.Stage, + } + keyToIdx[nk] = append(keyToIdx[nk], i) + } + + taskIDs := make([]uint, 0, len(taskIDSet)) + for id := range taskIDSet { + taskIDs = append(taskIDs, id) + } + userIDs := make([]uint, 0, len(userIDSet)) + for id := range userIDSet { + userIDs = append(userIDs, id) + } + + // Query all matching reminder logs in one query + var logs []models.TaskReminderLog + err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs). + Find(&logs).Error + if err != nil { + return nil, err + } + + // Match returned logs against our key set + for _, l := range logs { + dueDateStr := l.DueDate.Format("2006-01-02") + nk := normalizedKey{ + TaskID: l.TaskID, + UserID: l.UserID, + DueDate: dueDateStr, + Stage: l.ReminderStage, + } + if indices, ok := keyToIdx[nk]; ok { + for _, idx := range indices { + result[idx] = true + } + } + } + + return result, nil +} + // LogReminder records that a reminder was sent. // Returns the created log entry or an error if the reminder was already sent // (unique constraint violation). diff --git a/internal/repositories/residence_repo.go b/internal/repositories/residence_repo.go index 3663465..8308f85 100644 --- a/internal/repositories/residence_repo.go +++ b/internal/repositories/residence_repo.go @@ -6,6 +6,7 @@ import ( "math/big" "time" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/models" @@ -269,7 +270,9 @@ func (r *ResidenceRepository) GetActiveShareCode(residenceID uint) (*models.Resi // Check if expired if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) { // Auto-deactivate expired code - r.DeactivateShareCode(shareCode.ID) + if err := r.DeactivateShareCode(shareCode.ID); err != nil { + log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate expired share code") + } return nil, nil } @@ -296,9 +299,11 @@ func (r *ResidenceRepository) generateUniqueCode() (string, error) { // Check if code already exists var count int64 - r.db.Model(&models.ResidenceShareCode{}). + if err := r.db.Model(&models.ResidenceShareCode{}). Where("code = ? AND is_active = ?", codeStr, true). - Count(&count) + Count(&count).Error; err != nil { + return "", err + } if count == 0 { return codeStr, nil diff --git a/internal/repositories/subscription_repo.go b/internal/repositories/subscription_repo.go index 5e29d45..4c2eb5c 100644 --- a/internal/repositories/subscription_repo.go +++ b/internal/repositories/subscription_repo.go @@ -1,9 +1,11 @@ package repositories import ( + "errors" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/treytartt/casera-api/internal/models" ) @@ -30,31 +32,37 @@ func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscrip return &sub, nil } -// GetOrCreate gets or creates a subscription for a user (defaults to free tier) +// GetOrCreate gets or creates a subscription for a user (defaults to free tier). +// Uses a transaction to avoid TOCTOU race conditions on concurrent requests. func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) { - sub, err := r.FindByUserID(userID) - if err == nil { - return sub, nil - } + var sub models.UserSubscription - if err == gorm.ErrRecordNotFound { - sub = &models.UserSubscription{ - UserID: userID, - Tier: models.TierFree, + err := r.db.Transaction(func(tx *gorm.DB) error { + err := tx.Where("user_id = ?", userID).First(&sub).Error + if err == nil { + return nil // Found existing subscription + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err // Unexpected error + } + + // Record not found -- create with free tier defaults + sub = models.UserSubscription{ + UserID: userID, + Tier: models.TierFree, AutoRenew: true, } - if err := r.db.Create(sub).Error; err != nil { - return nil, err - } - return sub, nil + return tx.Create(&sub).Error + }) + if err != nil { + return nil, err } - - return nil, err + return &sub, nil } // Update updates a subscription func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error { - return r.db.Save(sub).Error + return r.db.Omit("User").Save(sub).Error } // UpgradeToPro upgrades a user to Pro tier using a transaction with row locking @@ -63,7 +71,7 @@ func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, return r.db.Transaction(func(tx *gorm.DB) error { // Lock the row for update var sub models.UserSubscription - if err := tx.Set("gorm:query_option", "FOR UPDATE"). + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Where("user_id = ?", userID).First(&sub).Error; err != nil { return err } @@ -86,7 +94,7 @@ func (r *SubscriptionRepository) DowngradeToFree(userID uint) error { return r.db.Transaction(func(tx *gorm.DB) error { // Lock the row for update var sub models.UserSubscription - if err := tx.Set("gorm:query_option", "FOR UPDATE"). + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Where("user_id = ?", userID).First(&sub).Error; err != nil { return err } @@ -165,7 +173,7 @@ func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*m var limits models.TierLimits err := r.db.Where("tier = ?", tier).First(&limits).Error if err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { // Return defaults if tier == models.TierFree { defaults := models.GetDefaultFreeLimits() @@ -193,7 +201,7 @@ func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, er var settings models.SubscriptionSettings err := r.db.First(&settings).Error if err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { // Return default settings (limitations disabled) return &models.SubscriptionSettings{ EnableLimitations: false, diff --git a/internal/repositories/subscription_repo_test.go b/internal/repositories/subscription_repo_test.go new file mode 100644 index 0000000..d70b748 --- /dev/null +++ b/internal/repositories/subscription_repo_test.go @@ -0,0 +1,79 @@ +package repositories + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/models" + "github.com/treytartt/casera-api/internal/testutil" +) + +func TestGetOrCreate_New_CreatesFreeTier(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewSubscriptionRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + sub, err := repo.GetOrCreate(user.ID) + require.NoError(t, err) + require.NotNil(t, sub) + + assert.Equal(t, user.ID, sub.UserID) + assert.Equal(t, models.TierFree, sub.Tier) + assert.True(t, sub.AutoRenew) + + // Verify persisted + var count int64 + db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(1), count, "should have exactly one subscription record") +} + +func TestGetOrCreate_AlreadyExists_Returns(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewSubscriptionRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // Create a pro subscription manually + existing := &models.UserSubscription{ + UserID: user.ID, + Tier: models.TierPro, + AutoRenew: true, + } + err := db.Create(existing).Error + require.NoError(t, err) + + // GetOrCreate should return existing, not overwrite with free defaults + sub, err := repo.GetOrCreate(user.ID) + require.NoError(t, err) + require.NotNil(t, sub) + + assert.Equal(t, existing.ID, sub.ID, "should return the existing record by ID") + assert.Equal(t, models.TierPro, sub.Tier, "should preserve existing pro tier, not overwrite with free") + + // Verify still only one record + var count int64 + db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(1), count, "should still have exactly one subscription record") +} + +func TestGetOrCreate_Idempotent(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewSubscriptionRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + sub1, err := repo.GetOrCreate(user.ID) + require.NoError(t, err) + + sub2, err := repo.GetOrCreate(user.ID) + require.NoError(t, err) + + assert.Equal(t, sub1.ID, sub2.ID) + + var count int64 + db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls") +} diff --git a/internal/repositories/task_repo.go b/internal/repositories/task_repo.go index e34e559..fd6b954 100644 --- a/internal/repositories/task_repo.go +++ b/internal/repositories/task_repo.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/models" @@ -25,6 +26,50 @@ func NewTaskRepository(db *gorm.DB) *TaskRepository { return &TaskRepository{db: db} } +// DB returns the underlying database connection. +// Used by services that need to run transactions spanning multiple operations. +func (r *TaskRepository) DB() *gorm.DB { + return r.db +} + +// CreateCompletionTx creates a new task completion within an existing transaction. +func (r *TaskRepository) CreateCompletionTx(tx *gorm.DB, completion *models.TaskCompletion) error { + return tx.Create(completion).Error +} + +// UpdateTx updates a task with optimistic locking within an existing transaction. +func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error { + result := tx.Model(task). + Where("id = ? AND version = ?", task.ID, task.Version). + Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions"). + Updates(map[string]interface{}{ + "title": task.Title, + "description": task.Description, + "category_id": task.CategoryID, + "priority_id": task.PriorityID, + "frequency_id": task.FrequencyID, + "custom_interval_days": task.CustomIntervalDays, + "in_progress": task.InProgress, + "assigned_to_id": task.AssignedToID, + "due_date": task.DueDate, + "next_due_date": task.NextDueDate, + "estimated_cost": task.EstimatedCost, + "actual_cost": task.ActualCost, + "contractor_id": task.ContractorID, + "is_cancelled": task.IsCancelled, + "is_archived": task.IsArchived, + "version": gorm.Expr("version + 1"), + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return ErrVersionConflict + } + task.Version++ // Update local copy + return nil +} + // === Task Filter Options === // TaskFilterOptions provides flexible filtering for task queries. @@ -495,55 +540,39 @@ func buildKanbanColumns( } // GetKanbanData retrieves tasks organized for kanban display. -// Uses single-purpose query functions for each column type, ensuring consistency -// with notification handlers that use the same functions. +// Fetches all non-cancelled, non-archived tasks for the residence in a single query, +// then categorizes them in-memory using the task categorization chain for consistency +// with the predicate-based logic used throughout the application. // The `now` parameter should be the start of day in the user's timezone for accurate overdue detection. // -// Optimization: Preloads only minimal completion data (id, task_id, completed_at) for count/detection. +// Optimization: Single query with preloads, then in-memory categorization. // Images and CompletedBy are NOT preloaded - fetch separately when viewing completion details. func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int, now time.Time) (*models.KanbanBoard, error) { - opts := TaskFilterOptions{ - ResidenceID: residenceID, - PreloadCreatedBy: true, - PreloadAssignedTo: true, - PreloadCompletions: true, + // Fetch all tasks for this residence in a single query (excluding cancelled/archived) + var allTasks []models.Task + query := r.db.Model(&models.Task{}). + Where("task_task.residence_id = ?", residenceID). + Preload("CreatedBy"). + Preload("AssignedTo"). + Preload("Completions", func(db *gorm.DB) *gorm.DB { + return db.Select("id", "task_id", "completed_at") + }). + Scopes(task.ScopeKanbanOrder) + + if err := query.Find(&allTasks).Error; err != nil { + return nil, fmt.Errorf("get tasks for kanban: %w", err) } - // Query each column using single-purpose functions - // These functions use the same scopes as notification handlers for consistency - overdue, err := r.GetOverdueTasks(now, opts) - if err != nil { - return nil, fmt.Errorf("get overdue tasks: %w", err) - } + // Categorize all tasks in-memory using the categorization chain + columnMap := categorization.CategorizeTasksIntoColumnsWithTime(allTasks, daysThreshold, now) - inProgress, err := r.GetInProgressTasks(opts) - if err != nil { - return nil, fmt.Errorf("get in-progress tasks: %w", err) - } - - dueSoon, err := r.GetDueSoonTasks(now, daysThreshold, opts) - if err != nil { - return nil, fmt.Errorf("get due-soon tasks: %w", err) - } - - upcoming, err := r.GetUpcomingTasks(now, daysThreshold, opts) - if err != nil { - return nil, fmt.Errorf("get upcoming tasks: %w", err) - } - - completed, err := r.GetCompletedTasks(opts) - if err != nil { - return nil, fmt.Errorf("get completed tasks: %w", err) - } - - // Intentionally hidden from board: - // cancelled/archived tasks are not returned as a kanban column. - // cancelled, err := r.GetCancelledTasks(opts) - // if err != nil { - // return nil, fmt.Errorf("get cancelled tasks: %w", err) - // } - - columns := buildKanbanColumns(overdue, inProgress, dueSoon, upcoming, completed) + columns := buildKanbanColumns( + columnMap[categorization.ColumnOverdue], + columnMap[categorization.ColumnInProgress], + columnMap[categorization.ColumnDueSoon], + columnMap[categorization.ColumnUpcoming], + columnMap[categorization.ColumnCompleted], + ) return &models.KanbanBoard{ Columns: columns, @@ -553,56 +582,39 @@ func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int, now } // GetKanbanDataForMultipleResidences retrieves tasks from multiple residences organized for kanban display. -// Uses single-purpose query functions for each column type, ensuring consistency -// with notification handlers that use the same functions. +// Fetches all tasks in a single query, then categorizes them in-memory using the +// task categorization chain for consistency with predicate-based logic. // The `now` parameter should be the start of day in the user's timezone for accurate overdue detection. // -// Optimization: Preloads only minimal completion data (id, task_id, completed_at) for count/detection. +// Optimization: Single query with preloads, then in-memory categorization. // Images and CompletedBy are NOT preloaded - fetch separately when viewing completion details. func (r *TaskRepository) GetKanbanDataForMultipleResidences(residenceIDs []uint, daysThreshold int, now time.Time) (*models.KanbanBoard, error) { - opts := TaskFilterOptions{ - ResidenceIDs: residenceIDs, - PreloadCreatedBy: true, - PreloadAssignedTo: true, - PreloadResidence: true, - PreloadCompletions: true, + // Fetch all tasks for these residences in a single query (excluding cancelled/archived) + var allTasks []models.Task + query := r.db.Model(&models.Task{}). + Where("task_task.residence_id IN ?", residenceIDs). + Preload("CreatedBy"). + Preload("AssignedTo"). + Preload("Residence"). + Preload("Completions", func(db *gorm.DB) *gorm.DB { + return db.Select("id", "task_id", "completed_at") + }). + Scopes(task.ScopeKanbanOrder) + + if err := query.Find(&allTasks).Error; err != nil { + return nil, fmt.Errorf("get tasks for kanban: %w", err) } - // Query each column using single-purpose functions - // These functions use the same scopes as notification handlers for consistency - overdue, err := r.GetOverdueTasks(now, opts) - if err != nil { - return nil, fmt.Errorf("get overdue tasks: %w", err) - } + // Categorize all tasks in-memory using the categorization chain + columnMap := categorization.CategorizeTasksIntoColumnsWithTime(allTasks, daysThreshold, now) - inProgress, err := r.GetInProgressTasks(opts) - if err != nil { - return nil, fmt.Errorf("get in-progress tasks: %w", err) - } - - dueSoon, err := r.GetDueSoonTasks(now, daysThreshold, opts) - if err != nil { - return nil, fmt.Errorf("get due-soon tasks: %w", err) - } - - upcoming, err := r.GetUpcomingTasks(now, daysThreshold, opts) - if err != nil { - return nil, fmt.Errorf("get upcoming tasks: %w", err) - } - - completed, err := r.GetCompletedTasks(opts) - if err != nil { - return nil, fmt.Errorf("get completed tasks: %w", err) - } - - // Intentionally hidden from board: - // cancelled/archived tasks are not returned as a kanban column. - // cancelled, err := r.GetCancelledTasks(opts) - // if err != nil { - // return nil, fmt.Errorf("get cancelled tasks: %w", err) - // } - - columns := buildKanbanColumns(overdue, inProgress, dueSoon, upcoming, completed) + columns := buildKanbanColumns( + columnMap[categorization.ColumnOverdue], + columnMap[categorization.ColumnInProgress], + columnMap[categorization.ColumnDueSoon], + columnMap[categorization.ColumnUpcoming], + columnMap[categorization.ColumnCompleted], + ) return &models.KanbanBoard{ Columns: columns, @@ -653,6 +665,19 @@ func (r *TaskRepository) CountByResidence(residenceID uint) (int64, error) { return count, err } +// CountByResidenceIDs counts all active tasks across multiple residences in a single query. +// Returns the total count of non-cancelled, non-archived tasks for the given residence IDs. +func (r *TaskRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) { + if len(residenceIDs) == 0 { + return 0, nil + } + var count int64 + err := r.db.Model(&models.Task{}). + Where("residence_id IN ? AND is_cancelled = ? AND is_archived = ?", residenceIDs, false, false). + Count(&count).Error + return count, err +} + // === Task Completion Operations === // CreateCompletion creates a new task completion @@ -705,7 +730,9 @@ func (r *TaskRepository) UpdateCompletion(completion *models.TaskCompletion) err // DeleteCompletion deletes a task completion func (r *TaskRepository) DeleteCompletion(id uint) error { // Delete images first - r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}) + if err := r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil { + log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images") + } return r.db.Delete(&models.TaskCompletion{}, id).Error } diff --git a/internal/repositories/task_repo_test.go b/internal/repositories/task_repo_test.go index cc53d50..fa9ef7f 100644 --- a/internal/repositories/task_repo_test.go +++ b/internal/repositories/task_repo_test.go @@ -2097,3 +2097,170 @@ func TestConsistency_OverduePredicateVsScopeVsRepo(t *testing.T) { } assert.Equal(t, expectedCount, len(repoTasks), "Overdue task count mismatch") } + +// TestGetKanbanData_CategorizesCorrectly verifies the single-query kanban approach +// produces correct column assignments for various task states. +func TestGetKanbanData_CategorizesCorrectly(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewTaskRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + now := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC) + yesterday := now.AddDate(0, 0, -1) + tomorrow := now.AddDate(0, 0, 1) + nextMonth := now.AddDate(0, 1, 0) + + // Create overdue task (due yesterday) + overdueTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Overdue Task", + DueDate: &yesterday, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(overdueTask).Error) + + // Create due-soon task (due tomorrow, within 30-day threshold) + dueSoonTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Due Soon Task", + DueDate: &tomorrow, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(dueSoonTask).Error) + + // Create upcoming task (due next month, outside 30-day threshold) + upcomingTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Upcoming Task", + DueDate: &nextMonth, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(upcomingTask).Error) + + // Create in-progress task + inProgressTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "In Progress Task", + DueDate: &tomorrow, + InProgress: true, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(inProgressTask).Error) + + // Create completed task (no next due date, has completion) + completedTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Completed Task", + DueDate: &yesterday, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(completedTask).Error) + completion := &models.TaskCompletion{ + TaskID: completedTask.ID, + CompletedByID: user.ID, + CompletedAt: now, + } + require.NoError(t, db.Create(completion).Error) + + // Create cancelled task (should NOT appear in kanban columns) + cancelledTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Cancelled Task", + DueDate: &yesterday, + IsCancelled: true, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(cancelledTask).Error) + + // Create archived task (should NOT appear in active kanban columns) + archivedTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Archived Task", + DueDate: &yesterday, + IsCancelled: false, + IsArchived: true, + Version: 1, + } + require.NoError(t, db.Create(archivedTask).Error) + + // Create no-due-date task (should go to upcoming) + noDueDateTask := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "No Due Date Task", + IsCancelled: false, + IsArchived: false, + Version: 1, + } + require.NoError(t, db.Create(noDueDateTask).Error) + + // Execute kanban data retrieval + board, err := repo.GetKanbanData(residence.ID, 30, now) + require.NoError(t, err) + require.NotNil(t, board) + require.Len(t, board.Columns, 5, "Should have 5 visible columns") + + // Build a map of column name -> task titles for easy assertion + columnTasks := make(map[string][]string) + for _, col := range board.Columns { + var titles []string + for _, task := range col.Tasks { + titles = append(titles, task.Title) + } + columnTasks[col.Name] = titles + } + + // Verify overdue column + assert.Contains(t, columnTasks["overdue_tasks"], "Overdue Task", + "Overdue task should be in overdue column") + + // Verify in-progress column + assert.Contains(t, columnTasks["in_progress_tasks"], "In Progress Task", + "In-progress task should be in in-progress column") + + // Verify due-soon column + assert.Contains(t, columnTasks["due_soon_tasks"], "Due Soon Task", + "Due-soon task should be in due-soon column") + + // Verify upcoming column contains both upcoming and no-due-date tasks + assert.Contains(t, columnTasks["upcoming_tasks"], "No Due Date Task", + "No-due-date task should be in upcoming column") + + // Verify completed column + assert.Contains(t, columnTasks["completed_tasks"], "Completed Task", + "Completed task should be in completed column") + + // Verify cancelled and archived tasks are categorized to the cancelled column + // (which is present in categorization but hidden from visible kanban columns) + // The cancelled/archived tasks should NOT appear in any of the 5 visible columns + allVisibleTitles := make(map[string]bool) + for _, col := range board.Columns { + for _, task := range col.Tasks { + allVisibleTitles[task.Title] = true + } + } + assert.False(t, allVisibleTitles["Cancelled Task"], + "Cancelled task should not appear in visible kanban columns") + assert.False(t, allVisibleTitles["Archived Task"], + "Archived task should not appear in visible kanban columns") +} diff --git a/internal/repositories/task_template_repo.go b/internal/repositories/task_template_repo.go index 07409d6..672afcd 100644 --- a/internal/repositories/task_template_repo.go +++ b/internal/repositories/task_template_repo.go @@ -45,7 +45,8 @@ func (r *TaskTemplateRepository) GetByCategory(categoryID uint) ([]models.TaskTe // Search searches templates by title and tags func (r *TaskTemplateRepository) Search(query string) ([]models.TaskTemplate, error) { var templates []models.TaskTemplate - searchTerm := "%" + strings.ToLower(query) + "%" + escaped := escapeLikeWildcards(strings.ToLower(query)) + searchTerm := "%" + escaped + "%" err := r.db. Preload("Category"). @@ -77,7 +78,7 @@ func (r *TaskTemplateRepository) Create(template *models.TaskTemplate) error { // Update updates an existing task template func (r *TaskTemplateRepository) Update(template *models.TaskTemplate) error { - return r.db.Save(template).Error + return r.db.Omit("Category", "Frequency").Save(template).Error } // Delete hard deletes a task template diff --git a/internal/repositories/util.go b/internal/repositories/util.go new file mode 100644 index 0000000..ca6acc1 --- /dev/null +++ b/internal/repositories/util.go @@ -0,0 +1,11 @@ +package repositories + +import "strings" + +// escapeLikeWildcards escapes SQL LIKE wildcard characters in user input +// to prevent users from injecting wildcards like % or _ into search queries. +func escapeLikeWildcards(s string) string { + s = strings.ReplaceAll(s, "%", "\\%") + s = strings.ReplaceAll(s, "_", "\\_") + return s +} diff --git a/internal/router/router.go b/internal/router/router.go index 5ed9334..ea047f7 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -129,6 +129,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo { taskService.SetResidenceService(residenceService) // For including TotalSummary in CRUD responses taskService.SetStorageService(deps.StorageService) // For reading completion images for email subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo) + residenceService.SetSubscriptionService(subscriptionService) // Wire up subscription service for tier limit enforcement taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo) // Initialize webhook event repo for deduplication diff --git a/internal/services/contractor_service.go b/internal/services/contractor_service.go index ca3d4fe..f5654f3 100644 --- a/internal/services/contractor_service.go +++ b/internal/services/contractor_service.go @@ -195,6 +195,18 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req if req.IsFavorite != nil { contractor.IsFavorite = *req.IsFavorite } + // If residence_id is provided, verify the user has access to the NEW residence. + // This prevents an attacker from reassigning a contractor to someone else's residence. + if req.ResidenceID != nil { + hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID) + if err != nil { + return nil, apperrors.Internal(err) + } + if !hasAccess { + return nil, apperrors.Forbidden("error.residence_access_denied") + } + } + // If residence_id is not sent in the request (nil), it means the user // removed the residence association - contractor becomes personal contractor.ResidenceID = req.ResidenceID diff --git a/internal/services/contractor_service_test.go b/internal/services/contractor_service_test.go new file mode 100644 index 0000000..090ac0e --- /dev/null +++ b/internal/services/contractor_service_test.go @@ -0,0 +1,98 @@ +package services + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/dto/requests" + "github.com/treytartt/casera-api/internal/repositories" + "github.com/treytartt/casera-api/internal/testutil" +) + +func setupContractorService(t *testing.T) (*ContractorService, *repositories.ContractorRepository, *repositories.ResidenceRepository) { + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + contractorRepo := repositories.NewContractorRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewContractorService(contractorRepo, residenceRepo) + return service, contractorRepo, residenceRepo +} + +func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) { + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + contractorRepo := repositories.NewContractorRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewContractorService(contractorRepo, residenceRepo) + + // Create two users: owner and attacker + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password") + + // Owner creates a residence + ownerResidence := testutil.CreateTestResidence(t, db, owner.ID, "Owner House") + + // Attacker creates a residence and a contractor in their residence + attackerResidence := testutil.CreateTestResidence(t, db, attacker.ID, "Attacker House") + contractor := testutil.CreateTestContractor(t, db, attackerResidence.ID, attacker.ID, "My Contractor") + + // Attacker tries to reassign their contractor to the owner's residence + // This should be denied because the attacker does not have access to the owner's residence + newResidenceID := ownerResidence.ID + req := &requests.UpdateContractorRequest{ + ResidenceID: &newResidenceID, + } + + _, err := service.UpdateContractor(contractor.ID, attacker.ID, req) + require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to") + testutil.AssertAppErrorCode(t, err, http.StatusForbidden) +} + +func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) { + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + contractorRepo := repositories.NewContractorRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewContractorService(contractorRepo, residenceRepo) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence1 := testutil.CreateTestResidence(t, db, owner.ID, "House 1") + residence2 := testutil.CreateTestResidence(t, db, owner.ID, "House 2") + contractor := testutil.CreateTestContractor(t, db, residence1.ID, owner.ID, "My Contractor") + + // Owner reassigns contractor to their other residence - should succeed + newResidenceID := residence2.ID + newName := "Updated Contractor" + req := &requests.UpdateContractorRequest{ + Name: &newName, + ResidenceID: &newResidenceID, + } + + resp, err := service.UpdateContractor(contractor.ID, owner.ID, req) + require.NoError(t, err, "should allow reassigning contractor to a residence the user owns") + require.NotNil(t, resp) + require.Equal(t, "Updated Contractor", resp.Name) +} + +func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) { + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + contractorRepo := repositories.NewContractorRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewContractorService(contractorRepo, residenceRepo) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, owner.ID, "My House") + contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "My Contractor") + + // Setting ResidenceID to nil should remove the residence association (make it personal) + req := &requests.UpdateContractorRequest{ + ResidenceID: nil, + } + + resp, err := service.UpdateContractor(contractor.ID, owner.ID, req) + require.NoError(t, err, "should allow removing residence association") + require.NotNil(t, resp) +} diff --git a/internal/services/iap_validation.go b/internal/services/iap_validation.go index e7c6306..45ecda4 100644 --- a/internal/services/iap_validation.go +++ b/internal/services/iap_validation.go @@ -323,10 +323,21 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string }, nil } -// validateLegacyReceipt uses Apple's legacy verifyReceipt endpoint +// validateLegacyReceipt uses Apple's legacy verifyReceipt endpoint. +// It delegates to validateLegacyReceiptWithSandbox using the client's +// configured sandbox setting. This avoids mutating the struct field +// during the sandbox-retry flow, which caused a data race when +// multiple goroutines shared the same AppleIAPClient. func (c *AppleIAPClient) validateLegacyReceipt(ctx context.Context, receiptData string) (*AppleValidationResult, error) { + return c.validateLegacyReceiptWithSandbox(ctx, receiptData, c.sandbox) +} + +// validateLegacyReceiptWithSandbox performs legacy receipt validation against +// the specified environment. The sandbox parameter is passed by value (not +// stored on the struct) so this function is safe for concurrent use. +func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, receiptData string, useSandbox bool) (*AppleValidationResult, error) { url := "https://buy.itunes.apple.com/verifyReceipt" - if c.sandbox { + if useSandbox { url = "https://sandbox.itunes.apple.com/verifyReceipt" } @@ -378,12 +389,10 @@ func (c *AppleIAPClient) validateLegacyReceipt(ctx context.Context, receiptData } // Status codes: 0 = valid, 21007 = sandbox receipt on production, 21008 = production receipt on sandbox - if legacyResponse.Status == 21007 && !c.sandbox { - // Retry with sandbox - c.sandbox = true - result, err := c.validateLegacyReceipt(ctx, receiptData) - c.sandbox = false - return result, err + if legacyResponse.Status == 21007 && !useSandbox { + // Retry with sandbox -- pass sandbox=true as a parameter instead of + // mutating c.sandbox, which avoids a data race. + return c.validateLegacyReceiptWithSandbox(ctx, receiptData, true) } if legacyResponse.Status != 0 { diff --git a/internal/services/notification_service.go b/internal/services/notification_service.go index a515670..f303e9c 100644 --- a/internal/services/notification_service.go +++ b/internal/services/notification_service.go @@ -355,20 +355,43 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) return result, nil } -// DeleteDevice deletes a device +// DeleteDevice deactivates a device after verifying it belongs to the requesting user. +// Without ownership verification, an attacker could deactivate push notifications for other users. func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error { - var err error switch platform { case push.PlatformIOS: - err = s.notificationRepo.DeactivateAPNSDevice(deviceID) + device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return apperrors.NotFound("error.device_not_found") + } + return apperrors.Internal(err) + } + // Verify the device belongs to the requesting user + if device.UserID == nil || *device.UserID != userID { + return apperrors.Forbidden("error.device_access_denied") + } + if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil { + return apperrors.Internal(err) + } case push.PlatformAndroid: - err = s.notificationRepo.DeactivateGCMDevice(deviceID) + device, err := s.notificationRepo.FindGCMDeviceByID(deviceID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return apperrors.NotFound("error.device_not_found") + } + return apperrors.Internal(err) + } + // Verify the device belongs to the requesting user + if device.UserID == nil || *device.UserID != userID { + return apperrors.Forbidden("error.device_access_denied") + } + if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil { + return apperrors.Internal(err) + } default: return apperrors.BadRequest("error.invalid_platform") } - if err != nil { - return apperrors.Internal(err) - } return nil } @@ -549,9 +572,9 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse { // RegisterDeviceRequest represents device registration request type RegisterDeviceRequest struct { Name string `json:"name"` - DeviceID string `json:"device_id" binding:"required"` - RegistrationID string `json:"registration_id" binding:"required"` - Platform string `json:"platform" binding:"required,oneof=ios android"` + DeviceID string `json:"device_id" validate:"required"` + RegistrationID string `json:"registration_id" validate:"required"` + Platform string `json:"platform" validate:"required,oneof=ios android"` } // === Task Notifications with Actions === diff --git a/internal/services/notification_service_test.go b/internal/services/notification_service_test.go new file mode 100644 index 0000000..5b57e9c --- /dev/null +++ b/internal/services/notification_service_test.go @@ -0,0 +1,126 @@ +package services + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/models" + "github.com/treytartt/casera-api/internal/push" + "github.com/treytartt/casera-api/internal/repositories" + "github.com/treytartt/casera-api/internal/testutil" +) + +func setupNotificationService(t *testing.T) (*NotificationService, *repositories.NotificationRepository) { + db := testutil.SetupTestDB(t) + notifRepo := repositories.NewNotificationRepository(db) + // pushClient is nil for testing (no actual push sends) + service := NewNotificationService(notifRepo, nil) + return service, notifRepo +} + +func TestDeleteDevice_WrongUser_Returns403(t *testing.T) { + db := testutil.SetupTestDB(t) + notifRepo := repositories.NewNotificationRepository(db) + service := NewNotificationService(notifRepo, nil) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password") + + // Register an iOS device for the owner + device := &models.APNSDevice{ + UserID: &owner.ID, + Name: "Owner iPhone", + DeviceID: "device-123", + RegistrationID: "token-abc", + Active: true, + } + err := db.Create(device).Error + require.NoError(t, err) + + // Attacker tries to deactivate the owner's device + err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID) + require.Error(t, err, "should not allow deleting another user's device") + testutil.AssertAppErrorCode(t, err, http.StatusForbidden) + + // Verify the device is still active + var found models.APNSDevice + err = db.First(&found, device.ID).Error + require.NoError(t, err) + assert.True(t, found.Active, "device should still be active after failed deletion") +} + +func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) { + db := testutil.SetupTestDB(t) + notifRepo := repositories.NewNotificationRepository(db) + service := NewNotificationService(notifRepo, nil) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // Register an iOS device for the owner + device := &models.APNSDevice{ + UserID: &owner.ID, + Name: "Owner iPhone", + DeviceID: "device-123", + RegistrationID: "token-abc", + Active: true, + } + err := db.Create(device).Error + require.NoError(t, err) + + // Owner deactivates their own device + err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID) + require.NoError(t, err, "owner should be able to deactivate their own device") + + // Verify the device is now inactive + var found models.APNSDevice + err = db.First(&found, device.ID).Error + require.NoError(t, err) + assert.False(t, found.Active, "device should be deactivated") +} + +func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) { + db := testutil.SetupTestDB(t) + notifRepo := repositories.NewNotificationRepository(db) + service := NewNotificationService(notifRepo, nil) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password") + + // Register an Android device for the owner + device := &models.GCMDevice{ + UserID: &owner.ID, + Name: "Owner Pixel", + DeviceID: "device-456", + RegistrationID: "token-def", + CloudMessageType: "FCM", + Active: true, + } + err := db.Create(device).Error + require.NoError(t, err) + + // Attacker tries to deactivate the owner's Android device + err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID) + require.Error(t, err, "should not allow deleting another user's Android device") + testutil.AssertAppErrorCode(t, err, http.StatusForbidden) + + // Verify the device is still active + var found models.GCMDevice + err = db.First(&found, device.ID).Error + require.NoError(t, err) + assert.True(t, found.Active, "Android device should still be active after failed deletion") +} + +func TestDeleteDevice_NonExistent_Returns404(t *testing.T) { + db := testutil.SetupTestDB(t) + notifRepo := repositories.NewNotificationRepository(db) + service := NewNotificationService(notifRepo, nil) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + err := service.DeleteDevice(99999, push.PlatformIOS, user.ID) + require.Error(t, err, "should return error for non-existent device") + testutil.AssertAppErrorCode(t, err, http.StatusNotFound) +} diff --git a/internal/services/onboarding_email_service.go b/internal/services/onboarding_email_service.go index 28d3801..f9c2b29 100644 --- a/internal/services/onboarding_email_service.go +++ b/internal/services/onboarding_email_service.go @@ -40,9 +40,12 @@ func generateTrackingID() string { // HasSentEmail checks if a specific email type has already been sent to a user func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool { var count int64 - s.db.Model(&models.OnboardingEmail{}). + if err := s.db.Model(&models.OnboardingEmail{}). Where("user_id = ? AND email_type = ?", userID, emailType). - Count(&count) + Count(&count).Error; err != nil { + log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to check if email was sent") + return false + } return count > 0 } @@ -125,23 +128,31 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) // No residence email stats var noResTotal, noResOpened int64 - s.db.Model(&models.OnboardingEmail{}). + if err := s.db.Model(&models.OnboardingEmail{}). Where("email_type = ?", models.OnboardingEmailNoResidence). - Count(&noResTotal) - s.db.Model(&models.OnboardingEmail{}). + Count(&noResTotal).Error; err != nil { + log.Error().Err(err).Msg("Failed to count no-residence emails") + } + if err := s.db.Model(&models.OnboardingEmail{}). Where("email_type = ? AND opened_at IS NOT NULL", models.OnboardingEmailNoResidence). - Count(&noResOpened) + Count(&noResOpened).Error; err != nil { + log.Error().Err(err).Msg("Failed to count opened no-residence emails") + } stats.NoResidenceTotal = noResTotal stats.NoResidenceOpened = noResOpened // No tasks email stats var noTasksTotal, noTasksOpened int64 - s.db.Model(&models.OnboardingEmail{}). + if err := s.db.Model(&models.OnboardingEmail{}). Where("email_type = ?", models.OnboardingEmailNoTasks). - Count(&noTasksTotal) - s.db.Model(&models.OnboardingEmail{}). + Count(&noTasksTotal).Error; err != nil { + log.Error().Err(err).Msg("Failed to count no-tasks emails") + } + if err := s.db.Model(&models.OnboardingEmail{}). Where("email_type = ? AND opened_at IS NOT NULL", models.OnboardingEmailNoTasks). - Count(&noTasksOpened) + Count(&noTasksOpened).Error; err != nil { + log.Error().Err(err).Msg("Failed to count opened no-tasks emails") + } stats.NoTasksTotal = noTasksTotal stats.NoTasksOpened = noTasksOpened @@ -351,7 +362,9 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp // If already sent before, delete the old record first to allow re-recording // This allows admins to "resend" emails while still tracking them if alreadySent { - s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}) + if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil { + log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend") + } } // Record that email was sent diff --git a/internal/services/path_utils.go b/internal/services/path_utils.go new file mode 100644 index 0000000..cdfa9d7 --- /dev/null +++ b/internal/services/path_utils.go @@ -0,0 +1,51 @@ +package services + +import ( + "fmt" + "path/filepath" + "strings" +) + +// SafeResolvePath resolves a user-supplied relative path within a base directory. +// Returns an error if the resolved path escapes the base directory (path traversal). +// The baseDir must be an absolute path. +func SafeResolvePath(baseDir, userInput string) (string, error) { + if userInput == "" { + return "", fmt.Errorf("empty path") + } + + // Reject absolute paths + if filepath.IsAbs(userInput) { + return "", fmt.Errorf("absolute paths not allowed") + } + + // Clean the user input to resolve . and .. components + cleaned := filepath.Clean(userInput) + + // After cleaning, check if it starts with .. (escapes base) + if strings.HasPrefix(cleaned, "..") { + return "", fmt.Errorf("path traversal detected") + } + + // Resolve the base directory to an absolute path + absBase, err := filepath.Abs(baseDir) + if err != nil { + return "", fmt.Errorf("invalid base directory: %w", err) + } + + // Join and resolve the full path + fullPath := filepath.Join(absBase, cleaned) + + // Final containment check: the resolved path must be within the base directory + absFullPath, err := filepath.Abs(fullPath) + if err != nil { + return "", fmt.Errorf("invalid resolved path: %w", err) + } + + // Ensure the resolved path is strictly inside the base directory (not the base itself) + if !strings.HasPrefix(absFullPath, absBase+string(filepath.Separator)) { + return "", fmt.Errorf("path traversal detected") + } + + return absFullPath, nil +} diff --git a/internal/services/path_utils_test.go b/internal/services/path_utils_test.go new file mode 100644 index 0000000..2376908 --- /dev/null +++ b/internal/services/path_utils_test.go @@ -0,0 +1,55 @@ +package services + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSafeResolvePath_Normal_Resolves(t *testing.T) { + result, err := SafeResolvePath("/var/uploads", "images/photo.jpg") + require.NoError(t, err) + assert.Equal(t, "/var/uploads/images/photo.jpg", result) +} + +func TestSafeResolvePath_SubdirPath_Resolves(t *testing.T) { + result, err := SafeResolvePath("/var/uploads", "documents/2024/report.pdf") + require.NoError(t, err) + assert.Equal(t, "/var/uploads/documents/2024/report.pdf", result) +} + +func TestSafeResolvePath_DotDotTraversal_Blocked(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"simple dotdot", "../etc/passwd"}, + {"nested dotdot", "../../etc/shadow"}, + {"embedded dotdot", "images/../../etc/passwd"}, + {"deep dotdot", "a/b/c/../../../../etc/passwd"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := SafeResolvePath("/var/uploads", tt.input) + assert.Error(t, err, "path traversal should be blocked: %s", tt.input) + }) + } +} + +func TestSafeResolvePath_AbsolutePath_Blocked(t *testing.T) { + _, err := SafeResolvePath("/var/uploads", "/etc/passwd") + assert.Error(t, err, "absolute paths should be blocked") +} + +func TestSafeResolvePath_EmptyPath_Blocked(t *testing.T) { + _, err := SafeResolvePath("/var/uploads", "") + assert.Error(t, err, "empty paths should be blocked") +} + +func TestSafeResolvePath_CurrentDir_Blocked(t *testing.T) { + // "." resolves to the base dir itself — this is not a file, so block it + _, err := SafeResolvePath("/var/uploads", ".") + assert.Error(t, err, "bare current directory should be blocked") +} diff --git a/internal/services/pdf_service.go b/internal/services/pdf_service.go index cb6eaae..320944f 100644 --- a/internal/services/pdf_service.go +++ b/internal/services/pdf_service.go @@ -126,10 +126,11 @@ func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte pdf.SetFillColor(255, 255, 255) // White } - // Title (truncate if too long) + // Title (truncate if too long, use runes to avoid cutting multi-byte UTF-8 characters) title := task.Title - if len(title) > 35 { - title = title[:32] + "..." + titleRunes := []rune(title) + if len(titleRunes) > 35 { + title = string(titleRunes[:32]) + "..." } // Status text diff --git a/internal/services/residence_service.go b/internal/services/residence_service.go index 165febc..5b53775 100644 --- a/internal/services/residence_service.go +++ b/internal/services/residence_service.go @@ -4,6 +4,7 @@ import ( "errors" "time" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/apperrors" @@ -31,10 +32,11 @@ var ( // ResidenceService handles residence business logic type ResidenceService struct { - residenceRepo *repositories.ResidenceRepository - userRepo *repositories.UserRepository - taskRepo *repositories.TaskRepository - config *config.Config + residenceRepo *repositories.ResidenceRepository + userRepo *repositories.UserRepository + taskRepo *repositories.TaskRepository + subscriptionService *SubscriptionService + config *config.Config } // NewResidenceService creates a new residence service @@ -51,6 +53,11 @@ func (s *ResidenceService) SetTaskRepository(taskRepo *repositories.TaskReposito s.taskRepo = taskRepo } +// SetSubscriptionService sets the subscription service (used for tier limit enforcement) +func (s *ResidenceService) SetSubscriptionService(subService *SubscriptionService) { + s.subscriptionService = subService +} + // GetResidence gets a residence by ID with access check func (s *ResidenceService) GetResidence(residenceID, userID uint) (*responses.ResidenceResponse, error) { // Check access @@ -152,12 +159,12 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary { // CreateResidence creates a new residence and returns it with updated summary func (s *ResidenceService) CreateResidence(req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) { - // TODO: Check subscription tier limits - // count, err := s.residenceRepo.CountByOwner(ownerID) - // if err != nil { - // return nil, err - // } - // Check against tier limits... + // Check subscription tier limits (if subscription service is wired up) + if s.subscriptionService != nil { + if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil { + return nil, err + } + } isPrimary := true if req.IsPrimary != nil { @@ -447,6 +454,7 @@ func (s *ResidenceService) JoinWithCode(code string, userID uint) (*responses.Jo if err := s.residenceRepo.DeactivateShareCode(shareCode.ID); err != nil { // Log the error but don't fail the join - the user has already been added // The code will just be usable by others until it expires + log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate share code after join") } // Get the residence with full details diff --git a/internal/services/residence_service_test.go b/internal/services/residence_service_test.go index 93a25d8..99417fd 100644 --- a/internal/services/residence_service_test.go +++ b/internal/services/residence_service_test.go @@ -1,15 +1,19 @@ package services import ( + "fmt" "net/http" "testing" + "time" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" "github.com/treytartt/casera-api/internal/config" "github.com/treytartt/casera-api/internal/dto/requests" + "github.com/treytartt/casera-api/internal/models" "github.com/treytartt/casera-api/internal/repositories" "github.com/treytartt/casera-api/internal/testutil" ) @@ -333,3 +337,122 @@ func TestResidenceService_RemoveUser_CannotRemoveOwner(t *testing.T) { err := service.RemoveUser(residence.ID, owner.ID, owner.ID) testutil.AssertAppError(t, err, http.StatusBadRequest, "error.cannot_remove_owner") } + +// setupResidenceServiceWithSubscription creates a ResidenceService wired with a +// SubscriptionService, enabling tier limit enforcement in tests. +func setupResidenceServiceWithSubscription(t *testing.T) (*ResidenceService, *gorm.DB) { + db := testutil.SetupTestDB(t) + residenceRepo := repositories.NewResidenceRepository(db) + userRepo := repositories.NewUserRepository(db) + taskRepo := repositories.NewTaskRepository(db) + contractorRepo := repositories.NewContractorRepository(db) + documentRepo := repositories.NewDocumentRepository(db) + subscriptionRepo := repositories.NewSubscriptionRepository(db) + + cfg := &config.Config{} + service := NewResidenceService(residenceRepo, userRepo, cfg) + subscriptionService := NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo) + service.SetSubscriptionService(subscriptionService) + + return service, db +} + +func TestCreateResidence_FreeTier_EnforcesLimit(t *testing.T) { + service, db := setupResidenceServiceWithSubscription(t) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // Enable global limitations + db.Where("1=1").Delete(&models.SubscriptionSettings{}) + err := db.Create(&models.SubscriptionSettings{EnableLimitations: true}).Error + require.NoError(t, err) + + // Set free tier limit to 1 property + one := 1 + db.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{}) + err = db.Create(&models.TierLimits{ + Tier: models.TierFree, + PropertiesLimit: &one, + }).Error + require.NoError(t, err) + + // Ensure user has a free-tier subscription record + subscriptionRepo := repositories.NewSubscriptionRepository(db) + _, err = subscriptionRepo.GetOrCreate(owner.ID) + require.NoError(t, err) + + // First residence should succeed (under the limit) + req := &requests.CreateResidenceRequest{ + Name: "First House", + StreetAddress: "1 Main St", + City: "Austin", + StateProvince: "TX", + PostalCode: "78701", + } + resp, err := service.CreateResidence(req, owner.ID) + require.NoError(t, err) + assert.Equal(t, "First House", resp.Data.Name) + + // Second residence should be rejected (at the limit) + req2 := &requests.CreateResidenceRequest{ + Name: "Second House", + StreetAddress: "2 Main St", + City: "Austin", + StateProvince: "TX", + PostalCode: "78702", + } + _, err = service.CreateResidence(req2, owner.ID) + testutil.AssertAppError(t, err, http.StatusForbidden, "error.properties_limit_exceeded") +} + +func TestCreateResidence_ProTier_AllowsMore(t *testing.T) { + service, db := setupResidenceServiceWithSubscription(t) + + owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + + // Enable global limitations + db.Where("1=1").Delete(&models.SubscriptionSettings{}) + err := db.Create(&models.SubscriptionSettings{EnableLimitations: true}).Error + require.NoError(t, err) + + // Set free tier limit to 1 property (pro is unlimited by default: nil limits) + one := 1 + db.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{}) + err = db.Create(&models.TierLimits{ + Tier: models.TierFree, + PropertiesLimit: &one, + }).Error + require.NoError(t, err) + + // Create a pro-tier subscription for the user + subscriptionRepo := repositories.NewSubscriptionRepository(db) + sub, err := subscriptionRepo.GetOrCreate(owner.ID) + require.NoError(t, err) + + // Upgrade to Pro with a future expiration + future := time.Now().UTC().Add(30 * 24 * time.Hour) + sub.Tier = models.TierPro + sub.ExpiresAt = &future + sub.SubscribedAt = ptrTime(time.Now().UTC()) + err = subscriptionRepo.Update(sub) + require.NoError(t, err) + + // Create multiple residences — all should succeed for Pro users + for i := 1; i <= 3; i++ { + req := &requests.CreateResidenceRequest{ + Name: fmt.Sprintf("House %d", i), + StreetAddress: fmt.Sprintf("%d Main St", i), + City: "Austin", + StateProvince: "TX", + PostalCode: "78701", + } + resp, err := service.CreateResidence(req, owner.ID) + require.NoError(t, err, "Pro user should be able to create residence %d", i) + assert.Equal(t, fmt.Sprintf("House %d", i), resp.Data.Name) + } +} + +// ptrTime returns a pointer to the given time. +func ptrTime(t time.Time) *time.Time { + return &t +} diff --git a/internal/services/storage_service.go b/internal/services/storage_service.go index 0bcd4fe..1fc6a2f 100644 --- a/internal/services/storage_service.go +++ b/internal/services/storage_service.go @@ -72,7 +72,7 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U if ext == "" { ext = s.getExtensionFromMimeType(mimeType) } - newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String()[:8], ext) + newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String(), ext) // Determine subdirectory based on category subdir := "images" @@ -134,9 +134,15 @@ func (s *StorageService) Delete(fileURL string) error { fullPath := filepath.Join(s.cfg.UploadDir, relativePath) // Security check: ensure path is within upload directory - absUploadDir, _ := filepath.Abs(s.cfg.UploadDir) - absFilePath, _ := filepath.Abs(fullPath) - if !strings.HasPrefix(absFilePath, absUploadDir) { + absUploadDir, err := filepath.Abs(s.cfg.UploadDir) + if err != nil { + return fmt.Errorf("failed to resolve upload directory: %w", err) + } + absFilePath, err := filepath.Abs(fullPath) + if err != nil { + return fmt.Errorf("failed to resolve file path: %w", err) + } + if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir { return fmt.Errorf("invalid file path") } @@ -181,3 +187,9 @@ func (s *StorageService) getExtensionFromMimeType(mimeType string) string { func (s *StorageService) GetUploadDir() string { return s.cfg.UploadDir } + +// NewStorageServiceForTest creates a StorageService without creating directories. +// This is intended only for unit tests that need a StorageService with a known config. +func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService { + return &StorageService{cfg: cfg} +} diff --git a/internal/services/subscription_service.go b/internal/services/subscription_service.go index 042ee33..2096015 100644 --- a/internal/services/subscription_service.go +++ b/internal/services/subscription_service.go @@ -3,9 +3,9 @@ package services import ( "context" "errors" - "log" "time" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/apperrors" @@ -74,11 +74,11 @@ func NewSubscriptionService( appleClient, err := NewAppleIAPClient(cfg.AppleIAP) if err != nil { if !errors.Is(err, ErrIAPNotConfigured) { - log.Printf("Warning: Failed to initialize Apple IAP client: %v", err) + log.Warn().Err(err).Msg("Failed to initialize Apple IAP client") } } else { svc.appleClient = appleClient - log.Println("Apple IAP validation client initialized") + log.Info().Msg("Apple IAP validation client initialized") } // Initialize Google IAP client @@ -86,11 +86,11 @@ func NewSubscriptionService( googleClient, err := NewGoogleIAPClient(ctx, cfg.GoogleIAP) if err != nil { if !errors.Is(err, ErrIAPNotConfigured) { - log.Printf("Warning: Failed to initialize Google IAP client: %v", err) + log.Warn().Err(err).Msg("Failed to initialize Google IAP client") } } else { svc.googleClient = googleClient - log.Println("Google IAP validation client initialized") + log.Info().Msg("Google IAP validation client initialized") } } @@ -173,7 +173,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS return resp, nil } -// getUserUsage calculates current usage for a user +// getUserUsage calculates current usage for a user. +// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)). func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) { residences, err := s.residenceRepo.FindOwnedByUser(userID) if err != nil { @@ -181,26 +182,26 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) } propertiesCount := int64(len(residences)) - // Count tasks, contractors, and documents across all user's residences - var tasksCount, contractorsCount, documentsCount int64 - for _, r := range residences { - tc, err := s.taskRepo.CountByResidence(r.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - tasksCount += tc + // Collect residence IDs for batch queries + residenceIDs := make([]uint, len(residences)) + for i, r := range residences { + residenceIDs[i] = r.ID + } - cc, err := s.contractorRepo.CountByResidence(r.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - contractorsCount += cc + // Count tasks, contractors, and documents across all residences with single queries each + tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs) + if err != nil { + return nil, apperrors.Internal(err) + } - dc, err := s.documentRepo.CountByResidence(r.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - documentsCount += dc + contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs) + if err != nil { + return nil, apperrors.Internal(err) + } + + documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs) + if err != nil { + return nil, apperrors.Internal(err) } return &UsageResponse{ @@ -342,46 +343,40 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri return nil, apperrors.Internal(err) } - // Validate with Apple if client is configured - var expiresAt time.Time - if s.appleClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - var result *AppleValidationResult - var err error - - // Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1) - if transactionID != "" { - result, err = s.appleClient.ValidateTransaction(ctx, transactionID) - } else if receiptData != "" { - result, err = s.appleClient.ValidateReceipt(ctx, receiptData) - } - - if err != nil { - // Log the validation error - log.Printf("Apple validation warning for user %d: %v", userID, err) - - // Check if it's a fatal error - if errors.Is(err, ErrInvalidReceipt) || errors.Is(err, ErrSubscriptionCancelled) { - return nil, err - } - - // For other errors (network, etc.), fall back with shorter expiry - expiresAt = time.Now().UTC().AddDate(0, 1, 0) // 1 month fallback - } else if result != nil { - // Use the expiration date from Apple - expiresAt = result.ExpiresAt - log.Printf("Apple purchase validated for user %d: product=%s, expires=%v, env=%s", - userID, result.ProductID, result.ExpiresAt, result.Environment) - } - } else { - // Apple validation not configured - trust client but log warning - log.Printf("Warning: Apple IAP validation not configured, trusting client for user %d", userID) - expiresAt = time.Now().UTC().AddDate(1, 0, 0) // 1 year default + // Apple IAP client must be configured to validate purchases. + // Without server-side validation, we cannot trust client-provided receipts. + if s.appleClient == nil { + log.Error().Uint("user_id", userID).Msg("Apple IAP validation not configured, rejecting purchase") + return nil, apperrors.BadRequest("error.iap_validation_not_configured") } - // Upgrade to Pro with the determined expiration + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var result *AppleValidationResult + var err error + + // Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1) + if transactionID != "" { + result, err = s.appleClient.ValidateTransaction(ctx, transactionID) + } else if receiptData != "" { + result, err = s.appleClient.ValidateReceipt(ctx, receiptData) + } + + if err != nil { + // Validation failed -- do NOT fall through to grant Pro. + log.Error().Err(err).Uint("user_id", userID).Msg("Apple validation failed") + return nil, err + } + + if result == nil { + return nil, apperrors.BadRequest("error.no_receipt_or_transaction") + } + + expiresAt := result.ExpiresAt + log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated") + + // Upgrade to Pro with the validated expiration if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil { return nil, apperrors.Internal(err) } @@ -397,59 +392,48 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s return nil, apperrors.Internal(err) } - // Validate the purchase with Google if client is configured - var expiresAt time.Time - if s.googleClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - var result *GoogleValidationResult - var err error - - // If productID is provided, use it directly; otherwise try known IDs - if productID != "" { - result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken) - } else { - result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs) - } - - if err != nil { - // Log the validation error - log.Printf("Google purchase validation warning for user %d: %v", userID, err) - - // Check if it's a fatal error - if errors.Is(err, ErrInvalidPurchaseToken) || errors.Is(err, ErrSubscriptionCancelled) { - return nil, err - } - - if errors.Is(err, ErrSubscriptionExpired) { - // Subscription expired - still allow but set past expiry - expiresAt = time.Now().UTC().Add(-1 * time.Hour) - } else { - // For other errors, fall back with shorter expiry - expiresAt = time.Now().UTC().AddDate(0, 1, 0) // 1 month fallback - } - } else if result != nil { - // Use the expiration date from Google - expiresAt = result.ExpiresAt - log.Printf("Google purchase validated for user %d: product=%s, expires=%v, autoRenew=%v", - userID, result.ProductID, result.ExpiresAt, result.AutoRenewing) - - // Acknowledge the subscription if not already acknowledged - if !result.AcknowledgedState { - if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil { - log.Printf("Warning: Failed to acknowledge subscription for user %d: %v", userID, err) - // Don't fail the purchase, just log the warning - } - } - } - } else { - // Google validation not configured - trust client but log warning - log.Printf("Warning: Google IAP validation not configured, trusting client for user %d", userID) - expiresAt = time.Now().UTC().AddDate(1, 0, 0) // 1 year default + // Google IAP client must be configured to validate purchases. + // Without server-side validation, we cannot trust client-provided tokens. + if s.googleClient == nil { + log.Error().Uint("user_id", userID).Msg("Google IAP validation not configured, rejecting purchase") + return nil, apperrors.BadRequest("error.iap_validation_not_configured") } - // Upgrade to Pro with the determined expiration + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var result *GoogleValidationResult + var err error + + // If productID is provided, use it directly; otherwise try known IDs + if productID != "" { + result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken) + } else { + result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs) + } + + if err != nil { + // Validation failed -- do NOT fall through to grant Pro. + log.Error().Err(err).Uint("user_id", userID).Msg("Google purchase validation failed") + return nil, err + } + + if result == nil { + return nil, apperrors.BadRequest("error.no_purchase_token") + } + + expiresAt := result.ExpiresAt + log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Bool("auto_renew", result.AutoRenewing).Msg("Google purchase validated") + + // Acknowledge the subscription if not already acknowledged + if !result.AcknowledgedState { + if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil { + log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to acknowledge Google subscription") + // Don't fail the purchase, just log the warning + } + } + + // Upgrade to Pro with the validated expiration if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil { return nil, apperrors.Internal(err) } @@ -654,5 +638,5 @@ type ProcessPurchaseRequest struct { TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID PurchaseToken string `json:"purchase_token"` // Android ProductID string `json:"product_id"` // Android (optional, helps identify subscription) - Platform string `json:"platform" binding:"required,oneof=ios android"` + Platform string `json:"platform" validate:"required,oneof=ios android"` } diff --git a/internal/services/subscription_service_test.go b/internal/services/subscription_service_test.go new file mode 100644 index 0000000..31485a6 --- /dev/null +++ b/internal/services/subscription_service_test.go @@ -0,0 +1,181 @@ +package services + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/casera-api/internal/models" + "github.com/treytartt/casera-api/internal/repositories" + "github.com/treytartt/casera-api/internal/testutil" +) + +// setupSubscriptionService creates a SubscriptionService with the given +// IAP clients (nil means "not configured"). It bypasses NewSubscriptionService +// which tries to load config from environment. +func setupSubscriptionService(t *testing.T, appleClient *AppleIAPClient, googleClient *GoogleIAPClient) (*SubscriptionService, *repositories.SubscriptionRepository) { + db := testutil.SetupTestDB(t) + + subscriptionRepo := repositories.NewSubscriptionRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + taskRepo := repositories.NewTaskRepository(db) + contractorRepo := repositories.NewContractorRepository(db) + documentRepo := repositories.NewDocumentRepository(db) + + // Create a test user and subscription record for the test + user := testutil.CreateTestUser(t, db, "subuser", "subuser@test.com", "password") + + // Create subscription record so GetOrCreate will find it + sub := &models.UserSubscription{ + UserID: user.ID, + Tier: models.TierFree, + } + err := db.Create(sub).Error + require.NoError(t, err) + + svc := &SubscriptionService{ + subscriptionRepo: subscriptionRepo, + residenceRepo: residenceRepo, + taskRepo: taskRepo, + contractorRepo: contractorRepo, + documentRepo: documentRepo, + appleClient: appleClient, + googleClient: googleClient, + } + + return svc, subscriptionRepo +} + +func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) { + db := testutil.SetupTestDB(t) + subscriptionRepo := repositories.NewSubscriptionRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + taskRepo := repositories.NewTaskRepository(db) + contractorRepo := repositories.NewContractorRepository(db) + documentRepo := repositories.NewDocumentRepository(db) + + user := testutil.CreateTestUser(t, db, "subuser", "subuser@test.com", "password") + sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree} + require.NoError(t, db.Create(sub).Error) + + svc := &SubscriptionService{ + subscriptionRepo: subscriptionRepo, + residenceRepo: residenceRepo, + taskRepo: taskRepo, + contractorRepo: contractorRepo, + documentRepo: documentRepo, + appleClient: nil, // Not configured + googleClient: nil, + } + + _, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "") + assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil") + + // Verify user was NOT upgraded to Pro + updatedSub, err := subscriptionRepo.GetOrCreate(user.ID) + require.NoError(t, err) + assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier when IAP client is nil") +} + +func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) { + // We cannot easily create a real AppleIAPClient that will fail validation + // in a unit test (it requires real keys and network access). + // Instead, we test the code path logic: + // When appleClient is nil, the service must NOT upgrade the user. + // This is the same as TestProcessApplePurchase_ClientNil_ReturnsError + // but validates no fallback occurs for the specific case. + + db := testutil.SetupTestDB(t) + subscriptionRepo := repositories.NewSubscriptionRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + taskRepo := repositories.NewTaskRepository(db) + contractorRepo := repositories.NewContractorRepository(db) + documentRepo := repositories.NewDocumentRepository(db) + + user := testutil.CreateTestUser(t, db, "subuser2", "subuser2@test.com", "password") + sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree} + require.NoError(t, db.Create(sub).Error) + + svc := &SubscriptionService{ + subscriptionRepo: subscriptionRepo, + residenceRepo: residenceRepo, + taskRepo: taskRepo, + contractorRepo: contractorRepo, + documentRepo: documentRepo, + appleClient: nil, + googleClient: nil, + } + + // Neither receipt data nor transaction ID - should still not grant Pro + _, err := svc.ProcessApplePurchase(user.ID, "", "") + assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data") + + // Verify no upgrade happened + updatedSub, err := subscriptionRepo.GetOrCreate(user.ID) + require.NoError(t, err) + assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier") +} + +func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) { + db := testutil.SetupTestDB(t) + subscriptionRepo := repositories.NewSubscriptionRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + taskRepo := repositories.NewTaskRepository(db) + contractorRepo := repositories.NewContractorRepository(db) + documentRepo := repositories.NewDocumentRepository(db) + + user := testutil.CreateTestUser(t, db, "subuser3", "subuser3@test.com", "password") + sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree} + require.NoError(t, db.Create(sub).Error) + + svc := &SubscriptionService{ + subscriptionRepo: subscriptionRepo, + residenceRepo: residenceRepo, + taskRepo: taskRepo, + contractorRepo: contractorRepo, + documentRepo: documentRepo, + appleClient: nil, + googleClient: nil, // Not configured + } + + _, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.casera.pro.monthly") + assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil") + + // Verify user was NOT upgraded to Pro + updatedSub, err := subscriptionRepo.GetOrCreate(user.ID) + require.NoError(t, err) + assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier when IAP client is nil") +} + +func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) { + db := testutil.SetupTestDB(t) + subscriptionRepo := repositories.NewSubscriptionRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + taskRepo := repositories.NewTaskRepository(db) + contractorRepo := repositories.NewContractorRepository(db) + documentRepo := repositories.NewDocumentRepository(db) + + user := testutil.CreateTestUser(t, db, "subuser4", "subuser4@test.com", "password") + sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree} + require.NoError(t, db.Create(sub).Error) + + svc := &SubscriptionService{ + subscriptionRepo: subscriptionRepo, + residenceRepo: residenceRepo, + taskRepo: taskRepo, + contractorRepo: contractorRepo, + documentRepo: documentRepo, + appleClient: nil, + googleClient: nil, // Not configured + } + + // With empty token + _, err := svc.ProcessGooglePurchase(user.ID, "", "") + assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil") + + // Verify no upgrade happened + updatedSub, err := subscriptionRepo.GetOrCreate(user.ID) + require.NoError(t, err) + assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier") +} diff --git a/internal/services/task_service.go b/internal/services/task_service.go index 5f87af8..b8935fe 100644 --- a/internal/services/task_service.go +++ b/internal/services/task_service.go @@ -560,11 +560,7 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest Rating: req.Rating, } - if err := s.taskRepo.CreateCompletion(completion); err != nil { - return nil, apperrors.Internal(err) - } - - // Update next_due_date and in_progress based on frequency + // Determine interval days for NextDueDate calculation before entering the transaction. // - If frequency is "Once" (days = nil or 0), set next_due_date to nil (marks as completed) // - If frequency is "Custom", use task.CustomIntervalDays for recurrence // - If frequency is recurring, calculate next_due_date = completion_date + frequency_days @@ -598,11 +594,25 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest // instead of staying in "In Progress" column task.InProgress = false } - if err := s.taskRepo.Update(task); err != nil { - if errors.Is(err, repositories.ErrVersionConflict) { + + // P1-5: Wrap completion creation and task update in a transaction. + // If either operation fails, both are rolled back to prevent orphaned completions. + txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error { + if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil { + return err + } + if err := s.taskRepo.UpdateTx(tx, task); err != nil { + return err + } + return nil + }) + if txErr != nil { + // P1-6: Return the error instead of swallowing it. + if errors.Is(txErr, repositories.ErrVersionConflict) { return nil, apperrors.Conflict("error.version_conflict") } - log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after completion") + log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task") + return nil, apperrors.Internal(txErr) } // Create images if provided @@ -731,8 +741,15 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error { } log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully") - // Send notification (fire and forget) - go s.sendTaskCompletedNotification(task, completion) + // Send notification (fire and forget with panic recovery) + go func() { + defer func() { + if r := recover(); r != nil { + log.Error().Interface("panic", r).Uint("task_id", task.ID).Msg("Panic in quick-complete notification goroutine") + } + }() + s.sendTaskCompletedNotification(task, completion) + }() return nil } @@ -764,23 +781,23 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio emailImages = s.loadCompletionImagesForEmail(completion.Images) } - // Notify all users + // Notify all users synchronously to avoid unbounded goroutine spawning. + // This method is already called from a goroutine (QuickComplete) or inline + // (CreateCompletion) where blocking is acceptable for notification delivery. for _, user := range users { isCompleter := user.ID == completion.CompletedByID // Send push notification (to everyone EXCEPT the person who completed it) if !isCompleter && s.notificationService != nil { - go func(userID uint) { - ctx := context.Background() - if err := s.notificationService.CreateAndSendTaskNotification( - ctx, - userID, - models.NotificationTaskCompleted, - task, - ); err != nil { - log.Error().Err(err).Uint("user_id", userID).Uint("task_id", task.ID).Msg("Failed to send task completion push notification") - } - }(user.ID) + ctx := context.Background() + if err := s.notificationService.CreateAndSendTaskNotification( + ctx, + user.ID, + models.NotificationTaskCompleted, + task, + ); err != nil { + log.Error().Err(err).Uint("user_id", user.ID).Uint("task_id", task.ID).Msg("Failed to send task completion push notification") + } } // Send email notification (to everyone INCLUDING the person who completed it) @@ -789,20 +806,18 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio prefs, err := s.notificationService.GetPreferences(user.ID) if err != nil || (prefs != nil && prefs.EmailTaskCompleted) { // Send email if we couldn't get prefs (fail-open) or if email notifications are enabled - go func(u models.User, images []EmbeddedImage) { - if err := s.emailService.SendTaskCompletedEmail( - u.Email, - u.GetFullName(), - task.Title, - completedByName, - residenceName, - images, - ); err != nil { - log.Error().Err(err).Str("email", u.Email).Uint("task_id", task.ID).Msg("Failed to send task completion email") - } else { - log.Info().Str("email", u.Email).Uint("task_id", task.ID).Int("images", len(images)).Msg("Task completion email sent") - } - }(user, emailImages) + if err := s.emailService.SendTaskCompletedEmail( + user.Email, + user.GetFullName(), + task.Title, + completedByName, + residenceName, + emailImages, + ); err != nil { + log.Error().Err(err).Str("email", user.Email).Uint("task_id", task.ID).Msg("Failed to send task completion email") + } else { + log.Info().Str("email", user.Email).Uint("task_id", task.ID).Int("images", len(emailImages)).Msg("Task completion email sent") + } } } } @@ -846,20 +861,28 @@ func (s *TaskService) loadCompletionImagesForEmail(images []models.TaskCompletio return emailImages } -// resolveImageFilePath converts a stored URL to an actual file path +// resolveImageFilePath converts a stored URL to an actual file path. +// Returns empty string if the URL is empty or the resolved path would escape +// the upload directory (path traversal attempt). func (s *TaskService) resolveImageFilePath(storedURL, uploadDir string) string { if storedURL == "" { return "" } - // Handle /uploads/... URLs + // Strip legacy /uploads/ prefix to get relative path + relativePath := storedURL if strings.HasPrefix(storedURL, "/uploads/") { - relativePath := strings.TrimPrefix(storedURL, "/uploads/") - return filepath.Join(uploadDir, relativePath) + relativePath = strings.TrimPrefix(storedURL, "/uploads/") } - // Handle relative paths - return filepath.Join(uploadDir, storedURL) + // Use SafeResolvePath to validate containment within upload directory + resolved, err := SafeResolvePath(uploadDir, relativePath) + if err != nil { + // Path traversal or invalid path — return empty to signal file not found + return "" + } + + return resolved } // getContentTypeFromPath returns the MIME type based on file extension @@ -977,7 +1000,11 @@ func (s *TaskService) UpdateCompletion(completionID, userID uint, req *requests. return &resp, nil } -// DeleteCompletion deletes a task completion +// DeleteCompletion deletes a task completion and recalculates the task's NextDueDate. +// +// P1-7: After deleting a completion, NextDueDate must be recalculated: +// - If no completions remain: restore NextDueDate = DueDate (original schedule) +// - If completions remain (recurring): recalculate from latest remaining completion + frequency days func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.DeleteWithSummaryResponse, error) { completion, err := s.taskRepo.FindCompletionByID(completionID) if err != nil { @@ -996,10 +1023,66 @@ func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.De return nil, apperrors.Forbidden("error.task_access_denied") } + taskID := completion.TaskID + if err := s.taskRepo.DeleteCompletion(completionID); err != nil { return nil, apperrors.Internal(err) } + // Recalculate NextDueDate based on remaining completions + task, err := s.taskRepo.FindByID(taskID) + if err != nil { + // Non-fatal for the delete operation itself, but log the error + log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to reload task after completion deletion for NextDueDate recalculation") + return &responses.DeleteWithSummaryResponse{ + Data: "completion deleted", + Summary: s.getSummaryForUser(userID), + }, nil + } + + // Get remaining completions for this task + remainingCompletions, err := s.taskRepo.FindCompletionsByTask(taskID) + if err != nil { + log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to query remaining completions after deletion") + return &responses.DeleteWithSummaryResponse{ + Data: "completion deleted", + Summary: s.getSummaryForUser(userID), + }, nil + } + + // Determine the task's frequency interval + var intervalDays *int + if task.FrequencyID != nil { + frequency, freqErr := s.taskRepo.GetFrequencyByID(*task.FrequencyID) + if freqErr == nil && frequency != nil { + if frequency.Name == "Custom" { + intervalDays = task.CustomIntervalDays + } else { + intervalDays = frequency.Days + } + } + } + + if len(remainingCompletions) == 0 { + // No completions remain: restore NextDueDate to the original DueDate + task.NextDueDate = task.DueDate + } else if intervalDays != nil && *intervalDays > 0 { + // Recurring task with remaining completions: recalculate from the latest completion + // remainingCompletions is ordered by completed_at DESC, so index 0 is the latest + latestCompletion := remainingCompletions[0] + nextDue := latestCompletion.CompletedAt.AddDate(0, 0, *intervalDays) + task.NextDueDate = &nextDue + } else { + // One-time task with remaining completions (unusual case): keep NextDueDate as nil + // since the task is still considered completed + task.NextDueDate = nil + } + + if err := s.taskRepo.Update(task); err != nil { + log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to update task NextDueDate after completion deletion") + // The completion was already deleted; return success but log the update failure + } + return &responses.DeleteWithSummaryResponse{ Data: "completion deleted", Summary: s.getSummaryForUser(userID), diff --git a/internal/services/task_service_test.go b/internal/services/task_service_test.go index e46465c..3818383 100644 --- a/internal/services/task_service_test.go +++ b/internal/services/task_service_test.go @@ -8,6 +8,7 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" "github.com/treytartt/casera-api/internal/dto/requests" "github.com/treytartt/casera-api/internal/models" @@ -442,6 +443,333 @@ func TestTaskService_DeleteCompletion(t *testing.T) { assert.Error(t, err) } +func TestTaskService_CreateCompletion_TransactionIntegrity(t *testing.T) { + // Verifies P1-5 / P1-6: completion creation and task update are atomic. + // After completion, both the completion record AND the task's NextDueDate update + // should succeed together, and errors should be propagated (not swallowed). + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + taskRepo := repositories.NewTaskRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewTaskService(taskRepo, residenceRepo) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + // Create a one-time task with a due date + dueDate := time.Now().AddDate(0, 0, 7).UTC() + task := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "One-time Task", + DueDate: &dueDate, + NextDueDate: &dueDate, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + err := db.Create(task).Error + require.NoError(t, err) + + req := &requests.CreateTaskCompletionRequest{ + TaskID: task.ID, + Notes: "Done", + } + + now := time.Now().UTC() + resp, err := service.CreateCompletion(req, user.ID, now) + require.NoError(t, err) + assert.NotZero(t, resp.Data.ID) + + // Verify the task was updated: NextDueDate should be nil for a one-time task + var reloaded models.Task + db.First(&reloaded, task.ID) + assert.Nil(t, reloaded.NextDueDate, "One-time task NextDueDate should be nil after completion") + assert.False(t, reloaded.InProgress, "InProgress should be false after completion") + + // Verify completion record exists + var completion models.TaskCompletion + err = db.Where("task_id = ?", task.ID).First(&completion).Error + require.NoError(t, err, "Completion record should exist") + assert.Equal(t, "Done", completion.Notes) +} + +func TestTaskService_CreateCompletion_UpdateError_ReturnedNotSwallowed(t *testing.T) { + // Verifies P1-5 and P1-6: the completion creation and task update are wrapped + // in a transaction, and update errors are returned (not swallowed). + // + // Strategy: We trigger a version conflict by using a goroutine that bumps + // the task version after the service reads the task but during the transaction. + // Since SQLite serializes writes, we instead verify the behavior by deleting + // the task between the service read and the transactional update. When UpdateTx + // tries to match the row by id+version, 0 rows are affected and ErrVersionConflict + // is returned. The transaction then rolls back the completion insert. + // + // However, because the entire CreateCompletion flow is synchronous and we cannot + // inject failures between steps, we instead verify the transactional guarantee + // indirectly: we confirm that a concurrent version bump (set before the call + // but after the SELECT) causes the version conflict to propagate. Since FindByID + // re-reads the current version, we must verify via a custom test that invokes + // the transaction layer directly. + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + taskRepo := repositories.NewTaskRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + dueDate := time.Now().AddDate(0, 0, 7).UTC() + task := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Conflict Task", + DueDate: &dueDate, + NextDueDate: &dueDate, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + err := db.Create(task).Error + require.NoError(t, err) + + // Directly test that the transactional path returns an error on version conflict: + // Use a stale task object (version=1) when the DB has been bumped to version=999. + db.Model(&models.Task{}).Where("id = ?", task.ID).Update("version", 999) + + completion := &models.TaskCompletion{ + TaskID: task.ID, + CompletedByID: user.ID, + CompletedAt: time.Now().UTC(), + Notes: "Should be rolled back", + } + + // Simulate the transaction that CreateCompletion now uses (task still has version=1) + txErr := taskRepo.DB().Transaction(func(tx *gorm.DB) error { + if err := taskRepo.CreateCompletionTx(tx, completion); err != nil { + return err + } + // task.Version is 1 but DB has 999 -> version conflict + if err := taskRepo.UpdateTx(tx, task); err != nil { + return err + } + return nil + }) + + require.Error(t, txErr, "Transaction should fail due to version conflict") + assert.ErrorIs(t, txErr, repositories.ErrVersionConflict, "Error should be ErrVersionConflict") + + // Verify the completion was rolled back + var count int64 + db.Model(&models.TaskCompletion{}).Where("task_id = ?", task.ID).Count(&count) + assert.Equal(t, int64(0), count, "Completion should not exist when transaction rolls back") + + // Also verify that CreateCompletion (full service method) would propagate the error. + // Re-create the task with a normal version so FindByID works, then bump it. + db.Model(&models.Task{}).Where("id = ?", task.ID).Update("version", 1) + service := NewTaskService(taskRepo, residenceRepo) + req := &requests.CreateTaskCompletionRequest{ + TaskID: task.ID, + Notes: "Test error propagation", + } + now := time.Now().UTC() + // This call will succeed because FindByID loads version=1, UpdateTx uses version=1, DB has version=1. + // To verify error propagation, we use the direct transaction test above. + resp, err := service.CreateCompletion(req, user.ID, now) + require.NoError(t, err, "CreateCompletion should succeed with matching versions") + assert.NotZero(t, resp.Data.ID) +} + +func TestTaskService_DeleteCompletion_OneTime_RestoresOriginalDueDate(t *testing.T) { + // Verifies P1-7: deleting the only completion on a one-time task + // should restore NextDueDate to the original DueDate. + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + taskRepo := repositories.NewTaskRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewTaskService(taskRepo, residenceRepo) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + // Create a one-time task with a due date + originalDueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC) + task := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "One-time Task", + DueDate: &originalDueDate, + NextDueDate: &originalDueDate, + IsCancelled: false, + IsArchived: false, + Version: 1, + // No FrequencyID = one-time task + } + err := db.Create(task).Error + require.NoError(t, err) + + // Complete the task (sets NextDueDate to nil for one-time tasks) + req := &requests.CreateTaskCompletionRequest{ + TaskID: task.ID, + Notes: "Completed", + } + now := time.Now().UTC() + completionResp, err := service.CreateCompletion(req, user.ID, now) + require.NoError(t, err) + + // Confirm NextDueDate is nil after completion + var taskAfterComplete models.Task + db.First(&taskAfterComplete, task.ID) + assert.Nil(t, taskAfterComplete.NextDueDate, "NextDueDate should be nil after one-time completion") + + // Delete the completion + _, err = service.DeleteCompletion(completionResp.Data.ID, user.ID) + require.NoError(t, err) + + // Verify NextDueDate is restored to the original DueDate + var taskAfterDelete models.Task + db.First(&taskAfterDelete, task.ID) + require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be restored after deleting completion") + assert.Equal(t, originalDueDate.Year(), taskAfterDelete.NextDueDate.Year()) + assert.Equal(t, originalDueDate.Month(), taskAfterDelete.NextDueDate.Month()) + assert.Equal(t, originalDueDate.Day(), taskAfterDelete.NextDueDate.Day()) +} + +func TestTaskService_DeleteCompletion_Recurring_RecalculatesFromLastCompletion(t *testing.T) { + // Verifies P1-7: deleting the latest completion on a recurring task + // should recalculate NextDueDate from the remaining latest completion. + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + taskRepo := repositories.NewTaskRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewTaskService(taskRepo, residenceRepo) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + var monthlyFrequency models.TaskFrequency + db.Where("name = ?", "Monthly").First(&monthlyFrequency) + + // Create a recurring task + originalDueDate := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + task := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Recurring Task", + FrequencyID: &monthlyFrequency.ID, + DueDate: &originalDueDate, + NextDueDate: &originalDueDate, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + err := db.Create(task).Error + require.NoError(t, err) + + // First completion on Jan 15 + firstCompletedAt := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC) + firstReq := &requests.CreateTaskCompletionRequest{ + TaskID: task.ID, + Notes: "First completion", + CompletedAt: &firstCompletedAt, + } + now := time.Now().UTC() + _, err = service.CreateCompletion(firstReq, user.ID, now) + require.NoError(t, err) + + // Second completion on Feb 15 + secondCompletedAt := time.Date(2026, 2, 15, 10, 0, 0, 0, time.UTC) + secondReq := &requests.CreateTaskCompletionRequest{ + TaskID: task.ID, + Notes: "Second completion", + CompletedAt: &secondCompletedAt, + } + resp, err := service.CreateCompletion(secondReq, user.ID, now) + require.NoError(t, err) + + // NextDueDate should be Feb 15 + 30 days = Mar 17 + var taskAfterSecond models.Task + db.First(&taskAfterSecond, task.ID) + require.NotNil(t, taskAfterSecond.NextDueDate) + expectedAfterSecond := secondCompletedAt.AddDate(0, 0, 30) + assert.Equal(t, expectedAfterSecond.Year(), taskAfterSecond.NextDueDate.Year()) + assert.Equal(t, expectedAfterSecond.Month(), taskAfterSecond.NextDueDate.Month()) + assert.Equal(t, expectedAfterSecond.Day(), taskAfterSecond.NextDueDate.Day()) + + // Delete the second (latest) completion + _, err = service.DeleteCompletion(resp.Data.ID, user.ID) + require.NoError(t, err) + + // NextDueDate should be recalculated from the first completion: Jan 15 + 30 = Feb 14 + var taskAfterDelete models.Task + db.First(&taskAfterDelete, task.ID) + require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be set after deleting latest completion") + expectedRecalculated := firstCompletedAt.AddDate(0, 0, 30) + assert.Equal(t, expectedRecalculated.Year(), taskAfterDelete.NextDueDate.Year()) + assert.Equal(t, expectedRecalculated.Month(), taskAfterDelete.NextDueDate.Month()) + assert.Equal(t, expectedRecalculated.Day(), taskAfterDelete.NextDueDate.Day()) +} + +func TestTaskService_DeleteCompletion_LastCompletion_RestoresDueDate(t *testing.T) { + // Verifies P1-7: deleting the only completion on a recurring task + // should restore NextDueDate to the original DueDate. + db := testutil.SetupTestDB(t) + testutil.SeedLookupData(t, db) + taskRepo := repositories.NewTaskRepository(db) + residenceRepo := repositories.NewResidenceRepository(db) + service := NewTaskService(taskRepo, residenceRepo) + + user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test House") + + var weeklyFrequency models.TaskFrequency + db.Where("name = ?", "Weekly").First(&weeklyFrequency) + + // Create a recurring task + originalDueDate := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + task := &models.Task{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Weekly Task", + FrequencyID: &weeklyFrequency.ID, + DueDate: &originalDueDate, + NextDueDate: &originalDueDate, + IsCancelled: false, + IsArchived: false, + Version: 1, + } + err := db.Create(task).Error + require.NoError(t, err) + + // Complete the task + completedAt := time.Date(2026, 3, 2, 10, 0, 0, 0, time.UTC) + req := &requests.CreateTaskCompletionRequest{ + TaskID: task.ID, + Notes: "First completion", + CompletedAt: &completedAt, + } + now := time.Now().UTC() + completionResp, err := service.CreateCompletion(req, user.ID, now) + require.NoError(t, err) + + // Verify NextDueDate was set to completedAt + 7 days + var taskAfterComplete models.Task + db.First(&taskAfterComplete, task.ID) + require.NotNil(t, taskAfterComplete.NextDueDate) + + // Delete the only completion + _, err = service.DeleteCompletion(completionResp.Data.ID, user.ID) + require.NoError(t, err) + + // NextDueDate should be restored to original DueDate since no completions remain + var taskAfterDelete models.Task + db.First(&taskAfterDelete, task.ID) + require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be restored to original DueDate") + assert.Equal(t, originalDueDate.Year(), taskAfterDelete.NextDueDate.Year()) + assert.Equal(t, originalDueDate.Month(), taskAfterDelete.NextDueDate.Month()) + assert.Equal(t, originalDueDate.Day(), taskAfterDelete.NextDueDate.Day()) +} + func TestTaskService_GetCategories(t *testing.T) { db := testutil.SetupTestDB(t) testutil.SeedLookupData(t, db) diff --git a/internal/worker/jobs/handler.go b/internal/worker/jobs/handler.go index 040f056..f41da3d 100644 --- a/internal/worker/jobs/handler.go +++ b/internal/worker/jobs/handler.go @@ -139,6 +139,12 @@ func (h *Handler) HandleTaskReminder(ctx context.Context, task *asynq.Task) erro log.Info().Int("count", len(dueSoonTasks)).Msg("Found tasks due today/tomorrow for eligible users") + // Build set for O(1) eligibility lookups instead of O(N) linear scan + eligibleSet := make(map[uint]bool, len(eligibleUserIDs)) + for _, id := range eligibleUserIDs { + eligibleSet[id] = true + } + // Group tasks by user (assigned_to or residence owner) userTasks := make(map[uint][]models.Task) for _, t := range dueSoonTasks { @@ -150,12 +156,9 @@ func (h *Handler) HandleTaskReminder(ctx context.Context, task *asynq.Task) erro } else { continue } - // Only include if user is in eligible list - for _, eligibleID := range eligibleUserIDs { - if userID == eligibleID { - userTasks[userID] = append(userTasks[userID], t) - break - } + // Only include if user is in eligible set (O(1) lookup) + if eligibleSet[userID] { + userTasks[userID] = append(userTasks[userID], t) } } @@ -236,6 +239,12 @@ func (h *Handler) HandleOverdueReminder(ctx context.Context, task *asynq.Task) e log.Info().Int("count", len(overdueTasks)).Msg("Found overdue tasks for eligible users") + // Build set for O(1) eligibility lookups instead of O(N) linear scan + eligibleSet := make(map[uint]bool, len(eligibleUserIDs)) + for _, id := range eligibleUserIDs { + eligibleSet[id] = true + } + // Group tasks by user (assigned_to or residence owner) userTasks := make(map[uint][]models.Task) for _, t := range overdueTasks { @@ -247,12 +256,9 @@ func (h *Handler) HandleOverdueReminder(ctx context.Context, task *asynq.Task) e } else { continue } - // Only include if user is in eligible list - for _, eligibleID := range eligibleUserIDs { - if userID == eligibleID { - userTasks[userID] = append(userTasks[userID], t) - break - } + // Only include if user is in eligible set (O(1) lookup) + if eligibleSet[userID] { + userTasks[userID] = append(userTasks[userID], t) } } @@ -684,10 +690,20 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err log.Info().Int("count", len(activeTasks)).Msg("Found active tasks for eligible users") - // Step 3: Process each task once, sending appropriate notification based on user prefs - var dueSoonSent, dueSoonSkipped, overdueSent, overdueSkipped int + // Step 3: Pre-process tasks to determine stages and build batch reminder check + type candidateReminder struct { + taskIndex int + userID uint + effectiveDate time.Time + stage string + isOverdue bool + reminderStage models.ReminderStage + } - for _, t := range activeTasks { + var candidates []candidateReminder + var reminderKeys []repositories.ReminderKey + + for i, t := range activeTasks { // Determine which user to notify var userID uint if t.AssignedToID != nil { @@ -737,15 +753,36 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err reminderStage := models.ReminderStage(stage) - // Check if already sent - alreadySent, err := h.reminderRepo.HasSentReminder(t.ID, userID, effectiveDate, reminderStage) - if err != nil { - log.Error().Err(err).Uint("task_id", t.ID).Msg("Failed to check reminder log") - continue - } + candidates = append(candidates, candidateReminder{ + taskIndex: i, + userID: userID, + effectiveDate: effectiveDate, + stage: stage, + isOverdue: isOverdueStage, + reminderStage: reminderStage, + }) - if alreadySent { - if isOverdueStage { + reminderKeys = append(reminderKeys, repositories.ReminderKey{ + TaskID: t.ID, + UserID: userID, + DueDate: effectiveDate, + Stage: reminderStage, + }) + } + + // Batch check which reminders have already been sent (single query) + alreadySentMap, err := h.reminderRepo.HasSentReminderBatch(reminderKeys) + if err != nil { + log.Error().Err(err).Msg("Failed to batch check reminder logs") + return err + } + + // Step 4: Send notifications for candidates that haven't been sent yet + var dueSoonSent, dueSoonSkipped, overdueSent, overdueSkipped int + + for i, c := range candidates { + if alreadySentMap[i] { + if c.isOverdue { overdueSkipped++ } else { dueSoonSkipped++ @@ -753,30 +790,32 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err continue } + t := activeTasks[c.taskIndex] + // Determine notification type var notificationType models.NotificationType - if isOverdueStage { + if c.isOverdue { notificationType = models.NotificationTaskOverdue } else { notificationType = models.NotificationTaskDueSoon } // Send notification - if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, notificationType, &t); err != nil { + if err := h.notificationService.CreateAndSendTaskNotification(ctx, c.userID, notificationType, &t); err != nil { log.Error().Err(err). - Uint("user_id", userID). + Uint("user_id", c.userID). Uint("task_id", t.ID). - Str("stage", stage). + Str("stage", c.stage). Msg("Failed to send smart reminder") continue } // Log the reminder - if _, err := h.reminderRepo.LogReminder(t.ID, userID, effectiveDate, reminderStage, nil); err != nil { - log.Error().Err(err).Uint("task_id", t.ID).Str("stage", stage).Msg("Failed to log reminder") + if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil { + log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder") } - if isOverdueStage { + if c.isOverdue { overdueSent++ } else { dueSoonSent++