Harden API security: input validation, safe auth extraction, new tests, and deploy config
Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
5
.deploy_prod
Executable file
5
.deploy_prod
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
exec "${SCRIPT_DIR}/deploy/scripts/deploy_prod.sh" "$@"
|
||||
38
audit-digest-1.md
Normal file
38
audit-digest-1.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Digest 1: cmd/, admin/dto, admin/handlers (first 15 files)
|
||||
|
||||
## Systemic Issues (across all admin handlers)
|
||||
- **SQL Injection via SortBy**: Every admin list handler concatenates `filters.SortBy` directly into GORM `Order()` without allowlist validation
|
||||
- **Unchecked Count() errors**: Every paginated handler ignores GORM Count error returns
|
||||
- **Unchecked post-mutation Preload errors**: After Save/Create, handlers reload with Preload but ignore errors
|
||||
- **`binding` vs `validate` tag mismatch**: Some request DTOs use `binding` (Gin) instead of `validate` (Echo)
|
||||
- **Direct DB access**: All admin handlers bypass Service layer, accessing `*gorm.DB` directly
|
||||
- **Unsafe type assertions**: `c.Get(key).(*models.AdminUser)` without comma-ok
|
||||
|
||||
## Per-File Highlights
|
||||
|
||||
### cmd/api/main.go - App entry point, wires dependencies
|
||||
### cmd/worker/main.go - Background worker entry point
|
||||
|
||||
### admin/handlers/admin_user_handler.go (347 lines)
|
||||
- N+1 query: `toUserResponse` does 2 extra DB queries per user (residence count, task count)
|
||||
- Line 64: SortBy SQL injection
|
||||
- Line 173: Unchecked profile creation error (user created without profile)
|
||||
|
||||
### admin/handlers/apple_social_auth_handler.go - CRUD for Apple social auth records
|
||||
- Same systemic SQL injection and unchecked errors
|
||||
|
||||
### admin/handlers/auth_handler.go - Admin login/session management
|
||||
### admin/handlers/auth_token_handler.go - Auth token CRUD
|
||||
### admin/handlers/completion_handler.go - Task completion CRUD
|
||||
### admin/handlers/completion_image_handler.go - Completion image CRUD
|
||||
### admin/handlers/confirmation_code_handler.go - Email confirmation code CRUD
|
||||
### admin/handlers/contractor_handler.go - Contractor CRUD
|
||||
### admin/handlers/dashboard_handler.go - Admin dashboard stats
|
||||
|
||||
### admin/handlers/device_handler.go (317 lines)
|
||||
- Exposes device push tokens (RegistrationID) in API responses
|
||||
- Lines 293-296: Unchecked Count errors in GetStats
|
||||
|
||||
### admin/handlers/document_handler.go (394 lines)
|
||||
- Lines 176-183: Date parsing errors silently ignored
|
||||
- Line 379: Precision loss from decimal.Float64() discarded
|
||||
51
audit-digest-2.md
Normal file
51
audit-digest-2.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Digest 2: admin/handlers (remaining 15 files)
|
||||
|
||||
### admin/handlers/document_image_handler.go (245 lines)
|
||||
- N+1: toResponse queries DB per image in List
|
||||
- Same SortBy SQL injection
|
||||
|
||||
### admin/handlers/feature_benefit_handler.go (231 lines)
|
||||
- `binding` tags instead of `validate` - required fields never enforced
|
||||
|
||||
### admin/handlers/limitations_handler.go (451 lines)
|
||||
- Line 37: Unchecked Create error for default settings
|
||||
- Line 191-197: UpdateTierLimits overwrites ALL fields even for partial updates
|
||||
|
||||
### admin/handlers/lookup_handler.go (877 lines)
|
||||
- **CRITICAL**: Lines 30-32, 50-52, etc.: refreshXxxCache checks `if cache == nil {}` with EMPTY body, then calls cache.CacheXxx() — nil pointer panic when cache is nil
|
||||
- Line 792: Hardcoded join table name "task_contractor_specialties"
|
||||
|
||||
### admin/handlers/notification_handler.go (419 lines)
|
||||
- Line 351-363: HTML template built by string concatenation with user-provided subject/body — XSS in admin emails
|
||||
|
||||
### admin/handlers/notification_prefs_handler.go (347 lines)
|
||||
- Line 154: Unchecked user lookup — deleted user produces zero-value username/email
|
||||
|
||||
### admin/handlers/onboarding_handler.go (343 lines)
|
||||
- Line 304: Internal error details leaked to client
|
||||
|
||||
### admin/handlers/password_reset_code_handler.go (161 lines)
|
||||
- **BUG**: Line 85: `code.ResetToken[:8] + "..." + code.ResetToken[len-4:]` panics if token < 8 chars
|
||||
|
||||
### admin/handlers/promotion_handler.go (304 lines)
|
||||
- `binding` tags: required fields never enforced
|
||||
|
||||
### admin/handlers/residence_handler.go (371 lines)
|
||||
- Lines 121-122: Unchecked Count errors for task/document counts
|
||||
|
||||
### admin/handlers/settings_handler.go (794 lines)
|
||||
- Line 378: Raw SQL execution from seed files (no parameterization)
|
||||
- Line 529-793: ClearAllData is destructive with no double-auth check
|
||||
- Line 536-539: Panic in ClearAllData silently swallowed
|
||||
|
||||
### admin/handlers/share_code_handler.go (225 lines)
|
||||
- Line 155-162: `IsActive` as non-pointer bool — absent field defaults to false, deactivating codes
|
||||
|
||||
### admin/handlers/subscription_handler.go (237 lines)
|
||||
- **BUG**: Line 40-41: JOIN uses "users" but actual table is "auth_user" — query fails on PostgreSQL
|
||||
|
||||
### admin/handlers/task_handler.go (401 lines)
|
||||
- Line 247-296: Admin Create bypasses service layer — no business logic applied
|
||||
|
||||
### admin/handlers/task_template_handler.go (347 lines)
|
||||
- Lines 29-31: Same nil cache panic as lookup_handler.go
|
||||
30
audit-digest-3.md
Normal file
30
audit-digest-3.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Digest 3: admin routes, apperrors, config, database, dto/requests, dto/responses (first half)
|
||||
|
||||
### admin/routes.go (483 lines)
|
||||
- Route ordering: users DELETE "/:id" before "/bulk" — "/bulk" matches as id param
|
||||
- Line 454: Uses os.Getenv instead of Viper config
|
||||
- Line 460-462: url.Parse failure returns silently, no logging
|
||||
- Line 467-469: Proxy errors not surfaced (always returns nil)
|
||||
|
||||
### apperrors/errors.go (98 lines) - Clean. Error types with Wrap/Unwrap.
|
||||
### apperrors/handler.go (67 lines) - c.JSON error returns discarded (minor)
|
||||
|
||||
### config/config.go (427 lines)
|
||||
- Line 339: Hardcoded debug secret key "change-me-in-production-secret-key-12345"
|
||||
- Lines 311-312: Comments say wrong UTC times (says 8PM/9AM, actually 2PM/3PM)
|
||||
|
||||
### database/database.go (468 lines)
|
||||
- **SECURITY**: Line 372-382: Hardcoded bcrypt hash for GoAdmin with password "admin" — migration RESETS password every run
|
||||
- **SECURITY**: Line 447-463: Hardcoded admin@mycrib.com / admin123 — password reset on every migration
|
||||
- Line 182+: Multiple db.Exec errors unchecked for index creation
|
||||
- Line 100-102: WithTransaction coupled to global db variable
|
||||
|
||||
### dto/requests/auth.go (66 lines) - LoginRequest min=1 password (intentional for login)
|
||||
### dto/requests/contractor.go (37 lines) - Rating has no min/max validation bounds
|
||||
### dto/requests/document.go (47 lines) - ImageURLs no length limit, Description no max length
|
||||
### dto/requests/residence.go (59 lines) - Bedrooms/SquareFootage accept negative values, ExpiresInHours no validation
|
||||
### dto/requests/task.go (110 lines) - Rating no bounds, ImageURLs no length limit, CustomIntervalDays no min
|
||||
|
||||
### dto/responses/auth.go (190 lines) - Clean. Proper nil checks on Profile.
|
||||
### dto/responses/contractor.go (131 lines) - TaskCount depends on preloaded Tasks association
|
||||
### dto/responses/document.go (126 lines) - Line 101: CreatedBy accessed as value type (fragile if changed to pointer)
|
||||
48
audit-digest-4.md
Normal file
48
audit-digest-4.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# Digest 4: dto/responses (remaining), echohelpers, handlers (first half)
|
||||
|
||||
### dto/responses/residence.go (215 lines) - NewResidenceResponse no nil check on param. Owner zero-value if not preloaded.
|
||||
### dto/responses/task_template.go (135 lines) - No nil check on template param
|
||||
### dto/responses/task.go (399 lines) - No nil checks on params in factory functions
|
||||
### dto/responses/user.go (20 lines) - Clean data types
|
||||
### echohelpers/helpers.go (46 lines) - Clean utilities
|
||||
### echohelpers/pagination.go (33 lines) - Clean, properly bounded
|
||||
|
||||
### handlers/auth_handler.go (379 lines)
|
||||
- **ARCHITECTURE**: Lines 83, 178, 207, 241, 329, 370: SIX goroutine spawns for email — violates "no goroutines in handlers" rule
|
||||
- Line 308-312: AppError constructed directly instead of factory function
|
||||
|
||||
### handlers/contractor_handler.go (154 lines)
|
||||
- Line 28+: Unchecked type assertions throughout (7 instances)
|
||||
- Line 31: Raw err.Error() returned to client
|
||||
- Line 55: CreateContractor missing c.Validate() call
|
||||
|
||||
### handlers/document_handler.go (336 lines)
|
||||
- Line 37+: Unchecked type assertions (10 instances)
|
||||
- Line 92-93: Raw error leaked to client
|
||||
- Line 137: No DocumentType validation — any string accepted
|
||||
- Lines 187, 217: Missing c.Validate() calls
|
||||
|
||||
### handlers/media_handler.go (172 lines)
|
||||
- **SECURITY**: Line 156-171: resolveFilePath uses filepath.Join with user-influenced data — PATH TRAVERSAL vulnerability. TrimPrefix doesn't sanitize ../
|
||||
- Line 19-22: Handler accesses repositories directly, bypasses service layer
|
||||
|
||||
### handlers/notification_handler.go (200 lines)
|
||||
- Line 29-40: No upper bound on limit — unbounded query with limit=999999999
|
||||
- Line 168: Silent default to "ios" platform
|
||||
|
||||
### handlers/residence_handler.go (365 lines)
|
||||
- Line 38+: Unchecked type assertions (14 instances)
|
||||
- Lines 187, 209, 303: Bind errors silently discarded
|
||||
- Line 224: JoinWithCode missing c.Validate()
|
||||
|
||||
### handlers/static_data_handler.go (152 lines) - Uses interface{} instead of concrete types
|
||||
### handlers/subscription_handler.go (176 lines) - Lines 97, 150: Missing c.Validate() calls
|
||||
- Line 159-163: RestoreSubscription doesn't validate receipt/transaction ID presence
|
||||
|
||||
### handlers/subscription_webhook_handler.go (821 lines)
|
||||
- **SECURITY**: Line 190-192: Apple JWS payload decoded WITHOUT signature verification
|
||||
- **SECURITY**: Line 787-793: VerifyGooglePubSubToken ALWAYS returns true — webhook unauthenticated
|
||||
- Line 639-643: Subscription duration guessed by string matching product ID
|
||||
- Line 657, 694: Hardcoded 1-month extension regardless of actual plan
|
||||
- Line 759, 772: Unchecked type assertions in VerifyAppleSignature
|
||||
- Line 162: Apple renewal info error silently discarded
|
||||
45
audit-digest-5.md
Normal file
45
audit-digest-5.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Digest 5: handlers (remaining), i18n, middleware, models (first half)
|
||||
|
||||
### handlers/task_handler.go (440 lines)
|
||||
- Line 35+: Unchecked type assertions (18 locations)
|
||||
- Line 42: Fire-and-forget goroutine for UpdateUserTimezone — no error handling, no context
|
||||
- Lines 112-115, 134-137: Missing c.Validate() calls
|
||||
- Line 317: 32MB multipart limit with no per-file size check
|
||||
|
||||
### handlers/task_template_handler.go (98 lines)
|
||||
- Line 59: No max length on search query — slow LIKE queries possible
|
||||
|
||||
### handlers/tracking_handler.go (46 lines)
|
||||
- Line 25: Package-level base64 decode error discarded
|
||||
- Lines 34-36: Fire-and-forget goroutine — violates no-goroutines rule
|
||||
|
||||
### handlers/upload_handler.go (93 lines)
|
||||
- Line 31: User-controlled `category` param passed to storage — potential path traversal
|
||||
- Line 80: `binding` tag instead of `validate`
|
||||
- No file type or size validation at handler level
|
||||
|
||||
### handlers/user_handler.go (76 lines) - Unchecked type assertions
|
||||
|
||||
### i18n/i18n.go (87 lines)
|
||||
- Line 16: Global Bundle is nil until Init() — NewLocalizer dereferences without nil check
|
||||
- Line 37: MustParseMessageFileBytes panics on malformed translation files
|
||||
- Line 83: MustT panics on missing translations
|
||||
|
||||
### i18n/middleware.go (127 lines) - Clean
|
||||
|
||||
### middleware/admin_auth.go (133 lines)
|
||||
- **SECURITY**: Line 50: Admin JWT accepted via query param — tokens leak into server/proxy logs
|
||||
- Line 124: Unchecked type assertion
|
||||
|
||||
### middleware/auth.go (229 lines)
|
||||
- **BUG**: Line 66: `token[:8]` panics if token is fewer than 8 characters
|
||||
- Line 104: cacheUserID error silently discarded
|
||||
- Line 209: Unchecked type assertion
|
||||
|
||||
### middleware/logger.go (54 lines) - Clean
|
||||
### middleware/request_id.go (44 lines) - Line 21: Client-supplied X-Request-ID accepted without validation (log injection)
|
||||
### middleware/timezone.go (101 lines) - Lines 88, 99: Unchecked type assertions
|
||||
|
||||
### models/admin.go (64 lines) - Line 38: No max password length check; bcrypt truncates at 72 bytes
|
||||
### models/base.go (39 lines) - Clean GORM hooks
|
||||
### models/contractor.go (54 lines) - *float64 mapped to decimal(2,1) — minor precision mismatch
|
||||
40
audit-digest-6.md
Normal file
40
audit-digest-6.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Digest 6: models (remaining), monitoring
|
||||
|
||||
### models/document.go (100 lines) - Clean
|
||||
### models/notification.go (141 lines) - Clean
|
||||
### models/onboarding_email.go (35 lines) - Clean
|
||||
### models/reminder_log.go (92 lines) - Clean
|
||||
|
||||
### models/residence.go (106 lines)
|
||||
- Lines 65-70: GetAllUsers/HasAccess assumes Owner and Users are preloaded — returns wrong results if not
|
||||
|
||||
### models/subscription.go (169 lines)
|
||||
- Lines 57-65: IsActive()/IsPro() don't account for IsFree admin override field — misleading method names
|
||||
|
||||
### models/task_template.go (23 lines) - Clean
|
||||
### models/task.go (317 lines)
|
||||
- Lines 189-264: GetKanbanColumnWithTimezone duplicates categorization chain logic — maintenance drift risk
|
||||
- Lines 158-182: IsDueSoon uses time.Now() internally — non-deterministic, harder to test
|
||||
|
||||
### models/user.go (268 lines)
|
||||
- Line 101: crypto/rand.Read error unchecked (safe in practice since Go 1.20)
|
||||
- Line 164-172: GenerateConfirmationCode has slight distribution bias (negligible)
|
||||
|
||||
### monitoring/buffer.go (166 lines) - Line 75: Corrupted Redis data silently dropped
|
||||
### monitoring/collector.go (201 lines)
|
||||
- Line 82: cpu.Percent blocks 1 second per collection
|
||||
- Line 96-110: ReadMemStats called TWICE per cycle (also in collectRuntime)
|
||||
|
||||
### monitoring/handler.go (195 lines)
|
||||
- **SECURITY**: Line 19-22: WebSocket CheckOrigin always returns true
|
||||
- **BUG**: Line 117-119: After upgrader.Upgrade fails, execution continues to conn.Close() — nil pointer panic
|
||||
- **BUG**: Line 177: Missing return after ctx.Done() — goroutine spins
|
||||
- Lines 183-184: GetAllStats error silently ignored
|
||||
- Line 192: WriteJSON error unchecked
|
||||
|
||||
### monitoring/middleware.go (220 lines) - Clean
|
||||
### monitoring/models.go (129 lines) - Clean
|
||||
|
||||
### monitoring/service.go (196 lines)
|
||||
- Line 121: Hardcoded primary key 1 for singleton settings row
|
||||
- Line 191-194: MetricsMiddleware returns nil when httpCollector is nil — caller must nil-check
|
||||
57
audit-digest-7.md
Normal file
57
audit-digest-7.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Digest 7: monitoring/writer, notifications, push, repositories
|
||||
|
||||
### monitoring/writer.go (96 lines)
|
||||
- Line 90-92: Unbounded fire-and-forget goroutines for Redis push — no rate limiting
|
||||
|
||||
### notifications/reminder_config.go (64 lines) - Clean config data
|
||||
### notifications/reminder_schedule.go (199 lines)
|
||||
- Line 112: Integer truncation of float division — DST off-by-one possible
|
||||
- Line 148-161: Custom itoa reimplements strconv.Itoa
|
||||
|
||||
### push/apns.go (209 lines)
|
||||
- Line 44: Double-negative logic — both Production=false and Sandbox=false defaults to production
|
||||
|
||||
### push/client.go (158 lines)
|
||||
- Line 89-105: SendToAll last-error-wins — cannot tell which platform failed
|
||||
- Line 150-157: HealthCheck always returns nil — useless health check
|
||||
|
||||
### push/fcm.go (140 lines)
|
||||
- Line 16: Legacy FCM HTTP API (deprecated by Google)
|
||||
- Line 119-126: If FCM returns fewer results than tokens, index out of bounds panic
|
||||
|
||||
### repositories/admin_repo.go (108 lines)
|
||||
- Line 92: Negative page produces negative offset
|
||||
|
||||
### repositories/contractor_repo.go (166 lines)
|
||||
- **RACE**: Line 89-101: ToggleFavorite read-then-write without transaction
|
||||
- Line 91: ToggleFavorite doesn't filter is_active — can toggle deleted contractors
|
||||
|
||||
### repositories/document_repo.go (201 lines)
|
||||
- Line 92: LIKE wildcards in user input not escaped
|
||||
- Line 12: DocumentFilter.ResidenceID field defined but never used
|
||||
|
||||
### repositories/notification_repo.go (267 lines)
|
||||
- **RACE**: Line 137-161: GetOrCreatePreferences race — concurrent calls both create, duplicate key error
|
||||
- Line 143: Uses == instead of errors.Is for ErrRecordNotFound
|
||||
|
||||
### repositories/reminder_repo.go (126 lines)
|
||||
- Line 115-122: rows.Err() not checked after iteration loop
|
||||
|
||||
### repositories/residence_repo.go (344 lines)
|
||||
- Line 272: DeactivateShareCode error silently ignored
|
||||
- Line 298-301: Count error unchecked in generateUniqueCode — potential duplicate codes
|
||||
- Line 125-128: PostgreSQL-specific ON CONFLICT — fails on SQLite in tests
|
||||
|
||||
### repositories/subscription_repo.go (257 lines)
|
||||
- **RACE**: Line 40: GetOrCreate race condition (same as notification_repo)
|
||||
- Line 66: GORM v1 pattern `gorm:query_option` for FOR UPDATE — may not work in GORM v2
|
||||
- Line 129: LIKE search on receipt data blobs — inefficient, no index
|
||||
- Lines 40, 168, 196: Uses == instead of errors.Is
|
||||
|
||||
### repositories/task_repo.go (765 lines)
|
||||
- Line 707-709: DeleteCompletion ignores image deletion error
|
||||
- Line 62-101: applyFilterOptions applies NO scope when no filter set — could query all tasks
|
||||
|
||||
### repositories/task_template_repo.go (124 lines)
|
||||
- Line 48: LIKE wildcard escape issue
|
||||
- Line 79-81: Save without Omit could corrupt Category/Frequency lookup data
|
||||
35
audit-digest-8.md
Normal file
35
audit-digest-8.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Digest 8: repositories (remaining), router, services (first half)
|
||||
|
||||
### repositories/user_repo.go - Standard GORM CRUD
|
||||
### repositories/webhook_event_repo.go - Webhook event storage
|
||||
|
||||
### router/router.go - Route registration wiring
|
||||
|
||||
### services/apple_auth.go - Apple Sign In JWT validation
|
||||
### services/auth_service.go - Token management, password hashing, email verification
|
||||
|
||||
### services/cache_service.go - Redis caching for lookups
|
||||
### services/contractor_service.go - Contractor CRUD via repository
|
||||
|
||||
### services/document_service.go - Document management
|
||||
### services/email_service.go - SMTP email sending
|
||||
|
||||
### services/google_auth.go - Google OAuth token validation
|
||||
### services/iap_validation.go - Apple/Google receipt validation
|
||||
|
||||
### services/notification_service.go - Push notifications, preferences
|
||||
|
||||
### services/onboarding_email_service.go (371 lines)
|
||||
- **ARCHITECTURE**: Direct *gorm.DB access — bypasses repository layer entirely
|
||||
- Line 43-46: HasSentEmail ignores Count error — could send duplicate emails
|
||||
- Line 128-133: GetEmailStats ignores 4 Count errors
|
||||
- Line 170: Raw SQL references "auth_user" table
|
||||
- Line 354: Delete error silently ignored
|
||||
|
||||
### services/pdf_service.go (179 lines)
|
||||
- **BUG**: Line 131-133: Byte-level truncation of title — breaks multi-byte UTF-8 (CJK, emoji)
|
||||
|
||||
### services/residence_service.go (648 lines)
|
||||
- Line 155: TODO comment — subscription tier limit check commented out (free tier bypass)
|
||||
- Line 447-450: Empty if block — DeactivateShareCode error completely ignored
|
||||
- Line 625: Status only set for in-progress tasks; all others have empty string
|
||||
49
audit-digest-9.md
Normal file
49
audit-digest-9.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Digest 9: services (remaining), task package, testutil, validator, worker, pkg
|
||||
|
||||
### services/storage_service.go (184 lines)
|
||||
- Line 75: UUID truncated to 8 chars — increased collision risk
|
||||
- **SECURITY**: Line 137-138: filepath.Abs errors ignored — path traversal check could be bypassed
|
||||
|
||||
### services/subscription_service.go (659 lines)
|
||||
- **PERFORMANCE**: Line 186-204: N+1 queries in getUserUsage — 3 queries per residence
|
||||
- **SECURITY/BUSINESS**: Line 371: Apple validation failure grants 1-month free Pro
|
||||
- **SECURITY/BUSINESS**: Line 381: Apple validation not configured grants 1-year free Pro
|
||||
- **SECURITY/BUSINESS**: Line 429, 449: Same for Google — errors/misconfiguration grant free Pro
|
||||
- Line 7: Uses stdlib "log" instead of zerolog
|
||||
|
||||
### services/task_button_types.go (85 lines) - Clean, uses predicates correctly
|
||||
### services/task_service.go (1092 lines)
|
||||
- **DATA INTEGRITY**: Line 601: If task update fails after completion creation, error only logged not returned — stale NextDueDate/InProgress
|
||||
- Line 735: Goroutine in QuickComplete (service method) — inconsistent with synchronous CreateCompletion
|
||||
- Line 773: Unbounded goroutine creation per user for notifications
|
||||
- Line 790: Fail-open email notification on error (intentional but risky)
|
||||
- **SECURITY**: Line 857-862: resolveImageFilePath has NO path traversal validation
|
||||
|
||||
### services/task_template_service.go (70 lines) - Errors returned raw (not wrapped with apperrors)
|
||||
### services/user_service.go (88 lines) - Returns nil instead of empty slice (JSON null vs [])
|
||||
|
||||
### task/categorization/chain.go (359 lines) - Clean chain-of-responsibility
|
||||
### task/predicates/predicates.go (217 lines)
|
||||
- Line 210: IsRecurring requires Frequency preloaded — returns false without it
|
||||
|
||||
### task/scopes/scopes.go (270 lines)
|
||||
- Line 118: ScopeOverdue doesn't exclude InProgress — differs from categorization chain
|
||||
|
||||
### task/task.go (261 lines) - Clean facade re-exports
|
||||
|
||||
### testutil/testutil.go (359 lines)
|
||||
- Line 86: json.Marshal error ignored in MakeRequest
|
||||
- Line 92: http.NewRequest error ignored
|
||||
|
||||
### validator/validator.go (103 lines) - Clean
|
||||
|
||||
### worker/jobs/email_jobs.go (118 lines) - Clean
|
||||
### worker/jobs/handler.go (810 lines)
|
||||
- Lines 95-106, 193-204: Direct DB access bypasses repository layer
|
||||
- Line 627-635: Raw SQL with fmt.Sprintf (not currently user-supplied but fragile)
|
||||
- Line 154, 251: O(N*M) lookup instead of map
|
||||
|
||||
### worker/scheduler.go (240 lines)
|
||||
- Line 200-212: Cron schedules at fixed UTC times may conflict with smart reminder system — potential duplicate notifications
|
||||
|
||||
### pkg/utils/logger.go (132 lines) - Panic recovery bypasses apperrors.HTTPErrorHandler
|
||||
18
deploy/.gitignore
vendored
Normal file
18
deploy/.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# Local deploy inputs (copy from *.example files)
|
||||
cluster.env
|
||||
registry.env
|
||||
prod.env
|
||||
|
||||
# Local secret material
|
||||
secrets/*.txt
|
||||
secrets/*.p8
|
||||
|
||||
# Keep templates and docs tracked
|
||||
!*.example
|
||||
!README.md
|
||||
!shit_deploy_cant_do.md
|
||||
!swarm-stack.prod.yml
|
||||
!scripts/
|
||||
!scripts/**
|
||||
!secrets/*.example
|
||||
!secrets/README.md
|
||||
134
deploy/README.md
Normal file
134
deploy/README.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# Deploy Folder
|
||||
|
||||
This folder is the full production deploy toolkit for `myCribAPI-go`.
|
||||
|
||||
Run deploy with:
|
||||
|
||||
```bash
|
||||
./.deploy_prod
|
||||
```
|
||||
|
||||
The script will refuse to run until all required values are set.
|
||||
|
||||
## First-Time Prerequisite: Create The Swarm Cluster
|
||||
|
||||
You must do this once before `./.deploy_prod` can work.
|
||||
|
||||
1. SSH to manager #1 and initialize Swarm:
|
||||
|
||||
```bash
|
||||
docker swarm init --advertise-addr <manager1-private-ip>
|
||||
```
|
||||
|
||||
2. On manager #1, get join commands:
|
||||
|
||||
```bash
|
||||
docker swarm join-token manager
|
||||
docker swarm join-token worker
|
||||
```
|
||||
|
||||
3. SSH to each additional node and run the appropriate `docker swarm join ...` command.
|
||||
|
||||
4. Verify from manager #1:
|
||||
|
||||
```bash
|
||||
docker node ls
|
||||
```
|
||||
|
||||
## Security Requirements Before Public Launch
|
||||
|
||||
Use this as a mandatory checklist before you route production traffic.
|
||||
|
||||
### 1) Firewall Rules (Node-Level)
|
||||
|
||||
Apply firewall rules to all Swarm nodes:
|
||||
|
||||
- SSH port (for example `2222/tcp`): your IP only
|
||||
- `80/tcp`, `443/tcp`: Hetzner LB only (or Cloudflare IP ranges only if no LB)
|
||||
- `2377/tcp`: Swarm nodes only
|
||||
- `7946/tcp,udp`: Swarm nodes only
|
||||
- `4789/udp`: Swarm nodes only
|
||||
- Everything else: blocked
|
||||
|
||||
### 2) SSH Hardening
|
||||
|
||||
On each node, harden `/etc/ssh/sshd_config`:
|
||||
|
||||
```text
|
||||
Port 2222
|
||||
PermitRootLogin no
|
||||
PasswordAuthentication no
|
||||
PubkeyAuthentication yes
|
||||
AllowUsers deploy
|
||||
```
|
||||
|
||||
### 3) Cloudflare Origin Lockdown
|
||||
|
||||
- Keep public DNS records proxied (orange cloud on).
|
||||
- Point Cloudflare to LB, not node IPs.
|
||||
- Do not publish Swarm node IPs in DNS.
|
||||
- Enforce firewall source restrictions so public traffic cannot bypass Cloudflare/LB.
|
||||
|
||||
### 4) Secrets Policy
|
||||
|
||||
- Keep runtime secrets in Docker Swarm secrets only.
|
||||
- Do not put production secrets in git or plain `.env` files.
|
||||
- `./.deploy_prod` already creates versioned Swarm secrets from files in `deploy/secrets/`.
|
||||
- Rotate secrets after incidents or credential exposure.
|
||||
|
||||
### 5) Data Path Security
|
||||
|
||||
- Neon/Postgres: `DB_SSLMODE=require`, strong DB password, Neon IP allowlist limited to node IPs.
|
||||
- Backblaze B2: HTTPS only, scoped app keys (not master key), least-privilege bucket access.
|
||||
- Swarm overlay: encrypted network enabled in stack (`driver_opts.encrypted: "true"`).
|
||||
|
||||
### 6) Dozzle Hardening
|
||||
|
||||
- Keep Dozzle private (no public DNS/ingress).
|
||||
- Put auth/SSO in front (Cloudflare Access or equivalent).
|
||||
- Prefer a Docker socket proxy with restricted read-only scope.
|
||||
|
||||
### 7) Backup + Restore Readiness
|
||||
|
||||
- Postgres PITR path tested in staging.
|
||||
- Redis persistence enabled and restore path tested.
|
||||
- Written runbook for restore and secret rotation.
|
||||
- Named owner for incident response.
|
||||
|
||||
## Files You Fill In
|
||||
|
||||
Paste your values into these files:
|
||||
|
||||
- `deploy/cluster.env`
|
||||
- `deploy/registry.env`
|
||||
- `deploy/prod.env`
|
||||
- `deploy/secrets/postgres_password.txt`
|
||||
- `deploy/secrets/secret_key.txt`
|
||||
- `deploy/secrets/email_host_password.txt`
|
||||
- `deploy/secrets/fcm_server_key.txt`
|
||||
- `deploy/secrets/apns_auth_key.p8`
|
||||
|
||||
If one is missing, the deploy script auto-copies it from its `.example` template and exits so you can fill it.
|
||||
|
||||
## What `./.deploy_prod` Does
|
||||
|
||||
1. Validates all required config files and credentials.
|
||||
2. Builds and pushes `api`, `worker`, and `admin` images.
|
||||
3. Uploads deploy bundle to your Swarm manager over SSH.
|
||||
4. Creates versioned Docker secrets on the manager.
|
||||
5. Deploys the stack with `docker stack deploy --with-registry-auth`.
|
||||
6. Waits until service replicas converge.
|
||||
7. Runs an HTTP health check (if `DEPLOY_HEALTHCHECK_URL` is set).
|
||||
|
||||
## Useful Flags
|
||||
|
||||
Environment flags:
|
||||
|
||||
- `SKIP_BUILD=1 ./.deploy_prod` to deploy already-pushed images.
|
||||
- `SKIP_HEALTHCHECK=1 ./.deploy_prod` to skip final URL check.
|
||||
- `DEPLOY_TAG=<tag> ./.deploy_prod` to deploy a specific image tag.
|
||||
|
||||
## Important
|
||||
|
||||
- `deploy/shit_deploy_cant_do.md` lists the manual tasks this script cannot automate.
|
||||
- Keep real credentials and secret files out of git.
|
||||
22
deploy/cluster.env.example
Normal file
22
deploy/cluster.env.example
Normal file
@@ -0,0 +1,22 @@
|
||||
# Swarm manager connection
|
||||
DEPLOY_MANAGER_HOST=CHANGEME_MANAGER_IP_OR_HOSTNAME
|
||||
DEPLOY_MANAGER_USER=deploy
|
||||
DEPLOY_MANAGER_SSH_PORT=22
|
||||
DEPLOY_SSH_KEY_PATH=~/.ssh/id_ed25519
|
||||
|
||||
# Stack settings
|
||||
DEPLOY_STACK_NAME=casera
|
||||
DEPLOY_REMOTE_DIR=/opt/casera/deploy
|
||||
DEPLOY_WAIT_SECONDS=420
|
||||
DEPLOY_HEALTHCHECK_URL=https://api.casera.app/api/health/
|
||||
|
||||
# Replicas and published ports
|
||||
API_REPLICAS=3
|
||||
WORKER_REPLICAS=2
|
||||
ADMIN_REPLICAS=1
|
||||
API_PORT=8000
|
||||
ADMIN_PORT=3000
|
||||
DOZZLE_PORT=9999
|
||||
|
||||
# Build behavior
|
||||
PUSH_LATEST_TAG=true
|
||||
73
deploy/prod.env.example
Normal file
73
deploy/prod.env.example
Normal file
@@ -0,0 +1,73 @@
|
||||
# API service settings
|
||||
DEBUG=false
|
||||
ALLOWED_HOSTS=api.casera.app,casera.app
|
||||
CORS_ALLOWED_ORIGINS=https://casera.app,https://admin.casera.app
|
||||
TIMEZONE=UTC
|
||||
BASE_URL=https://casera.app
|
||||
PORT=8000
|
||||
|
||||
# Admin service settings
|
||||
NEXT_PUBLIC_API_URL=https://api.casera.app
|
||||
ADMIN_PANEL_URL=https://admin.casera.app
|
||||
|
||||
# Database (Neon recommended)
|
||||
DB_HOST=CHANGEME_NEON_HOST
|
||||
DB_PORT=5432
|
||||
POSTGRES_USER=CHANGEME_DB_USER
|
||||
POSTGRES_DB=casera
|
||||
DB_SSLMODE=require
|
||||
DB_MAX_OPEN_CONNS=25
|
||||
DB_MAX_IDLE_CONNS=10
|
||||
DB_MAX_LIFETIME=600s
|
||||
|
||||
# Redis (in stack defaults to redis://redis:6379/0)
|
||||
REDIS_URL=redis://redis:6379/0
|
||||
REDIS_DB=0
|
||||
|
||||
# Email (password goes in deploy/secrets/email_host_password.txt)
|
||||
EMAIL_HOST=smtp.gmail.com
|
||||
EMAIL_PORT=587
|
||||
EMAIL_USE_TLS=true
|
||||
EMAIL_HOST_USER=CHANGEME_EMAIL_USER
|
||||
DEFAULT_FROM_EMAIL=Casera <noreply@casera.app>
|
||||
|
||||
# Push notifications
|
||||
# APNS private key goes in deploy/secrets/apns_auth_key.p8
|
||||
APNS_AUTH_KEY_ID=CHANGEME_APNS_KEY_ID
|
||||
APNS_TEAM_ID=CHANGEME_APNS_TEAM_ID
|
||||
APNS_TOPIC=com.tt.casera
|
||||
APNS_USE_SANDBOX=false
|
||||
APNS_PRODUCTION=true
|
||||
|
||||
# Worker schedules (UTC)
|
||||
TASK_REMINDER_HOUR=14
|
||||
OVERDUE_REMINDER_HOUR=15
|
||||
DAILY_DIGEST_HOUR=3
|
||||
|
||||
# Storage
|
||||
STORAGE_UPLOAD_DIR=/app/uploads
|
||||
STORAGE_BASE_URL=/uploads
|
||||
STORAGE_MAX_FILE_SIZE=10485760
|
||||
STORAGE_ALLOWED_TYPES=image/jpeg,image/png,image/gif,image/webp,application/pdf
|
||||
|
||||
# Feature flags
|
||||
FEATURE_PUSH_ENABLED=true
|
||||
FEATURE_EMAIL_ENABLED=true
|
||||
FEATURE_WEBHOOKS_ENABLED=true
|
||||
FEATURE_ONBOARDING_EMAILS_ENABLED=true
|
||||
FEATURE_PDF_REPORTS_ENABLED=true
|
||||
FEATURE_WORKER_ENABLED=true
|
||||
|
||||
# Optional auth/iap values
|
||||
APPLE_CLIENT_ID=
|
||||
APPLE_TEAM_ID=
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_ANDROID_CLIENT_ID=
|
||||
GOOGLE_IOS_CLIENT_ID=
|
||||
APPLE_IAP_KEY_ID=
|
||||
APPLE_IAP_ISSUER_ID=
|
||||
APPLE_IAP_BUNDLE_ID=
|
||||
APPLE_IAP_KEY_PATH=
|
||||
APPLE_IAP_SANDBOX=false
|
||||
GOOGLE_IAP_PACKAGE_NAME=
|
||||
GOOGLE_IAP_SERVICE_ACCOUNT_PATH=
|
||||
11
deploy/registry.env.example
Normal file
11
deploy/registry.env.example
Normal file
@@ -0,0 +1,11 @@
|
||||
# Container registry used for deploy images.
|
||||
# For GHCR:
|
||||
# REGISTRY=ghcr.io
|
||||
# REGISTRY_NAMESPACE=<github-username-or-org>
|
||||
# REGISTRY_USERNAME=<github-username>
|
||||
# REGISTRY_TOKEN=<github-pat-with-read:packages,write:packages>
|
||||
|
||||
REGISTRY=ghcr.io
|
||||
REGISTRY_NAMESPACE=CHANGEME_NAMESPACE
|
||||
REGISTRY_USERNAME=CHANGEME_USERNAME
|
||||
REGISTRY_TOKEN=CHANGEME_TOKEN
|
||||
397
deploy/scripts/deploy_prod.sh
Executable file
397
deploy/scripts/deploy_prod.sh
Executable file
@@ -0,0 +1,397 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
DEPLOY_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
REPO_DIR="$(cd "${DEPLOY_DIR}/.." && pwd)"
|
||||
|
||||
STACK_TEMPLATE="${DEPLOY_DIR}/swarm-stack.prod.yml"
|
||||
CLUSTER_ENV="${DEPLOY_DIR}/cluster.env"
|
||||
REGISTRY_ENV="${DEPLOY_DIR}/registry.env"
|
||||
PROD_ENV="${DEPLOY_DIR}/prod.env"
|
||||
|
||||
SECRET_POSTGRES="${DEPLOY_DIR}/secrets/postgres_password.txt"
|
||||
SECRET_APP_KEY="${DEPLOY_DIR}/secrets/secret_key.txt"
|
||||
SECRET_EMAIL_PASS="${DEPLOY_DIR}/secrets/email_host_password.txt"
|
||||
SECRET_FCM_KEY="${DEPLOY_DIR}/secrets/fcm_server_key.txt"
|
||||
SECRET_APNS_KEY="${DEPLOY_DIR}/secrets/apns_auth_key.p8"
|
||||
|
||||
SKIP_BUILD="${SKIP_BUILD:-0}"
|
||||
SKIP_HEALTHCHECK="${SKIP_HEALTHCHECK:-0}"
|
||||
|
||||
log() {
|
||||
printf '[deploy] %s\n' "$*"
|
||||
}
|
||||
|
||||
warn() {
|
||||
printf '[deploy][warn] %s\n' "$*" >&2
|
||||
}
|
||||
|
||||
die() {
|
||||
printf '[deploy][error] %s\n' "$*" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
require_cmd() {
|
||||
command -v "$1" >/dev/null 2>&1 || die "Missing required command: $1"
|
||||
}
|
||||
|
||||
contains_placeholder() {
|
||||
local value="$1"
|
||||
[[ -z "${value}" ]] && return 0
|
||||
local lowered
|
||||
lowered="$(printf '%s' "${value}" | tr '[:upper:]' '[:lower:]')"
|
||||
case "${lowered}" in
|
||||
*changeme*|*replace_me*|*example.com*|*your-*|*todo*|*fill_me*|*paste_here*)
|
||||
return 0
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
ensure_file_from_example() {
|
||||
local path="$1"
|
||||
local example="${path}.example"
|
||||
if [[ -f "${path}" ]]; then
|
||||
return
|
||||
fi
|
||||
if [[ -f "${example}" ]]; then
|
||||
cp "${example}" "${path}"
|
||||
die "Created ${path} from template. Fill it in and rerun."
|
||||
fi
|
||||
die "Missing required file: ${path}"
|
||||
}
|
||||
|
||||
require_var() {
|
||||
local name="$1"
|
||||
local value="${!name:-}"
|
||||
[[ -n "${value}" ]] || die "Missing required value: ${name}"
|
||||
if contains_placeholder "${value}"; then
|
||||
die "Value still uses placeholder text: ${name}=${value}"
|
||||
fi
|
||||
}
|
||||
|
||||
require_secret_file() {
|
||||
local path="$1"
|
||||
local label="$2"
|
||||
ensure_file_from_example "${path}"
|
||||
local contents
|
||||
contents="$(tr -d '\r' < "${path}" | sed 's/[[:space:]]*$//')"
|
||||
[[ -n "${contents}" ]] || die "Secret file is empty: ${path}"
|
||||
if contains_placeholder "${contents}"; then
|
||||
die "Secret file still has placeholder text (${label}): ${path}"
|
||||
fi
|
||||
}
|
||||
|
||||
print_usage() {
|
||||
cat <<'EOF'
|
||||
Usage:
|
||||
./.deploy_prod
|
||||
|
||||
Optional environment flags:
|
||||
SKIP_BUILD=1 Deploy existing image tags without rebuilding/pushing.
|
||||
SKIP_HEALTHCHECK=1 Skip final HTTP health check.
|
||||
DEPLOY_TAG=<tag> Override image tag (default: git short sha).
|
||||
EOF
|
||||
}
|
||||
|
||||
while (($# > 0)); do
|
||||
case "$1" in
|
||||
-h|--help)
|
||||
print_usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
die "Unknown argument: $1"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
require_cmd docker
|
||||
require_cmd ssh
|
||||
require_cmd scp
|
||||
require_cmd git
|
||||
require_cmd awk
|
||||
require_cmd sed
|
||||
require_cmd grep
|
||||
require_cmd mktemp
|
||||
require_cmd date
|
||||
require_cmd curl
|
||||
|
||||
ensure_file_from_example "${CLUSTER_ENV}"
|
||||
ensure_file_from_example "${REGISTRY_ENV}"
|
||||
ensure_file_from_example "${PROD_ENV}"
|
||||
|
||||
require_secret_file "${SECRET_POSTGRES}" "Postgres password"
|
||||
require_secret_file "${SECRET_APP_KEY}" "SECRET_KEY"
|
||||
require_secret_file "${SECRET_EMAIL_PASS}" "SMTP password"
|
||||
require_secret_file "${SECRET_FCM_KEY}" "FCM server key"
|
||||
require_secret_file "${SECRET_APNS_KEY}" "APNS private key"
|
||||
|
||||
set -a
|
||||
# shellcheck disable=SC1090
|
||||
source "${CLUSTER_ENV}"
|
||||
# shellcheck disable=SC1090
|
||||
source "${REGISTRY_ENV}"
|
||||
# shellcheck disable=SC1090
|
||||
source "${PROD_ENV}"
|
||||
set +a
|
||||
|
||||
DEPLOY_MANAGER_SSH_PORT="${DEPLOY_MANAGER_SSH_PORT:-22}"
|
||||
DEPLOY_STACK_NAME="${DEPLOY_STACK_NAME:-casera}"
|
||||
DEPLOY_REMOTE_DIR="${DEPLOY_REMOTE_DIR:-/opt/casera/deploy}"
|
||||
DEPLOY_WAIT_SECONDS="${DEPLOY_WAIT_SECONDS:-420}"
|
||||
DEPLOY_TAG="${DEPLOY_TAG:-$(git -C "${REPO_DIR}" rev-parse --short HEAD)}"
|
||||
PUSH_LATEST_TAG="${PUSH_LATEST_TAG:-true}"
|
||||
|
||||
require_var DEPLOY_MANAGER_HOST
|
||||
require_var DEPLOY_MANAGER_USER
|
||||
require_var DEPLOY_STACK_NAME
|
||||
require_var DEPLOY_REMOTE_DIR
|
||||
require_var REGISTRY
|
||||
require_var REGISTRY_NAMESPACE
|
||||
require_var REGISTRY_USERNAME
|
||||
require_var REGISTRY_TOKEN
|
||||
|
||||
require_var ALLOWED_HOSTS
|
||||
require_var CORS_ALLOWED_ORIGINS
|
||||
require_var BASE_URL
|
||||
require_var NEXT_PUBLIC_API_URL
|
||||
require_var DB_HOST
|
||||
require_var DB_PORT
|
||||
require_var POSTGRES_USER
|
||||
require_var POSTGRES_DB
|
||||
require_var DB_SSLMODE
|
||||
require_var REDIS_URL
|
||||
require_var EMAIL_HOST
|
||||
require_var EMAIL_PORT
|
||||
require_var EMAIL_HOST_USER
|
||||
require_var DEFAULT_FROM_EMAIL
|
||||
require_var APNS_AUTH_KEY_ID
|
||||
require_var APNS_TEAM_ID
|
||||
require_var APNS_TOPIC
|
||||
|
||||
if [[ ! "$(tr -d '\r\n' < "${SECRET_APNS_KEY}")" =~ BEGIN[[:space:]]+PRIVATE[[:space:]]+KEY ]]; then
|
||||
die "APNS key file does not look like a private key: ${SECRET_APNS_KEY}"
|
||||
fi
|
||||
|
||||
app_secret_len="$(tr -d '\r\n' < "${SECRET_APP_KEY}" | wc -c | tr -d ' ')"
|
||||
if (( app_secret_len < 32 )); then
|
||||
die "deploy/secrets/secret_key.txt must be at least 32 characters."
|
||||
fi
|
||||
|
||||
REGISTRY_PREFIX="${REGISTRY%/}/${REGISTRY_NAMESPACE#/}"
|
||||
API_IMAGE="${REGISTRY_PREFIX}/casera-api:${DEPLOY_TAG}"
|
||||
WORKER_IMAGE="${REGISTRY_PREFIX}/casera-worker:${DEPLOY_TAG}"
|
||||
ADMIN_IMAGE="${REGISTRY_PREFIX}/casera-admin:${DEPLOY_TAG}"
|
||||
|
||||
SSH_KEY_PATH="${DEPLOY_SSH_KEY_PATH:-}"
|
||||
if [[ -n "${SSH_KEY_PATH}" ]]; then
|
||||
SSH_KEY_PATH="${SSH_KEY_PATH/#\~/${HOME}}"
|
||||
fi
|
||||
|
||||
SSH_TARGET="${DEPLOY_MANAGER_USER}@${DEPLOY_MANAGER_HOST}"
|
||||
SSH_OPTS=(-p "${DEPLOY_MANAGER_SSH_PORT}")
|
||||
SCP_OPTS=(-P "${DEPLOY_MANAGER_SSH_PORT}")
|
||||
if [[ -n "${SSH_KEY_PATH}" ]]; then
|
||||
SSH_OPTS+=(-i "${SSH_KEY_PATH}")
|
||||
SCP_OPTS+=(-i "${SSH_KEY_PATH}")
|
||||
fi
|
||||
|
||||
log "Validating SSH access to ${SSH_TARGET}"
|
||||
if ! ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "echo ok" >/dev/null 2>&1; then
|
||||
die "SSH connection failed to ${SSH_TARGET}"
|
||||
fi
|
||||
|
||||
remote_swarm_state="$(ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "docker info --format '{{.Swarm.LocalNodeState}} {{.Swarm.ControlAvailable}}'" 2>/dev/null || true)"
|
||||
if [[ -z "${remote_swarm_state}" ]]; then
|
||||
die "Could not read Docker Swarm state on manager. Is Docker installed/running?"
|
||||
fi
|
||||
|
||||
if [[ "${remote_swarm_state}" != "active true" ]]; then
|
||||
die "Remote node must be a Swarm manager. Got: ${remote_swarm_state}"
|
||||
fi
|
||||
|
||||
if [[ "${SKIP_BUILD}" != "1" ]]; then
|
||||
log "Logging in to ${REGISTRY}"
|
||||
printf '%s' "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" -u "${REGISTRY_USERNAME}" --password-stdin >/dev/null
|
||||
|
||||
log "Building API image ${API_IMAGE}"
|
||||
docker build --target api -t "${API_IMAGE}" "${REPO_DIR}"
|
||||
log "Building Worker image ${WORKER_IMAGE}"
|
||||
docker build --target worker -t "${WORKER_IMAGE}" "${REPO_DIR}"
|
||||
log "Building Admin image ${ADMIN_IMAGE}"
|
||||
docker build --target admin -t "${ADMIN_IMAGE}" "${REPO_DIR}"
|
||||
|
||||
log "Pushing deploy images"
|
||||
docker push "${API_IMAGE}"
|
||||
docker push "${WORKER_IMAGE}"
|
||||
docker push "${ADMIN_IMAGE}"
|
||||
|
||||
if [[ "${PUSH_LATEST_TAG}" == "true" ]]; then
|
||||
log "Updating :latest tags"
|
||||
docker tag "${API_IMAGE}" "${REGISTRY_PREFIX}/casera-api:latest"
|
||||
docker tag "${WORKER_IMAGE}" "${REGISTRY_PREFIX}/casera-worker:latest"
|
||||
docker tag "${ADMIN_IMAGE}" "${REGISTRY_PREFIX}/casera-admin:latest"
|
||||
docker push "${REGISTRY_PREFIX}/casera-api:latest"
|
||||
docker push "${REGISTRY_PREFIX}/casera-worker:latest"
|
||||
docker push "${REGISTRY_PREFIX}/casera-admin:latest"
|
||||
fi
|
||||
else
|
||||
warn "SKIP_BUILD=1 set. Using prebuilt images for tag: ${DEPLOY_TAG}"
|
||||
fi
|
||||
|
||||
DEPLOY_ID_RAW="${DEPLOY_TAG}-$(date +%Y%m%d%H%M%S)"
|
||||
DEPLOY_ID="$(printf '%s' "${DEPLOY_ID_RAW}" | tr -c 'a-zA-Z0-9_-' '_')"
|
||||
|
||||
POSTGRES_PASSWORD_SECRET="${DEPLOY_STACK_NAME}_postgres_password_${DEPLOY_ID}"
|
||||
SECRET_KEY_SECRET="${DEPLOY_STACK_NAME}_secret_key_${DEPLOY_ID}"
|
||||
EMAIL_HOST_PASSWORD_SECRET="${DEPLOY_STACK_NAME}_email_host_password_${DEPLOY_ID}"
|
||||
FCM_SERVER_KEY_SECRET="${DEPLOY_STACK_NAME}_fcm_server_key_${DEPLOY_ID}"
|
||||
APNS_AUTH_KEY_SECRET="${DEPLOY_STACK_NAME}_apns_auth_key_${DEPLOY_ID}"
|
||||
|
||||
TMP_DIR="$(mktemp -d)"
|
||||
cleanup() {
|
||||
rm -rf "${TMP_DIR}"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
cp "${STACK_TEMPLATE}" "${TMP_DIR}/swarm-stack.prod.yml"
|
||||
cp "${PROD_ENV}" "${TMP_DIR}/prod.env"
|
||||
cp "${REGISTRY_ENV}" "${TMP_DIR}/registry.env"
|
||||
mkdir -p "${TMP_DIR}/secrets"
|
||||
cp "${SECRET_POSTGRES}" "${TMP_DIR}/secrets/postgres_password.txt"
|
||||
cp "${SECRET_APP_KEY}" "${TMP_DIR}/secrets/secret_key.txt"
|
||||
cp "${SECRET_EMAIL_PASS}" "${TMP_DIR}/secrets/email_host_password.txt"
|
||||
cp "${SECRET_FCM_KEY}" "${TMP_DIR}/secrets/fcm_server_key.txt"
|
||||
cp "${SECRET_APNS_KEY}" "${TMP_DIR}/secrets/apns_auth_key.p8"
|
||||
|
||||
cat > "${TMP_DIR}/runtime.env" <<EOF
|
||||
API_IMAGE=${API_IMAGE}
|
||||
WORKER_IMAGE=${WORKER_IMAGE}
|
||||
ADMIN_IMAGE=${ADMIN_IMAGE}
|
||||
|
||||
API_REPLICAS=${API_REPLICAS:-3}
|
||||
WORKER_REPLICAS=${WORKER_REPLICAS:-2}
|
||||
ADMIN_REPLICAS=${ADMIN_REPLICAS:-1}
|
||||
API_PORT=${API_PORT:-8000}
|
||||
ADMIN_PORT=${ADMIN_PORT:-3000}
|
||||
DOZZLE_PORT=${DOZZLE_PORT:-9999}
|
||||
|
||||
POSTGRES_PASSWORD_SECRET=${POSTGRES_PASSWORD_SECRET}
|
||||
SECRET_KEY_SECRET=${SECRET_KEY_SECRET}
|
||||
EMAIL_HOST_PASSWORD_SECRET=${EMAIL_HOST_PASSWORD_SECRET}
|
||||
FCM_SERVER_KEY_SECRET=${FCM_SERVER_KEY_SECRET}
|
||||
APNS_AUTH_KEY_SECRET=${APNS_AUTH_KEY_SECRET}
|
||||
EOF
|
||||
|
||||
log "Uploading deploy bundle to ${SSH_TARGET}:${DEPLOY_REMOTE_DIR}"
|
||||
ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "mkdir -p '${DEPLOY_REMOTE_DIR}/secrets'"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/swarm-stack.prod.yml" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/swarm-stack.prod.yml"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/prod.env" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/prod.env"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/registry.env" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/registry.env"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/runtime.env" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/runtime.env"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/postgres_password.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/postgres_password.txt"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/secret_key.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/secret_key.txt"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/email_host_password.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/email_host_password.txt"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/fcm_server_key.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/fcm_server_key.txt"
|
||||
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/apns_auth_key.p8" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/apns_auth_key.p8"
|
||||
|
||||
log "Creating Docker secrets and deploying stack on manager"
|
||||
ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "bash -s -- '${DEPLOY_REMOTE_DIR}' '${DEPLOY_STACK_NAME}'" <<'EOF'
|
||||
set -euo pipefail
|
||||
|
||||
REMOTE_DIR="$1"
|
||||
STACK_NAME="$2"
|
||||
|
||||
set -a
|
||||
# shellcheck disable=SC1090
|
||||
source "${REMOTE_DIR}/registry.env"
|
||||
# shellcheck disable=SC1090
|
||||
source "${REMOTE_DIR}/prod.env"
|
||||
# shellcheck disable=SC1090
|
||||
source "${REMOTE_DIR}/runtime.env"
|
||||
set +a
|
||||
|
||||
create_secret() {
|
||||
local name="$1"
|
||||
local src="$2"
|
||||
if docker secret inspect "${name}" >/dev/null 2>&1; then
|
||||
echo "[remote] secret exists: ${name}"
|
||||
else
|
||||
docker secret create "${name}" "${src}" >/dev/null
|
||||
echo "[remote] created secret: ${name}"
|
||||
fi
|
||||
}
|
||||
|
||||
printf '%s' "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" -u "${REGISTRY_USERNAME}" --password-stdin >/dev/null
|
||||
rm -f "${REMOTE_DIR}/registry.env"
|
||||
|
||||
create_secret "${POSTGRES_PASSWORD_SECRET}" "${REMOTE_DIR}/secrets/postgres_password.txt"
|
||||
create_secret "${SECRET_KEY_SECRET}" "${REMOTE_DIR}/secrets/secret_key.txt"
|
||||
create_secret "${EMAIL_HOST_PASSWORD_SECRET}" "${REMOTE_DIR}/secrets/email_host_password.txt"
|
||||
create_secret "${FCM_SERVER_KEY_SECRET}" "${REMOTE_DIR}/secrets/fcm_server_key.txt"
|
||||
create_secret "${APNS_AUTH_KEY_SECRET}" "${REMOTE_DIR}/secrets/apns_auth_key.p8"
|
||||
|
||||
rm -f "${REMOTE_DIR}/secrets/postgres_password.txt"
|
||||
rm -f "${REMOTE_DIR}/secrets/secret_key.txt"
|
||||
rm -f "${REMOTE_DIR}/secrets/email_host_password.txt"
|
||||
rm -f "${REMOTE_DIR}/secrets/fcm_server_key.txt"
|
||||
rm -f "${REMOTE_DIR}/secrets/apns_auth_key.p8"
|
||||
|
||||
set -a
|
||||
# shellcheck disable=SC1090
|
||||
source "${REMOTE_DIR}/prod.env"
|
||||
# shellcheck disable=SC1090
|
||||
source "${REMOTE_DIR}/runtime.env"
|
||||
set +a
|
||||
|
||||
docker stack deploy --with-registry-auth -c "${REMOTE_DIR}/swarm-stack.prod.yml" "${STACK_NAME}"
|
||||
EOF
|
||||
|
||||
log "Waiting for stack convergence (${DEPLOY_WAIT_SECONDS}s max)"
|
||||
start_epoch="$(date +%s)"
|
||||
while true; do
|
||||
services="$(ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "docker stack services '${DEPLOY_STACK_NAME}' --format '{{.Name}} {{.Replicas}}'" 2>/dev/null || true)"
|
||||
|
||||
if [[ -n "${services}" ]]; then
|
||||
all_ready=1
|
||||
while IFS=' ' read -r svc replicas; do
|
||||
[[ -z "${svc}" ]] && continue
|
||||
current="${replicas%%/*}"
|
||||
desired="${replicas##*/}"
|
||||
if [[ "${desired}" == "0" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ "${current}" != "${desired}" ]]; then
|
||||
all_ready=0
|
||||
fi
|
||||
done <<< "${services}"
|
||||
|
||||
if [[ "${all_ready}" -eq 1 ]]; then
|
||||
break
|
||||
fi
|
||||
fi
|
||||
|
||||
now_epoch="$(date +%s)"
|
||||
elapsed=$((now_epoch - start_epoch))
|
||||
if (( elapsed >= DEPLOY_WAIT_SECONDS )); then
|
||||
die "Timed out waiting for stack to converge. Check: ssh ${SSH_TARGET} docker stack services ${DEPLOY_STACK_NAME}"
|
||||
fi
|
||||
|
||||
sleep 10
|
||||
done
|
||||
|
||||
if [[ "${SKIP_HEALTHCHECK}" != "1" && -n "${DEPLOY_HEALTHCHECK_URL:-}" ]]; then
|
||||
log "Running health check: ${DEPLOY_HEALTHCHECK_URL}"
|
||||
curl -fsS --max-time 20 "${DEPLOY_HEALTHCHECK_URL}" >/dev/null
|
||||
fi
|
||||
|
||||
log "Deploy completed successfully."
|
||||
log "Stack: ${DEPLOY_STACK_NAME}"
|
||||
log "Images:"
|
||||
log " ${API_IMAGE}"
|
||||
log " ${WORKER_IMAGE}"
|
||||
log " ${ADMIN_IMAGE}"
|
||||
11
deploy/secrets/README.md
Normal file
11
deploy/secrets/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# Secrets Directory
|
||||
|
||||
Create these files (copy from `.example` files):
|
||||
|
||||
- `deploy/secrets/postgres_password.txt`
|
||||
- `deploy/secrets/secret_key.txt`
|
||||
- `deploy/secrets/email_host_password.txt`
|
||||
- `deploy/secrets/fcm_server_key.txt`
|
||||
- `deploy/secrets/apns_auth_key.p8`
|
||||
|
||||
These are consumed by `./.deploy_prod` and converted into Docker Swarm secrets.
|
||||
3
deploy/secrets/apns_auth_key.p8.example
Normal file
3
deploy/secrets/apns_auth_key.p8.example
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
CHANGEME_APNS_PRIVATE_KEY
|
||||
-----END PRIVATE KEY-----
|
||||
1
deploy/secrets/email_host_password.txt.example
Normal file
1
deploy/secrets/email_host_password.txt.example
Normal file
@@ -0,0 +1 @@
|
||||
CHANGEME_SMTP_PASSWORD
|
||||
1
deploy/secrets/fcm_server_key.txt.example
Normal file
1
deploy/secrets/fcm_server_key.txt.example
Normal file
@@ -0,0 +1 @@
|
||||
CHANGEME_FCM_SERVER_KEY
|
||||
1
deploy/secrets/postgres_password.txt.example
Normal file
1
deploy/secrets/postgres_password.txt.example
Normal file
@@ -0,0 +1 @@
|
||||
CHANGEME_DATABASE_PASSWORD
|
||||
1
deploy/secrets/secret_key.txt.example
Normal file
1
deploy/secrets/secret_key.txt.example
Normal file
@@ -0,0 +1 @@
|
||||
CHANGEME_SECRET_KEY_MIN_32_CHARS
|
||||
67
deploy/shit_deploy_cant_do.md
Normal file
67
deploy/shit_deploy_cant_do.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Shit Deploy Can't Do
|
||||
|
||||
This is everything `./.deploy_prod` cannot safely automate for you.
|
||||
|
||||
## 1. Create Infrastructure
|
||||
|
||||
Step:
|
||||
Create Hetzner servers, networking, and load balancer.
|
||||
|
||||
Reason:
|
||||
The script only deploys app workloads. It cannot create paid cloud resources without cloud API credentials and IaC wiring.
|
||||
|
||||
## 2. Join Nodes To Swarm
|
||||
|
||||
Step:
|
||||
Run `docker swarm init` on the first manager and `docker swarm join` on other nodes.
|
||||
|
||||
Reason:
|
||||
Joining nodes requires one-time bootstrap tokens and host-level control.
|
||||
|
||||
## 3. Configure Firewall And Origin Restrictions
|
||||
|
||||
Step:
|
||||
Set firewall rules so only expected ingress paths can reach your nodes.
|
||||
|
||||
Reason:
|
||||
Firewall policies live in provider networking controls, outside this repo.
|
||||
|
||||
## 4. Configure DNS / Cloudflare
|
||||
|
||||
Step:
|
||||
Point DNS at LB, enable proxying, set SSL mode, and lock down origin access.
|
||||
|
||||
Reason:
|
||||
DNS and CDN settings are account-level operations in Cloudflare, not deploy-time app actions.
|
||||
|
||||
## 5. Configure External Services
|
||||
|
||||
Step:
|
||||
Create and configure Neon, B2, email provider, APNS, and FCM credentials.
|
||||
|
||||
Reason:
|
||||
These credentials are issued in vendor dashboards and must be manually generated/rotated.
|
||||
|
||||
## 6. Seed SSH Trust
|
||||
|
||||
Step:
|
||||
Ensure your local machine can SSH to the manager with the key in `deploy/cluster.env`.
|
||||
|
||||
Reason:
|
||||
The script assumes SSH already works; it cannot grant itself SSH access.
|
||||
|
||||
## 7. First-Time Smoke Testing Beyond `/api/health/`
|
||||
|
||||
Step:
|
||||
Manually test login, push, background jobs, and admin panel flows after first deploy.
|
||||
|
||||
Reason:
|
||||
Automated health checks prove container readiness, not end-to-end business behavior.
|
||||
|
||||
## 8. Safe Secret Garbage Collection
|
||||
|
||||
Step:
|
||||
Periodically remove old versioned Docker secrets that are no longer referenced.
|
||||
|
||||
Reason:
|
||||
This deploy script creates versioned secrets for safe rollouts and does not auto-delete old ones to avoid breaking running services.
|
||||
288
deploy/swarm-stack.prod.yml
Normal file
288
deploy/swarm-stack.prod.yml
Normal file
@@ -0,0 +1,288 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --appendonly yes --appendfsync everysec
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
condition: any
|
||||
delay: 5s
|
||||
placement:
|
||||
max_replicas_per_node: 1
|
||||
networks:
|
||||
- casera-network
|
||||
|
||||
api:
|
||||
image: ${API_IMAGE}
|
||||
ports:
|
||||
- target: 8000
|
||||
published: ${API_PORT}
|
||||
protocol: tcp
|
||||
mode: ingress
|
||||
environment:
|
||||
PORT: "8000"
|
||||
DEBUG: "${DEBUG}"
|
||||
ALLOWED_HOSTS: "${ALLOWED_HOSTS}"
|
||||
CORS_ALLOWED_ORIGINS: "${CORS_ALLOWED_ORIGINS}"
|
||||
TIMEZONE: "${TIMEZONE}"
|
||||
BASE_URL: "${BASE_URL}"
|
||||
ADMIN_PANEL_URL: "${ADMIN_PANEL_URL}"
|
||||
|
||||
DB_HOST: "${DB_HOST}"
|
||||
DB_PORT: "${DB_PORT}"
|
||||
POSTGRES_USER: "${POSTGRES_USER}"
|
||||
POSTGRES_DB: "${POSTGRES_DB}"
|
||||
DB_SSLMODE: "${DB_SSLMODE}"
|
||||
DB_MAX_OPEN_CONNS: "${DB_MAX_OPEN_CONNS}"
|
||||
DB_MAX_IDLE_CONNS: "${DB_MAX_IDLE_CONNS}"
|
||||
DB_MAX_LIFETIME: "${DB_MAX_LIFETIME}"
|
||||
|
||||
REDIS_URL: "${REDIS_URL}"
|
||||
REDIS_DB: "${REDIS_DB}"
|
||||
|
||||
EMAIL_HOST: "${EMAIL_HOST}"
|
||||
EMAIL_PORT: "${EMAIL_PORT}"
|
||||
EMAIL_HOST_USER: "${EMAIL_HOST_USER}"
|
||||
DEFAULT_FROM_EMAIL: "${DEFAULT_FROM_EMAIL}"
|
||||
EMAIL_USE_TLS: "${EMAIL_USE_TLS}"
|
||||
|
||||
APNS_AUTH_KEY_PATH: "/run/secrets/apns_auth_key"
|
||||
APNS_AUTH_KEY_ID: "${APNS_AUTH_KEY_ID}"
|
||||
APNS_TEAM_ID: "${APNS_TEAM_ID}"
|
||||
APNS_TOPIC: "${APNS_TOPIC}"
|
||||
APNS_USE_SANDBOX: "${APNS_USE_SANDBOX}"
|
||||
APNS_PRODUCTION: "${APNS_PRODUCTION}"
|
||||
|
||||
STORAGE_UPLOAD_DIR: "${STORAGE_UPLOAD_DIR}"
|
||||
STORAGE_BASE_URL: "${STORAGE_BASE_URL}"
|
||||
STORAGE_MAX_FILE_SIZE: "${STORAGE_MAX_FILE_SIZE}"
|
||||
STORAGE_ALLOWED_TYPES: "${STORAGE_ALLOWED_TYPES}"
|
||||
|
||||
FEATURE_PUSH_ENABLED: "${FEATURE_PUSH_ENABLED}"
|
||||
FEATURE_EMAIL_ENABLED: "${FEATURE_EMAIL_ENABLED}"
|
||||
FEATURE_WEBHOOKS_ENABLED: "${FEATURE_WEBHOOKS_ENABLED}"
|
||||
FEATURE_ONBOARDING_EMAILS_ENABLED: "${FEATURE_ONBOARDING_EMAILS_ENABLED}"
|
||||
FEATURE_PDF_REPORTS_ENABLED: "${FEATURE_PDF_REPORTS_ENABLED}"
|
||||
FEATURE_WORKER_ENABLED: "${FEATURE_WORKER_ENABLED}"
|
||||
|
||||
APPLE_CLIENT_ID: "${APPLE_CLIENT_ID}"
|
||||
APPLE_TEAM_ID: "${APPLE_TEAM_ID}"
|
||||
GOOGLE_CLIENT_ID: "${GOOGLE_CLIENT_ID}"
|
||||
GOOGLE_ANDROID_CLIENT_ID: "${GOOGLE_ANDROID_CLIENT_ID}"
|
||||
GOOGLE_IOS_CLIENT_ID: "${GOOGLE_IOS_CLIENT_ID}"
|
||||
APPLE_IAP_KEY_PATH: "${APPLE_IAP_KEY_PATH}"
|
||||
APPLE_IAP_KEY_ID: "${APPLE_IAP_KEY_ID}"
|
||||
APPLE_IAP_ISSUER_ID: "${APPLE_IAP_ISSUER_ID}"
|
||||
APPLE_IAP_BUNDLE_ID: "${APPLE_IAP_BUNDLE_ID}"
|
||||
APPLE_IAP_SANDBOX: "${APPLE_IAP_SANDBOX}"
|
||||
GOOGLE_IAP_SERVICE_ACCOUNT_PATH: "${GOOGLE_IAP_SERVICE_ACCOUNT_PATH}"
|
||||
GOOGLE_IAP_PACKAGE_NAME: "${GOOGLE_IAP_PACKAGE_NAME}"
|
||||
command:
|
||||
- /bin/sh
|
||||
- -lc
|
||||
- |
|
||||
set -eu
|
||||
export POSTGRES_PASSWORD="$$(cat /run/secrets/postgres_password)"
|
||||
export SECRET_KEY="$$(cat /run/secrets/secret_key)"
|
||||
export EMAIL_HOST_PASSWORD="$$(cat /run/secrets/email_host_password)"
|
||||
export FCM_SERVER_KEY="$$(cat /run/secrets/fcm_server_key)"
|
||||
exec /app/api
|
||||
secrets:
|
||||
- source: ${POSTGRES_PASSWORD_SECRET}
|
||||
target: postgres_password
|
||||
- source: ${SECRET_KEY_SECRET}
|
||||
target: secret_key
|
||||
- source: ${EMAIL_HOST_PASSWORD_SECRET}
|
||||
target: email_host_password
|
||||
- source: ${FCM_SERVER_KEY_SECRET}
|
||||
target: fcm_server_key
|
||||
- source: ${APNS_AUTH_KEY_SECRET}
|
||||
target: apns_auth_key
|
||||
volumes:
|
||||
- uploads:/app/uploads
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://127.0.0.1:8000/api/health/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
start_period: 15s
|
||||
retries: 3
|
||||
deploy:
|
||||
replicas: ${API_REPLICAS}
|
||||
restart_policy:
|
||||
condition: any
|
||||
delay: 5s
|
||||
update_config:
|
||||
parallelism: 1
|
||||
delay: 10s
|
||||
order: start-first
|
||||
rollback_config:
|
||||
parallelism: 1
|
||||
delay: 5s
|
||||
order: stop-first
|
||||
networks:
|
||||
- casera-network
|
||||
|
||||
admin:
|
||||
image: ${ADMIN_IMAGE}
|
||||
ports:
|
||||
- target: 3000
|
||||
published: ${ADMIN_PORT}
|
||||
protocol: tcp
|
||||
mode: ingress
|
||||
environment:
|
||||
PORT: "3000"
|
||||
HOSTNAME: "0.0.0.0"
|
||||
NEXT_PUBLIC_API_URL: "${NEXT_PUBLIC_API_URL}"
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://127.0.0.1:3000/admin/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
deploy:
|
||||
replicas: ${ADMIN_REPLICAS}
|
||||
restart_policy:
|
||||
condition: any
|
||||
delay: 5s
|
||||
update_config:
|
||||
parallelism: 1
|
||||
delay: 10s
|
||||
order: start-first
|
||||
rollback_config:
|
||||
parallelism: 1
|
||||
delay: 5s
|
||||
order: stop-first
|
||||
networks:
|
||||
- casera-network
|
||||
|
||||
worker:
|
||||
image: ${WORKER_IMAGE}
|
||||
environment:
|
||||
DB_HOST: "${DB_HOST}"
|
||||
DB_PORT: "${DB_PORT}"
|
||||
POSTGRES_USER: "${POSTGRES_USER}"
|
||||
POSTGRES_DB: "${POSTGRES_DB}"
|
||||
DB_SSLMODE: "${DB_SSLMODE}"
|
||||
DB_MAX_OPEN_CONNS: "${DB_MAX_OPEN_CONNS}"
|
||||
DB_MAX_IDLE_CONNS: "${DB_MAX_IDLE_CONNS}"
|
||||
DB_MAX_LIFETIME: "${DB_MAX_LIFETIME}"
|
||||
|
||||
REDIS_URL: "${REDIS_URL}"
|
||||
REDIS_DB: "${REDIS_DB}"
|
||||
|
||||
EMAIL_HOST: "${EMAIL_HOST}"
|
||||
EMAIL_PORT: "${EMAIL_PORT}"
|
||||
EMAIL_HOST_USER: "${EMAIL_HOST_USER}"
|
||||
DEFAULT_FROM_EMAIL: "${DEFAULT_FROM_EMAIL}"
|
||||
EMAIL_USE_TLS: "${EMAIL_USE_TLS}"
|
||||
|
||||
APNS_AUTH_KEY_PATH: "/run/secrets/apns_auth_key"
|
||||
APNS_AUTH_KEY_ID: "${APNS_AUTH_KEY_ID}"
|
||||
APNS_TEAM_ID: "${APNS_TEAM_ID}"
|
||||
APNS_TOPIC: "${APNS_TOPIC}"
|
||||
APNS_USE_SANDBOX: "${APNS_USE_SANDBOX}"
|
||||
APNS_PRODUCTION: "${APNS_PRODUCTION}"
|
||||
|
||||
TASK_REMINDER_HOUR: "${TASK_REMINDER_HOUR}"
|
||||
OVERDUE_REMINDER_HOUR: "${OVERDUE_REMINDER_HOUR}"
|
||||
DAILY_DIGEST_HOUR: "${DAILY_DIGEST_HOUR}"
|
||||
|
||||
FEATURE_PUSH_ENABLED: "${FEATURE_PUSH_ENABLED}"
|
||||
FEATURE_EMAIL_ENABLED: "${FEATURE_EMAIL_ENABLED}"
|
||||
FEATURE_WEBHOOKS_ENABLED: "${FEATURE_WEBHOOKS_ENABLED}"
|
||||
FEATURE_ONBOARDING_EMAILS_ENABLED: "${FEATURE_ONBOARDING_EMAILS_ENABLED}"
|
||||
FEATURE_PDF_REPORTS_ENABLED: "${FEATURE_PDF_REPORTS_ENABLED}"
|
||||
FEATURE_WORKER_ENABLED: "${FEATURE_WORKER_ENABLED}"
|
||||
command:
|
||||
- /bin/sh
|
||||
- -lc
|
||||
- |
|
||||
set -eu
|
||||
export POSTGRES_PASSWORD="$$(cat /run/secrets/postgres_password)"
|
||||
export SECRET_KEY="$$(cat /run/secrets/secret_key)"
|
||||
export EMAIL_HOST_PASSWORD="$$(cat /run/secrets/email_host_password)"
|
||||
export FCM_SERVER_KEY="$$(cat /run/secrets/fcm_server_key)"
|
||||
exec /app/worker
|
||||
secrets:
|
||||
- source: ${POSTGRES_PASSWORD_SECRET}
|
||||
target: postgres_password
|
||||
- source: ${SECRET_KEY_SECRET}
|
||||
target: secret_key
|
||||
- source: ${EMAIL_HOST_PASSWORD_SECRET}
|
||||
target: email_host_password
|
||||
- source: ${FCM_SERVER_KEY_SECRET}
|
||||
target: fcm_server_key
|
||||
- source: ${APNS_AUTH_KEY_SECRET}
|
||||
target: apns_auth_key
|
||||
deploy:
|
||||
replicas: ${WORKER_REPLICAS}
|
||||
restart_policy:
|
||||
condition: any
|
||||
delay: 5s
|
||||
update_config:
|
||||
parallelism: 1
|
||||
delay: 10s
|
||||
order: start-first
|
||||
rollback_config:
|
||||
parallelism: 1
|
||||
delay: 5s
|
||||
order: stop-first
|
||||
networks:
|
||||
- casera-network
|
||||
|
||||
dozzle:
|
||||
image: amir20/dozzle:latest
|
||||
ports:
|
||||
- target: 8080
|
||||
published: ${DOZZLE_PORT}
|
||||
protocol: tcp
|
||||
mode: ingress
|
||||
environment:
|
||||
DOZZLE_NO_ANALYTICS: "true"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
deploy:
|
||||
replicas: 1
|
||||
restart_policy:
|
||||
condition: any
|
||||
delay: 5s
|
||||
placement:
|
||||
constraints:
|
||||
- node.role == manager
|
||||
networks:
|
||||
- casera-network
|
||||
|
||||
volumes:
|
||||
redis_data:
|
||||
uploads:
|
||||
|
||||
networks:
|
||||
casera-network:
|
||||
driver: overlay
|
||||
driver_opts:
|
||||
encrypted: "true"
|
||||
|
||||
secrets:
|
||||
postgres_password:
|
||||
external: true
|
||||
name: ${POSTGRES_PASSWORD_SECRET}
|
||||
secret_key:
|
||||
external: true
|
||||
name: ${SECRET_KEY_SECRET}
|
||||
email_host_password:
|
||||
external: true
|
||||
name: ${EMAIL_HOST_PASSWORD_SECRET}
|
||||
fcm_server_key:
|
||||
external: true
|
||||
name: ${FCM_SERVER_KEY_SECRET}
|
||||
apns_auth_key:
|
||||
external: true
|
||||
name: ${APNS_AUTH_KEY_SECRET}
|
||||
1527
docs/AUDIT_FINDINGS.md
Normal file
1527
docs/AUDIT_FINDINGS.md
Normal file
File diff suppressed because it is too large
Load Diff
37
hardening-report.md
Normal file
37
hardening-report.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Go Backend Hardening Audit Report
|
||||
|
||||
## Audit Sources
|
||||
- 9 mapper agents (100% file coverage)
|
||||
- 8 specialized domain auditors (parallel)
|
||||
- 1 cross-cutting deep audit (parallel)
|
||||
- Total source files: 136 (excluding 27 test files)
|
||||
|
||||
---
|
||||
|
||||
## CRITICAL — Will crash or lose data
|
||||
|
||||
## BUG — Incorrect behavior
|
||||
|
||||
## SILENT FAILURE — Error swallowed or ignored
|
||||
|
||||
## RACE CONDITION — Concurrency issue
|
||||
|
||||
## LOGIC ERROR — Code doesn't match intent
|
||||
|
||||
## PERFORMANCE — Unnecessary cost
|
||||
|
||||
## SECURITY — Vulnerability or exposure
|
||||
|
||||
## AUTHORIZATION — Access control gap
|
||||
|
||||
## DATA INTEGRITY — GORM / database issue
|
||||
|
||||
## API CONTRACT — Request/response issue
|
||||
|
||||
## ARCHITECTURE — Layer or pattern violation
|
||||
|
||||
## FRAGILE — Works now but will break easily
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "github.com/treytartt/casera-api/internal/middleware"
|
||||
|
||||
// PaginationParams holds pagination query parameters
|
||||
type PaginationParams struct {
|
||||
Page int `form:"page" validate:"omitempty,min=1"`
|
||||
@@ -41,6 +43,12 @@ func (p *PaginationParams) GetSortDir() string {
|
||||
return "DESC"
|
||||
}
|
||||
|
||||
// GetSafeSortBy validates SortBy against an allowlist to prevent SQL injection.
|
||||
// Returns the matching allowed column, or defaultCol if SortBy is empty or not allowed.
|
||||
func (p *PaginationParams) GetSafeSortBy(allowedCols []string, defaultCol string) string {
|
||||
return middleware.SanitizeSortColumn(p.SortBy, allowedCols, defaultCol)
|
||||
}
|
||||
|
||||
// UserFilters holds user-specific filter parameters
|
||||
type UserFilters struct {
|
||||
PaginationParams
|
||||
|
||||
199
internal/admin/handlers/admin_security_test.go
Normal file
199
internal/admin/handlers/admin_security_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/admin/dto"
|
||||
)
|
||||
|
||||
func TestAdminSortBy_ValidColumn_Works(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
allowlist []string
|
||||
defaultCol string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "exact match returns column",
|
||||
sortBy: "created_at",
|
||||
allowlist: []string{"id", "created_at", "updated_at", "name"},
|
||||
defaultCol: "created_at",
|
||||
expected: "created_at",
|
||||
},
|
||||
{
|
||||
name: "case insensitive match returns canonical column",
|
||||
sortBy: "Created_At",
|
||||
allowlist: []string{"id", "created_at", "updated_at", "name"},
|
||||
defaultCol: "created_at",
|
||||
expected: "created_at",
|
||||
},
|
||||
{
|
||||
name: "different valid column",
|
||||
sortBy: "name",
|
||||
allowlist: []string{"id", "created_at", "updated_at", "name"},
|
||||
defaultCol: "created_at",
|
||||
expected: "name",
|
||||
},
|
||||
{
|
||||
name: "date_joined for user handler",
|
||||
sortBy: "date_joined",
|
||||
allowlist: []string{"id", "username", "email", "date_joined", "last_login", "is_active"},
|
||||
defaultCol: "date_joined",
|
||||
expected: "date_joined",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := dto.PaginationParams{SortBy: tt.sortBy}
|
||||
result := p.GetSafeSortBy(tt.allowlist, tt.defaultCol)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminSortBy_SQLInjection_ReturnsDefault(t *testing.T) {
|
||||
allowlist := []string{"id", "created_at", "updated_at", "name"}
|
||||
defaultCol := "created_at"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
}{
|
||||
{
|
||||
name: "SQL injection with DROP TABLE",
|
||||
sortBy: "created_at; DROP TABLE users; --",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with UNION SELECT",
|
||||
sortBy: "id UNION SELECT password FROM auth_user",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with subquery",
|
||||
sortBy: "(SELECT password FROM auth_user LIMIT 1)",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with comment",
|
||||
sortBy: "created_at--",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with semicolon",
|
||||
sortBy: "created_at;",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with OR 1=1",
|
||||
sortBy: "created_at OR 1=1",
|
||||
},
|
||||
{
|
||||
name: "column not in allowlist",
|
||||
sortBy: "password",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with single quotes",
|
||||
sortBy: "name'; DROP TABLE users; --",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with double dashes",
|
||||
sortBy: "id -- comment",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := dto.PaginationParams{SortBy: tt.sortBy}
|
||||
result := p.GetSafeSortBy(allowlist, defaultCol)
|
||||
assert.Equal(t, defaultCol, result, "SQL injection attempt should return default column")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminSortBy_EmptyString_ReturnsDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
defaultCol string
|
||||
}{
|
||||
{
|
||||
name: "empty string returns default",
|
||||
sortBy: "",
|
||||
defaultCol: "created_at",
|
||||
},
|
||||
{
|
||||
name: "whitespace only returns default",
|
||||
sortBy: " ",
|
||||
defaultCol: "created_at",
|
||||
},
|
||||
{
|
||||
name: "tab only returns default",
|
||||
sortBy: "\t",
|
||||
defaultCol: "date_joined",
|
||||
},
|
||||
{
|
||||
name: "different default column",
|
||||
sortBy: "",
|
||||
defaultCol: "completed_at",
|
||||
},
|
||||
}
|
||||
|
||||
allowlist := []string{"id", "created_at", "updated_at", "name"}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := dto.PaginationParams{SortBy: tt.sortBy}
|
||||
result := p.GetSafeSortBy(allowlist, tt.defaultCol)
|
||||
assert.Equal(t, tt.defaultCol, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendEmail_XSSEscaped(t *testing.T) {
|
||||
// SEC-22: Subject and Body must be HTML-escaped before insertion into email template.
|
||||
// This tests the html.EscapeString behavior that the handler relies on.
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "script tag in subject",
|
||||
input: `<script>alert("xss")</script>`,
|
||||
expected: `<script>alert("xss")</script>`,
|
||||
},
|
||||
{
|
||||
name: "img onerror payload",
|
||||
input: `<img src=x onerror=alert(1)>`,
|
||||
expected: `<img src=x onerror=alert(1)>`,
|
||||
},
|
||||
{
|
||||
name: "ampersand and angle brackets",
|
||||
input: `Tom & Jerry <bros>`,
|
||||
expected: `Tom & Jerry <bros>`,
|
||||
},
|
||||
{
|
||||
name: "plain text passes through",
|
||||
input: "Hello World",
|
||||
expected: "Hello World",
|
||||
},
|
||||
{
|
||||
name: "single quotes",
|
||||
input: `It's a 'test'`,
|
||||
expected: `It's a 'test'`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
escaped := html.EscapeString(tt.input)
|
||||
assert.Equal(t, tt.expected, escaped)
|
||||
// Verify the escaped output does NOT contain raw angle brackets from the input
|
||||
if tt.input != tt.expected {
|
||||
assert.NotContains(t, escaped, "<script>")
|
||||
assert.NotContains(t, escaped, "<img")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -80,11 +80,11 @@ func (h *AdminUserManagementHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "email", "first_name", "last_name",
|
||||
"role", "is_active", "created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -63,11 +63,11 @@ func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "apple_id", "email", "is_private_email",
|
||||
"created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -55,11 +55,10 @@ func (h *AdminAuthTokenHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"created", "user_id",
|
||||
}, "created")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -96,11 +96,11 @@ func (h *AdminCompletionHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "completed_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "task_id", "completed_by_id", "completed_at",
|
||||
"created_at", "notes", "actual_cost", "rating",
|
||||
}, "completed_at")
|
||||
sortDir := "DESC"
|
||||
if filters.SortDir != "" {
|
||||
sortDir = filters.GetSortDir()
|
||||
|
||||
@@ -78,11 +78,10 @@ func (h *AdminCompletionImageHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "completion_id", "created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -58,11 +58,10 @@ func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "created_at", "expires_at", "is_used",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -59,11 +59,11 @@ func (h *AdminContractorHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "name", "company", "email", "phone", "city",
|
||||
"created_at", "updated_at", "is_active", "is_favorite", "rating",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -70,10 +70,10 @@ func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
sortBy := "date_created"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "name", "active", "user_id", "device_id", "date_created",
|
||||
}, "date_created")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
@@ -125,10 +125,10 @@ func (h *AdminDeviceHandler) ListGCM(c echo.Context) error {
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
sortBy := "date_created"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "name", "active", "user_id", "device_id", "cloud_message_type", "date_created",
|
||||
}, "date_created")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
|
||||
@@ -61,11 +61,11 @@ func (h *AdminDocumentHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "title", "created_at", "updated_at", "document_type",
|
||||
"residence_id", "is_active", "expiry_date", "vendor",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -79,11 +79,10 @@ func (h *AdminDocumentImageHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "document_id", "created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -52,10 +52,10 @@ func (h *AdminFeatureBenefitHandler) List(c echo.Context) error {
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
sortBy := "display_order"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "feature_name", "display_order", "is_active", "created_at", "updated_at",
|
||||
}, "display_order")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@ func NewAdminLookupHandler(db *gorm.DB) *AdminLookupHandler {
|
||||
func (h *AdminLookupHandler) refreshCategoriesCache(ctx context.Context) {
|
||||
cache := services.GetCache()
|
||||
if cache == nil {
|
||||
log.Warn().Msg("Cache service unavailable, skipping categories cache refresh")
|
||||
return
|
||||
}
|
||||
|
||||
var categories []models.TaskCategory
|
||||
@@ -49,6 +51,8 @@ func (h *AdminLookupHandler) refreshCategoriesCache(ctx context.Context) {
|
||||
func (h *AdminLookupHandler) refreshPrioritiesCache(ctx context.Context) {
|
||||
cache := services.GetCache()
|
||||
if cache == nil {
|
||||
log.Warn().Msg("Cache service unavailable, skipping priorities cache refresh")
|
||||
return
|
||||
}
|
||||
|
||||
var priorities []models.TaskPriority
|
||||
@@ -69,6 +73,8 @@ func (h *AdminLookupHandler) refreshPrioritiesCache(ctx context.Context) {
|
||||
func (h *AdminLookupHandler) refreshFrequenciesCache(ctx context.Context) {
|
||||
cache := services.GetCache()
|
||||
if cache == nil {
|
||||
log.Warn().Msg("Cache service unavailable, skipping frequencies cache refresh")
|
||||
return
|
||||
}
|
||||
|
||||
var frequencies []models.TaskFrequency
|
||||
@@ -89,6 +95,8 @@ func (h *AdminLookupHandler) refreshFrequenciesCache(ctx context.Context) {
|
||||
func (h *AdminLookupHandler) refreshResidenceTypesCache(ctx context.Context) {
|
||||
cache := services.GetCache()
|
||||
if cache == nil {
|
||||
log.Warn().Msg("Cache service unavailable, skipping residence types cache refresh")
|
||||
return
|
||||
}
|
||||
|
||||
var types []models.ResidenceType
|
||||
@@ -109,6 +117,8 @@ func (h *AdminLookupHandler) refreshResidenceTypesCache(ctx context.Context) {
|
||||
func (h *AdminLookupHandler) refreshSpecialtiesCache(ctx context.Context) {
|
||||
cache := services.GetCache()
|
||||
if cache == nil {
|
||||
log.Warn().Msg("Cache service unavailable, skipping specialties cache refresh")
|
||||
return
|
||||
}
|
||||
|
||||
var specialties []models.ContractorSpecialty
|
||||
@@ -130,6 +140,8 @@ func (h *AdminLookupHandler) refreshSpecialtiesCache(ctx context.Context) {
|
||||
func (h *AdminLookupHandler) invalidateSeededDataCache(ctx context.Context) {
|
||||
cache := services.GetCache()
|
||||
if cache == nil {
|
||||
log.Warn().Msg("Cache service unavailable, skipping seeded data cache invalidation")
|
||||
return
|
||||
}
|
||||
|
||||
if err := cache.InvalidateSeededData(ctx); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"html"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -67,11 +68,11 @@ func (h *AdminNotificationHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "created_at", "updated_at", "user_id",
|
||||
"notification_type", "sent", "read", "title",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
@@ -347,16 +348,20 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]interface{}{"error": "Email service not configured"})
|
||||
}
|
||||
|
||||
// HTML-escape user-supplied values to prevent XSS via email content
|
||||
escapedSubject := html.EscapeString(req.Subject)
|
||||
escapedBody := html.EscapeString(req.Body)
|
||||
|
||||
// Create HTML body with basic styling
|
||||
htmlBody := `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>` + req.Subject + `</title>
|
||||
<title>` + escapedSubject + `</title>
|
||||
</head>
|
||||
<body style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto; padding: 20px;">
|
||||
<h2 style="color: #333;">` + req.Subject + `</h2>
|
||||
<div style="color: #666; line-height: 1.6;">` + req.Body + `</div>
|
||||
<h2 style="color: #333;">` + escapedSubject + `</h2>
|
||||
<div style="color: #666; line-height: 1.6;">` + escapedBody + `</div>
|
||||
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
|
||||
<p style="color: #999; font-size: 12px;">This is a test email sent from Casera Admin Panel.</p>
|
||||
</body>
|
||||
|
||||
@@ -76,11 +76,10 @@ func (h *AdminNotificationPrefsHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -60,11 +60,10 @@ func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "created_at", "expires_at", "used",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -56,10 +56,11 @@ func (h *AdminPromotionHandler) List(c echo.Context) error {
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "promotion_id", "title", "start_date", "end_date",
|
||||
"target_tier", "is_active", "created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
|
||||
@@ -58,11 +58,11 @@ func (h *AdminResidenceHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "name", "created_at", "updated_at", "owner_id",
|
||||
"city", "state_province", "country", "is_active", "is_primary",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -62,11 +62,11 @@ func (h *AdminShareCodeHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "residence_id", "code", "created_by_id",
|
||||
"is_active", "expires_at", "created_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
@@ -153,13 +153,17 @@ func (h *AdminShareCodeHandler) Update(c echo.Context) error {
|
||||
}
|
||||
|
||||
var req struct {
|
||||
IsActive bool `json:"is_active"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
|
||||
code.IsActive = req.IsActive
|
||||
// Only update IsActive when explicitly provided (non-nil).
|
||||
// Using *bool prevents a missing field from defaulting to false.
|
||||
if req.IsActive != nil {
|
||||
code.IsActive = *req.IsActive
|
||||
}
|
||||
|
||||
if err := h.db.Save(&code).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update share code"})
|
||||
|
||||
@@ -65,11 +65,11 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "created_at", "updated_at", "user_id",
|
||||
"tier", "platform", "auto_renew", "expires_at", "subscribed_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -68,11 +68,12 @@ func (h *AdminTaskHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "title", "created_at", "updated_at", "due_date", "next_due_date",
|
||||
"residence_id", "category_id", "priority_id", "in_progress",
|
||||
"is_cancelled", "is_archived", "estimated_cost", "actual_cost",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -56,11 +56,11 @@ func (h *AdminUserHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "date_joined"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "username", "email", "first_name", "last_name",
|
||||
"date_joined", "last_login", "is_active", "is_staff", "is_superuser",
|
||||
}, "date_joined")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -69,11 +69,10 @@ func (h *AdminUserProfileHandler) List(c echo.Context) error {
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting
|
||||
sortBy := "created_at"
|
||||
if filters.SortBy != "" {
|
||||
sortBy = filters.SortBy
|
||||
}
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "verified", "created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@@ -338,10 +338,14 @@ func validate(cfg *Config) error {
|
||||
// In debug mode, use a default key with a warning for local development
|
||||
cfg.Security.SecretKey = "change-me-in-production-secret-key-12345"
|
||||
fmt.Println("WARNING: SECRET_KEY not set, using default (debug mode only)")
|
||||
fmt.Println("WARNING: *** DO NOT USE THIS DEFAULT KEY IN PRODUCTION ***")
|
||||
} else {
|
||||
// In production, refuse to start without a proper secret key
|
||||
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
|
||||
}
|
||||
} else if cfg.Security.SecretKey == "change-me-in-production-secret-key-12345" {
|
||||
// Warn if someone explicitly set the well-known debug key
|
||||
fmt.Println("WARNING: SECRET_KEY is set to the well-known debug default. Change it for production use.")
|
||||
}
|
||||
|
||||
// Database password might come from DATABASE_URL, don't require it separately
|
||||
|
||||
@@ -369,17 +369,13 @@ func migrateGoAdmin() error {
|
||||
}
|
||||
db.Exec(`CREATE INDEX IF NOT EXISTS idx_goadmin_site_key ON goadmin_site(key)`)
|
||||
|
||||
// Seed default admin user (password: admin - bcrypt hash)
|
||||
// Seed default admin user only on first run (ON CONFLICT DO NOTHING).
|
||||
// Password is NOT reset on subsequent migrations to preserve operator changes.
|
||||
db.Exec(`
|
||||
INSERT INTO goadmin_users (username, password, name, avatar)
|
||||
VALUES ('admin', '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm', 'Administrator', '')
|
||||
ON CONFLICT DO NOTHING
|
||||
`)
|
||||
// Update existing admin password if it exists with wrong hash
|
||||
db.Exec(`
|
||||
UPDATE goadmin_users SET password = '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm'
|
||||
WHERE username = 'admin'
|
||||
`)
|
||||
|
||||
// Seed default roles
|
||||
db.Exec(`INSERT INTO goadmin_roles (name, slug) VALUES ('Administrator', 'administrator') ON CONFLICT DO NOTHING`)
|
||||
@@ -443,8 +439,8 @@ func migrateGoAdmin() error {
|
||||
|
||||
log.Info().Msg("GoAdmin migrations completed")
|
||||
|
||||
// Seed default Next.js admin user (email: admin@mycrib.com, password: admin123)
|
||||
// bcrypt hash for "admin123": $2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O
|
||||
// Seed default Next.js admin user only on first run.
|
||||
// Password is NOT reset on subsequent migrations to preserve operator changes.
|
||||
var adminCount int64
|
||||
db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = 'admin@mycrib.com'`).Scan(&adminCount)
|
||||
if adminCount == 0 {
|
||||
@@ -453,14 +449,7 @@ func migrateGoAdmin() error {
|
||||
INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at)
|
||||
VALUES ('admin@mycrib.com', '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O', 'Admin', 'User', 'super_admin', true, NOW(), NOW())
|
||||
`)
|
||||
log.Info().Msg("Default admin user created: admin@mycrib.com / admin123")
|
||||
} else {
|
||||
// Update existing admin password if needed
|
||||
db.Exec(`
|
||||
UPDATE admin_users SET password = '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O'
|
||||
WHERE email = 'admin@mycrib.com'
|
||||
`)
|
||||
log.Info().Msg("Updated admin@mycrib.com password to admin123")
|
||||
log.Info().Msg("Default admin user created: admin@mycrib.com")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,13 +8,13 @@ type CreateContractorRequest struct {
|
||||
Phone string `json:"phone" validate:"max=20"`
|
||||
Email string `json:"email" validate:"omitempty,email,max=254"`
|
||||
Website string `json:"website" validate:"max=200"`
|
||||
Notes string `json:"notes"`
|
||||
Notes string `json:"notes" validate:"max=10000"`
|
||||
StreetAddress string `json:"street_address" validate:"max=255"`
|
||||
City string `json:"city" validate:"max=100"`
|
||||
StateProvince string `json:"state_province" validate:"max=100"`
|
||||
PostalCode string `json:"postal_code" validate:"max=20"`
|
||||
SpecialtyIDs []uint `json:"specialty_ids"`
|
||||
Rating *float64 `json:"rating"`
|
||||
SpecialtyIDs []uint `json:"specialty_ids" validate:"omitempty,max=20"`
|
||||
Rating *float64 `json:"rating" validate:"omitempty,min=0,max=5"`
|
||||
IsFavorite *bool `json:"is_favorite"`
|
||||
}
|
||||
|
||||
@@ -25,13 +25,13 @@ type UpdateContractorRequest struct {
|
||||
Phone *string `json:"phone" validate:"omitempty,max=20"`
|
||||
Email *string `json:"email" validate:"omitempty,email,max=254"`
|
||||
Website *string `json:"website" validate:"omitempty,max=200"`
|
||||
Notes *string `json:"notes"`
|
||||
Notes *string `json:"notes" validate:"omitempty,max=10000"`
|
||||
StreetAddress *string `json:"street_address" validate:"omitempty,max=255"`
|
||||
City *string `json:"city" validate:"omitempty,max=100"`
|
||||
StateProvince *string `json:"state_province" validate:"omitempty,max=100"`
|
||||
PostalCode *string `json:"postal_code" validate:"omitempty,max=20"`
|
||||
SpecialtyIDs []uint `json:"specialty_ids"`
|
||||
Rating *float64 `json:"rating"`
|
||||
SpecialtyIDs []uint `json:"specialty_ids" validate:"omitempty,max=20"`
|
||||
Rating *float64 `json:"rating" validate:"omitempty,min=0,max=5"`
|
||||
IsFavorite *bool `json:"is_favorite"`
|
||||
ResidenceID *uint `json:"residence_id"`
|
||||
}
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
type CreateDocumentRequest struct {
|
||||
ResidenceID uint `json:"residence_id" validate:"required"`
|
||||
Title string `json:"title" validate:"required,min=1,max=200"`
|
||||
Description string `json:"description"`
|
||||
DocumentType models.DocumentType `json:"document_type"`
|
||||
Description string `json:"description" validate:"max=10000"`
|
||||
DocumentType models.DocumentType `json:"document_type" validate:"omitempty,oneof=general warranty receipt contract insurance manual"`
|
||||
FileURL string `json:"file_url" validate:"max=500"`
|
||||
FileName string `json:"file_name" validate:"max=255"`
|
||||
FileSize *int64 `json:"file_size"`
|
||||
FileSize *int64 `json:"file_size" validate:"omitempty,min=0"`
|
||||
MimeType string `json:"mime_type" validate:"max=100"`
|
||||
PurchaseDate *time.Time `json:"purchase_date"`
|
||||
ExpiryDate *time.Time `json:"expiry_date"`
|
||||
@@ -25,17 +25,17 @@ type CreateDocumentRequest struct {
|
||||
SerialNumber string `json:"serial_number" validate:"max=100"`
|
||||
ModelNumber string `json:"model_number" validate:"max=100"`
|
||||
TaskID *uint `json:"task_id"`
|
||||
ImageURLs []string `json:"image_urls"` // Multiple image URLs
|
||||
ImageURLs []string `json:"image_urls" validate:"omitempty,max=20,dive,max=500"` // Multiple image URLs
|
||||
}
|
||||
|
||||
// UpdateDocumentRequest represents the request to update a document
|
||||
type UpdateDocumentRequest struct {
|
||||
Title *string `json:"title" validate:"omitempty,min=1,max=200"`
|
||||
Description *string `json:"description"`
|
||||
DocumentType *models.DocumentType `json:"document_type"`
|
||||
Description *string `json:"description" validate:"omitempty,max=10000"`
|
||||
DocumentType *models.DocumentType `json:"document_type" validate:"omitempty,oneof=general warranty receipt contract insurance manual"`
|
||||
FileURL *string `json:"file_url" validate:"omitempty,max=500"`
|
||||
FileName *string `json:"file_name" validate:"omitempty,max=255"`
|
||||
FileSize *int64 `json:"file_size"`
|
||||
FileSize *int64 `json:"file_size" validate:"omitempty,min=0"`
|
||||
MimeType *string `json:"mime_type" validate:"omitempty,max=100"`
|
||||
PurchaseDate *time.Time `json:"purchase_date"`
|
||||
ExpiryDate *time.Time `json:"expiry_date"`
|
||||
|
||||
@@ -16,12 +16,12 @@ type CreateResidenceRequest struct {
|
||||
StateProvince string `json:"state_province" validate:"max=100"`
|
||||
PostalCode string `json:"postal_code" validate:"max=20"`
|
||||
Country string `json:"country" validate:"max=100"`
|
||||
Bedrooms *int `json:"bedrooms"`
|
||||
Bedrooms *int `json:"bedrooms" validate:"omitempty,min=0"`
|
||||
Bathrooms *decimal.Decimal `json:"bathrooms"`
|
||||
SquareFootage *int `json:"square_footage"`
|
||||
SquareFootage *int `json:"square_footage" validate:"omitempty,min=0"`
|
||||
LotSize *decimal.Decimal `json:"lot_size"`
|
||||
YearBuilt *int `json:"year_built"`
|
||||
Description string `json:"description"`
|
||||
Description string `json:"description" validate:"max=10000"`
|
||||
PurchaseDate *time.Time `json:"purchase_date"`
|
||||
PurchasePrice *decimal.Decimal `json:"purchase_price"`
|
||||
IsPrimary *bool `json:"is_primary"`
|
||||
@@ -37,12 +37,12 @@ type UpdateResidenceRequest struct {
|
||||
StateProvince *string `json:"state_province" validate:"omitempty,max=100"`
|
||||
PostalCode *string `json:"postal_code" validate:"omitempty,max=20"`
|
||||
Country *string `json:"country" validate:"omitempty,max=100"`
|
||||
Bedrooms *int `json:"bedrooms"`
|
||||
Bedrooms *int `json:"bedrooms" validate:"omitempty,min=0"`
|
||||
Bathrooms *decimal.Decimal `json:"bathrooms"`
|
||||
SquareFootage *int `json:"square_footage"`
|
||||
SquareFootage *int `json:"square_footage" validate:"omitempty,min=0"`
|
||||
LotSize *decimal.Decimal `json:"lot_size"`
|
||||
YearBuilt *int `json:"year_built"`
|
||||
Description *string `json:"description"`
|
||||
Description *string `json:"description" validate:"omitempty,max=10000"`
|
||||
PurchaseDate *time.Time `json:"purchase_date"`
|
||||
PurchasePrice *decimal.Decimal `json:"purchase_price"`
|
||||
IsPrimary *bool `json:"is_primary"`
|
||||
@@ -55,5 +55,5 @@ type JoinWithCodeRequest struct {
|
||||
|
||||
// GenerateShareCodeRequest represents the request to generate a share code
|
||||
type GenerateShareCodeRequest struct {
|
||||
ExpiresInHours int `json:"expires_in_hours"` // Default: 24 hours
|
||||
ExpiresInHours int `json:"expires_in_hours" validate:"omitempty,min=1"` // Default: 24 hours
|
||||
}
|
||||
|
||||
@@ -56,11 +56,11 @@ func (fd *FlexibleDate) ToTimePtr() *time.Time {
|
||||
type CreateTaskRequest struct {
|
||||
ResidenceID uint `json:"residence_id" validate:"required"`
|
||||
Title string `json:"title" validate:"required,min=1,max=200"`
|
||||
Description string `json:"description"`
|
||||
Description string `json:"description" validate:"max=10000"`
|
||||
CategoryID *uint `json:"category_id"`
|
||||
PriorityID *uint `json:"priority_id"`
|
||||
FrequencyID *uint `json:"frequency_id"`
|
||||
CustomIntervalDays *int `json:"custom_interval_days"` // For "Custom" frequency, user-specified days
|
||||
CustomIntervalDays *int `json:"custom_interval_days" validate:"omitempty,min=1"` // For "Custom" frequency, user-specified days
|
||||
InProgress bool `json:"in_progress"`
|
||||
AssignedToID *uint `json:"assigned_to_id"`
|
||||
DueDate *FlexibleDate `json:"due_date"`
|
||||
@@ -75,7 +75,7 @@ type UpdateTaskRequest struct {
|
||||
CategoryID *uint `json:"category_id"`
|
||||
PriorityID *uint `json:"priority_id"`
|
||||
FrequencyID *uint `json:"frequency_id"`
|
||||
CustomIntervalDays *int `json:"custom_interval_days"` // For "Custom" frequency, user-specified days
|
||||
CustomIntervalDays *int `json:"custom_interval_days" validate:"omitempty,min=1"` // For "Custom" frequency, user-specified days
|
||||
InProgress *bool `json:"in_progress"`
|
||||
AssignedToID *uint `json:"assigned_to_id"`
|
||||
DueDate *FlexibleDate `json:"due_date"`
|
||||
@@ -88,18 +88,18 @@ type UpdateTaskRequest struct {
|
||||
type CreateTaskCompletionRequest struct {
|
||||
TaskID uint `json:"task_id" validate:"required"`
|
||||
CompletedAt *time.Time `json:"completed_at"` // Defaults to now
|
||||
Notes string `json:"notes"`
|
||||
Notes string `json:"notes" validate:"max=10000"`
|
||||
ActualCost *decimal.Decimal `json:"actual_cost"`
|
||||
Rating *int `json:"rating"` // 1-5 star rating
|
||||
ImageURLs []string `json:"image_urls"` // Multiple image URLs
|
||||
Rating *int `json:"rating" validate:"omitempty,min=1,max=5"` // 1-5 star rating
|
||||
ImageURLs []string `json:"image_urls" validate:"omitempty,max=20,dive,max=500"` // Multiple image URLs
|
||||
}
|
||||
|
||||
// UpdateTaskCompletionRequest represents the request to update a task completion
|
||||
type UpdateTaskCompletionRequest struct {
|
||||
Notes *string `json:"notes"`
|
||||
Notes *string `json:"notes" validate:"omitempty,max=10000"`
|
||||
ActualCost *decimal.Decimal `json:"actual_cost"`
|
||||
Rating *int `json:"rating"`
|
||||
ImageURLs []string `json:"image_urls"`
|
||||
Rating *int `json:"rating" validate:"omitempty,min=1,max=5"`
|
||||
ImageURLs []string `json:"image_urls" validate:"omitempty,max=20,dive,max=500"`
|
||||
}
|
||||
|
||||
// CompletionImageInput represents an image to add to a completion
|
||||
|
||||
@@ -81,6 +81,11 @@ func (h *AuthHandler) Register(c echo.Context) error {
|
||||
// Send welcome email with confirmation code (async)
|
||||
if h.emailService != nil && confirmationCode != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
|
||||
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
|
||||
}
|
||||
@@ -176,6 +181,11 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
|
||||
// Send post-verification welcome email with tips (async)
|
||||
if h.emailService != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
|
||||
}
|
||||
@@ -204,6 +214,11 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
|
||||
// Send verification email (async)
|
||||
if h.emailService != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
|
||||
}
|
||||
@@ -238,6 +253,11 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
||||
// Send password reset email (async) - only if user found
|
||||
if h.emailService != nil && code != "" && user != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in password reset email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
|
||||
}
|
||||
@@ -326,6 +346,11 @@ func (h *AuthHandler) AppleSignIn(c echo.Context) error {
|
||||
// Send welcome email for new users (async)
|
||||
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email")
|
||||
}
|
||||
@@ -368,6 +393,11 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
// Send welcome email for new users (async)
|
||||
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email")
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -25,17 +24,23 @@ func NewContractorHandler(contractorService *services.ContractorService) *Contra
|
||||
|
||||
// ListContractors handles GET /api/contractors/
|
||||
func (h *ContractorHandler) ListContractors(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := h.contractorService.ListContractors(user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetContractor handles GET /api/contractors/:id/
|
||||
func (h *ContractorHandler) GetContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -50,11 +55,17 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
|
||||
|
||||
// CreateContractor handles POST /api/contractors/
|
||||
func (h *ContractorHandler) CreateContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req requests.CreateContractorRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.contractorService.CreateContractor(&req, user.ID)
|
||||
if err != nil {
|
||||
@@ -65,7 +76,10 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
|
||||
|
||||
// UpdateContractor handles PUT/PATCH /api/contractors/:id/
|
||||
func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -75,6 +89,9 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -85,7 +102,10 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
||||
|
||||
// DeleteContractor handles DELETE /api/contractors/:id/
|
||||
func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -100,7 +120,10 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
|
||||
|
||||
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
|
||||
func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -115,7 +138,10 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
|
||||
|
||||
// GetContractorTasks handles GET /api/contractors/:id/tasks/
|
||||
func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -130,7 +156,10 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
|
||||
|
||||
// ListContractorsByResidence handles GET /api/contractors/by-residence/:residence_id/
|
||||
func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_residence_id")
|
||||
@@ -147,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
|
||||
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
|
||||
specialties, err := h.contractorService.GetSpecialties()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, specialties)
|
||||
}
|
||||
|
||||
182
internal/handlers/contractor_handler_test.go
Normal file
182
internal/handlers/contractor_handler_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupContractorHandler(t *testing.T) (*ContractorHandler, *echo.Echo, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
|
||||
handler := NewContractorHandler(contractorService)
|
||||
e := testutil.SetupTestRouter()
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func TestContractorHandler_CreateContractor_MissingName_Returns400(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateContractor)
|
||||
|
||||
t.Run("missing name returns 400 validation error", func(t *testing.T) {
|
||||
// Send request with no name (required field)
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should contain structured validation error
|
||||
assert.Contains(t, response, "error")
|
||||
assert.Contains(t, response, "fields")
|
||||
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "name", "validation error should reference the 'name' field")
|
||||
})
|
||||
|
||||
t.Run("empty body returns 400 validation error", func(t *testing.T) {
|
||||
// Send completely empty body
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", map[string]interface{}{}, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
})
|
||||
|
||||
t.Run("valid contractor creation succeeds", func(t *testing.T) {
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "John the Plumber",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_ListContractors_Error_NoRawErrorInResponse(t *testing.T) {
|
||||
_, e, db := setupContractorHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a handler with a broken service to simulate an internal error.
|
||||
// We do this by closing the underlying SQL connection, which will cause
|
||||
// the service to return an error on the next query.
|
||||
brokenDB := testutil.SetupTestDB(t)
|
||||
sqlDB, _ := brokenDB.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
brokenContractorRepo := repositories.NewContractorRepository(brokenDB)
|
||||
brokenResidenceRepo := repositories.NewResidenceRepository(brokenDB)
|
||||
brokenService := services.NewContractorService(brokenContractorRepo, brokenResidenceRepo)
|
||||
brokenHandler := NewContractorHandler(brokenService)
|
||||
|
||||
authGroup := e.Group("/api/broken-contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", brokenHandler.ListContractors)
|
||||
|
||||
t.Run("internal error does not leak raw error message", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/broken-contractors/", nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should contain the generic error key, NOT a raw database error
|
||||
errorMsg, ok := response["error"].(string)
|
||||
require.True(t, ok, "response should have an 'error' string field")
|
||||
|
||||
// Must not contain database-specific details
|
||||
assert.NotContains(t, errorMsg, "sql", "error message should not leak SQL details")
|
||||
assert.NotContains(t, errorMsg, "database", "error message should not leak database details")
|
||||
assert.NotContains(t, errorMsg, "closed", "error message should not leak connection state")
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_CreateContractor_100Specialties_Returns400(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateContractor)
|
||||
|
||||
t.Run("too many specialties rejected", func(t *testing.T) {
|
||||
// Create a slice with 100 specialty IDs (exceeds max=20)
|
||||
specialtyIDs := make([]uint, 100)
|
||||
for i := range specialtyIDs {
|
||||
specialtyIDs[i] = uint(i + 1)
|
||||
}
|
||||
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Over-specialized Contractor",
|
||||
SpecialtyIDs: specialtyIDs,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("20 specialties accepted", func(t *testing.T) {
|
||||
specialtyIDs := make([]uint, 20)
|
||||
for i := range specialtyIDs {
|
||||
specialtyIDs[i] = uint(i + 1)
|
||||
}
|
||||
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Multi-skilled Contractor",
|
||||
SpecialtyIDs: specialtyIDs,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
// Should pass validation (201 or success, not 400)
|
||||
assert.NotEqual(t, http.StatusBadRequest, w.Code, "20 specialties should pass validation")
|
||||
})
|
||||
|
||||
t.Run("rating above 5 rejected", func(t *testing.T) {
|
||||
rating := 6.0
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Bad Rating Contractor",
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
@@ -34,7 +34,10 @@ func NewDocumentHandler(documentService *services.DocumentService, storageServic
|
||||
|
||||
// ListDocuments handles GET /api/documents/
|
||||
func (h *DocumentHandler) ListDocuments(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build filter from supported query params.
|
||||
var filter *repositories.DocumentFilter
|
||||
@@ -71,7 +74,10 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
|
||||
|
||||
// GetDocument handles GET /api/documents/:id/
|
||||
func (h *DocumentHandler) GetDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -86,10 +92,13 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
|
||||
|
||||
// ListWarranties handles GET /api/documents/warranties/
|
||||
func (h *DocumentHandler) ListWarranties(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := h.documentService.ListWarranties(user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
@@ -97,7 +106,10 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
|
||||
// CreateDocument handles POST /api/documents/
|
||||
// Supports both JSON and multipart form data (for file uploads)
|
||||
func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req requests.CreateDocumentRequest
|
||||
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
@@ -198,6 +210,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.documentService.CreateDocument(&req, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -207,7 +223,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
||||
|
||||
// UpdateDocument handles PUT/PATCH /api/documents/:id/
|
||||
func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -217,6 +236,9 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -227,7 +249,10 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
||||
|
||||
// DeleteDocument handles DELETE /api/documents/:id/
|
||||
func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -242,7 +267,10 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
|
||||
|
||||
// ActivateDocument handles POST /api/documents/:id/activate/
|
||||
func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -257,7 +285,10 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
|
||||
|
||||
// DeactivateDocument handles POST /api/documents/:id/deactivate/
|
||||
func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -272,7 +303,10 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
|
||||
|
||||
// UploadDocumentImage handles POST /api/documents/:id/images/
|
||||
func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -316,7 +350,10 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
|
||||
|
||||
// DeleteDocumentImage handles DELETE /api/documents/:id/images/:imageId/
|
||||
func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
@@ -40,7 +38,10 @@ func NewMediaHandler(
|
||||
// ServeDocument serves a document file with access control
|
||||
// GET /api/media/document/:id
|
||||
func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -73,7 +74,10 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
// ServeDocumentImage serves a document image with access control
|
||||
// GET /api/media/document-image/:id
|
||||
func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -111,7 +115,10 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
|
||||
// ServeCompletionImage serves a task completion image with access control
|
||||
// GET /api/media/completion-image/:id
|
||||
func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -152,7 +159,9 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
|
||||
return c.File(filePath)
|
||||
}
|
||||
|
||||
// resolveFilePath converts a stored URL to an actual file path
|
||||
// resolveFilePath converts a stored URL to an actual file path.
|
||||
// Returns empty string if the URL is empty or the resolved path would escape
|
||||
// the upload directory (path traversal attempt).
|
||||
func (h *MediaHandler) resolveFilePath(storedURL string) string {
|
||||
if storedURL == "" {
|
||||
return ""
|
||||
@@ -160,12 +169,18 @@ func (h *MediaHandler) resolveFilePath(storedURL string) string {
|
||||
|
||||
uploadDir := h.storageSvc.GetUploadDir()
|
||||
|
||||
// Handle legacy /uploads/... URLs
|
||||
// Strip legacy /uploads/ prefix to get relative path
|
||||
relativePath := storedURL
|
||||
if strings.HasPrefix(storedURL, "/uploads/") {
|
||||
relativePath := strings.TrimPrefix(storedURL, "/uploads/")
|
||||
return filepath.Join(uploadDir, relativePath)
|
||||
relativePath = strings.TrimPrefix(storedURL, "/uploads/")
|
||||
}
|
||||
|
||||
// Handle relative paths (new format)
|
||||
return filepath.Join(uploadDir, storedURL)
|
||||
// Use SafeResolvePath to validate containment within upload directory
|
||||
resolved, err := services.SafeResolvePath(uploadDir, relativePath)
|
||||
if err != nil {
|
||||
// Path traversal or invalid path — return empty to signal file not found
|
||||
return ""
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
74
internal/handlers/media_handler_test.go
Normal file
74
internal/handlers/media_handler_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
// newTestStorageService creates a StorageService with a known upload directory for testing.
|
||||
// It does NOT call NewStorageService because that creates directories on disk.
|
||||
// Instead, it directly constructs the struct with only what resolveFilePath needs.
|
||||
func newTestStorageService(uploadDir string) *services.StorageService {
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: uploadDir,
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg,image/png",
|
||||
}
|
||||
// Use the exported constructor helper that skips directory creation (for tests)
|
||||
return services.NewStorageServiceForTest(cfg)
|
||||
}
|
||||
|
||||
func TestResolveFilePath_NormalPath_Works(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
result := h.resolveFilePath("images/photo.jpg")
|
||||
require.NotEmpty(t, result)
|
||||
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
|
||||
}
|
||||
|
||||
func TestResolveFilePath_LegacyUploadPath_Works(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
result := h.resolveFilePath("/uploads/images/photo.jpg")
|
||||
require.NotEmpty(t, result)
|
||||
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
|
||||
}
|
||||
|
||||
func TestResolveFilePath_DotDotTraversal_Blocked(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
storedURL string
|
||||
}{
|
||||
{"simple dotdot", "../etc/passwd"},
|
||||
{"nested dotdot", "../../etc/shadow"},
|
||||
{"embedded dotdot", "images/../../etc/passwd"},
|
||||
{"legacy prefix with dotdot", "/uploads/../../../etc/passwd"},
|
||||
{"deep dotdot", "a/b/c/../../../../etc/passwd"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := h.resolveFilePath(tt.storedURL)
|
||||
assert.Empty(t, result, "path traversal should return empty string for: %s", tt.storedURL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveFilePath_EmptyURL_ReturnsEmpty(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
result := h.resolveFilePath("")
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
334
internal/handlers/noauth_test.go
Normal file
334
internal/handlers/noauth_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
// TestTaskHandler_NoAuth_Returns401 verifies that task handler endpoints return
|
||||
// 401 Unauthorized when no auth user is set in the context (e.g., auth middleware
|
||||
// misconfigured or bypassed). This is a regression test for P1-1 (SEC-19).
|
||||
func TestTaskHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskService := services.NewTaskService(taskRepo, residenceRepo)
|
||||
handler := NewTaskHandler(taskService, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/tasks/", handler.ListTasks)
|
||||
e.GET("/api/tasks/:id/", handler.GetTask)
|
||||
e.POST("/api/tasks/", handler.CreateTask)
|
||||
e.PUT("/api/tasks/:id/", handler.UpdateTask)
|
||||
e.DELETE("/api/tasks/:id/", handler.DeleteTask)
|
||||
e.POST("/api/tasks/:id/cancel/", handler.CancelTask)
|
||||
e.POST("/api/tasks/:id/mark-in-progress/", handler.MarkInProgress)
|
||||
e.GET("/api/task-completions/", handler.ListCompletions)
|
||||
e.POST("/api/task-completions/", handler.CreateCompletion)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListTasks", "GET", "/api/tasks/"},
|
||||
{"GetTask", "GET", "/api/tasks/1/"},
|
||||
{"CreateTask", "POST", "/api/tasks/"},
|
||||
{"UpdateTask", "PUT", "/api/tasks/1/"},
|
||||
{"DeleteTask", "DELETE", "/api/tasks/1/"},
|
||||
{"CancelTask", "POST", "/api/tasks/1/cancel/"},
|
||||
{"MarkInProgress", "POST", "/api/tasks/1/mark-in-progress/"},
|
||||
{"ListCompletions", "GET", "/api/task-completions/"},
|
||||
{"CreateCompletion", "POST", "/api/task-completions/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResidenceHandler_NoAuth_Returns401 verifies that residence handler endpoints
|
||||
// return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestResidenceHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
handler := NewResidenceHandler(residenceService, nil, nil, true)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/residences/", handler.ListResidences)
|
||||
e.GET("/api/residences/my/", handler.GetMyResidences)
|
||||
e.GET("/api/residences/summary/", handler.GetSummary)
|
||||
e.GET("/api/residences/:id/", handler.GetResidence)
|
||||
e.POST("/api/residences/", handler.CreateResidence)
|
||||
e.PUT("/api/residences/:id/", handler.UpdateResidence)
|
||||
e.DELETE("/api/residences/:id/", handler.DeleteResidence)
|
||||
e.POST("/api/residences/:id/generate-share-code/", handler.GenerateShareCode)
|
||||
e.POST("/api/residences/join-with-code/", handler.JoinWithCode)
|
||||
e.GET("/api/residences/:id/users/", handler.GetResidenceUsers)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListResidences", "GET", "/api/residences/"},
|
||||
{"GetMyResidences", "GET", "/api/residences/my/"},
|
||||
{"GetSummary", "GET", "/api/residences/summary/"},
|
||||
{"GetResidence", "GET", "/api/residences/1/"},
|
||||
{"CreateResidence", "POST", "/api/residences/"},
|
||||
{"UpdateResidence", "PUT", "/api/residences/1/"},
|
||||
{"DeleteResidence", "DELETE", "/api/residences/1/"},
|
||||
{"GenerateShareCode", "POST", "/api/residences/1/generate-share-code/"},
|
||||
{"JoinWithCode", "POST", "/api/residences/join-with-code/"},
|
||||
{"GetResidenceUsers", "GET", "/api/residences/1/users/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationHandler_NoAuth_Returns401 verifies that notification handler
|
||||
// endpoints return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestNotificationHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notificationRepo := repositories.NewNotificationRepository(db)
|
||||
notificationService := services.NewNotificationService(notificationRepo, nil)
|
||||
handler := NewNotificationHandler(notificationService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/notifications/", handler.ListNotifications)
|
||||
e.GET("/api/notifications/unread-count/", handler.GetUnreadCount)
|
||||
e.POST("/api/notifications/:id/read/", handler.MarkAsRead)
|
||||
e.POST("/api/notifications/mark-all-read/", handler.MarkAllAsRead)
|
||||
e.GET("/api/notifications/preferences/", handler.GetPreferences)
|
||||
e.PUT("/api/notifications/preferences/", handler.UpdatePreferences)
|
||||
e.POST("/api/notifications/devices/", handler.RegisterDevice)
|
||||
e.GET("/api/notifications/devices/", handler.ListDevices)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListNotifications", "GET", "/api/notifications/"},
|
||||
{"GetUnreadCount", "GET", "/api/notifications/unread-count/"},
|
||||
{"MarkAsRead", "POST", "/api/notifications/1/read/"},
|
||||
{"MarkAllAsRead", "POST", "/api/notifications/mark-all-read/"},
|
||||
{"GetPreferences", "GET", "/api/notifications/preferences/"},
|
||||
{"UpdatePreferences", "PUT", "/api/notifications/preferences/"},
|
||||
{"RegisterDevice", "POST", "/api/notifications/devices/"},
|
||||
{"ListDevices", "GET", "/api/notifications/devices/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentHandler_NoAuth_Returns401 verifies that document handler endpoints
|
||||
// return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestDocumentHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
documentService := services.NewDocumentService(documentRepo, residenceRepo)
|
||||
handler := NewDocumentHandler(documentService, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/documents/", handler.ListDocuments)
|
||||
e.GET("/api/documents/:id/", handler.GetDocument)
|
||||
e.GET("/api/documents/warranties/", handler.ListWarranties)
|
||||
e.POST("/api/documents/", handler.CreateDocument)
|
||||
e.PUT("/api/documents/:id/", handler.UpdateDocument)
|
||||
e.DELETE("/api/documents/:id/", handler.DeleteDocument)
|
||||
e.POST("/api/documents/:id/activate/", handler.ActivateDocument)
|
||||
e.POST("/api/documents/:id/deactivate/", handler.DeactivateDocument)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListDocuments", "GET", "/api/documents/"},
|
||||
{"GetDocument", "GET", "/api/documents/1/"},
|
||||
{"ListWarranties", "GET", "/api/documents/warranties/"},
|
||||
{"CreateDocument", "POST", "/api/documents/"},
|
||||
{"UpdateDocument", "PUT", "/api/documents/1/"},
|
||||
{"DeleteDocument", "DELETE", "/api/documents/1/"},
|
||||
{"ActivateDocument", "POST", "/api/documents/1/activate/"},
|
||||
{"DeactivateDocument", "POST", "/api/documents/1/deactivate/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContractorHandler_NoAuth_Returns401 verifies that contractor handler endpoints
|
||||
// return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestContractorHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
|
||||
handler := NewContractorHandler(contractorService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/contractors/", handler.ListContractors)
|
||||
e.GET("/api/contractors/:id/", handler.GetContractor)
|
||||
e.POST("/api/contractors/", handler.CreateContractor)
|
||||
e.PUT("/api/contractors/:id/", handler.UpdateContractor)
|
||||
e.DELETE("/api/contractors/:id/", handler.DeleteContractor)
|
||||
e.POST("/api/contractors/:id/toggle-favorite/", handler.ToggleFavorite)
|
||||
e.GET("/api/contractors/:id/tasks/", handler.GetContractorTasks)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListContractors", "GET", "/api/contractors/"},
|
||||
{"GetContractor", "GET", "/api/contractors/1/"},
|
||||
{"CreateContractor", "POST", "/api/contractors/"},
|
||||
{"UpdateContractor", "PUT", "/api/contractors/1/"},
|
||||
{"DeleteContractor", "DELETE", "/api/contractors/1/"},
|
||||
{"ToggleFavorite", "POST", "/api/contractors/1/toggle-favorite/"},
|
||||
{"GetContractorTasks", "GET", "/api/contractors/1/tasks/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionHandler_NoAuth_Returns401 verifies that subscription handler
|
||||
// endpoints return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
|
||||
handler := NewSubscriptionHandler(subscriptionService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/subscription/", handler.GetSubscription)
|
||||
e.GET("/api/subscription/status/", handler.GetSubscriptionStatus)
|
||||
e.GET("/api/subscription/promotions/", handler.GetPromotions)
|
||||
e.POST("/api/subscription/purchase/", handler.ProcessPurchase)
|
||||
e.POST("/api/subscription/cancel/", handler.CancelSubscription)
|
||||
e.POST("/api/subscription/restore/", handler.RestoreSubscription)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"GetSubscription", "GET", "/api/subscription/"},
|
||||
{"GetSubscriptionStatus", "GET", "/api/subscription/status/"},
|
||||
{"GetPromotions", "GET", "/api/subscription/promotions/"},
|
||||
{"ProcessPurchase", "POST", "/api/subscription/purchase/"},
|
||||
{"CancelSubscription", "POST", "/api/subscription/cancel/"},
|
||||
{"RestoreSubscription", "POST", "/api/subscription/restore/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMediaHandler_NoAuth_Returns401 verifies that media handler endpoints return
|
||||
// 401 Unauthorized when no auth user is set in the context.
|
||||
func TestMediaHandler_NoAuth_Returns401(t *testing.T) {
|
||||
handler := NewMediaHandler(nil, nil, nil, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/media/document/:id", handler.ServeDocument)
|
||||
e.GET("/api/media/document-image/:id", handler.ServeDocumentImage)
|
||||
e.GET("/api/media/completion-image/:id", handler.ServeCompletionImage)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ServeDocument", "GET", "/api/media/document/1"},
|
||||
{"ServeDocumentImage", "GET", "/api/media/document-image/1"},
|
||||
{"ServeCompletionImage", "GET", "/api/media/completion-image/1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserHandler_NoAuth_Returns401 verifies that user handler endpoints return
|
||||
// 401 Unauthorized when no auth user is set in the context.
|
||||
func TestUserHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
userService := services.NewUserService(userRepo)
|
||||
handler := NewUserHandler(userService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/users/", handler.ListUsers)
|
||||
e.GET("/api/users/:id/", handler.GetUser)
|
||||
e.GET("/api/users/profiles/", handler.ListProfiles)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListUsers", "GET", "/api/users/"},
|
||||
{"GetUser", "GET", "/api/users/1/"},
|
||||
{"ListProfiles", "GET", "/api/users/profiles/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -24,7 +23,10 @@ func NewNotificationHandler(notificationService *services.NotificationService) *
|
||||
|
||||
// ListNotifications handles GET /api/notifications/
|
||||
func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
limit := 50
|
||||
offset := 0
|
||||
@@ -33,6 +35,9 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
if limit > 200 {
|
||||
limit = 200
|
||||
}
|
||||
if o := c.QueryParam("offset"); o != "" {
|
||||
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
@@ -52,7 +57,10 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
||||
|
||||
// GetUnreadCount handles GET /api/notifications/unread-count/
|
||||
func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := h.notificationService.GetUnreadCount(user.ID)
|
||||
if err != nil {
|
||||
@@ -64,7 +72,10 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
|
||||
|
||||
// MarkAsRead handles POST /api/notifications/:id/read/
|
||||
func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
notificationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -81,9 +92,12 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
|
||||
|
||||
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
|
||||
func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := h.notificationService.MarkAllAsRead(user.ID)
|
||||
err = h.notificationService.MarkAllAsRead(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -93,7 +107,10 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
|
||||
|
||||
// GetPreferences handles GET /api/notifications/preferences/
|
||||
func (h *NotificationHandler) GetPreferences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefs, err := h.notificationService.GetPreferences(user.ID)
|
||||
if err != nil {
|
||||
@@ -105,12 +122,18 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
|
||||
|
||||
// UpdatePreferences handles PUT/PATCH /api/notifications/preferences/
|
||||
func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.UpdatePreferencesRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -122,12 +145,18 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
|
||||
|
||||
// RegisterDevice handles POST /api/notifications/devices/
|
||||
func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.RegisterDeviceRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
device, err := h.notificationService.RegisterDevice(user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -139,7 +168,10 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
|
||||
|
||||
// ListDevices handles GET /api/notifications/devices/
|
||||
func (h *NotificationHandler) ListDevices(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices, err := h.notificationService.ListDevices(user.ID)
|
||||
if err != nil {
|
||||
@@ -152,7 +184,10 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
|
||||
// UnregisterDevice handles POST /api/notifications/devices/unregister/
|
||||
// Accepts {registration_id, platform} and deactivates the matching device
|
||||
func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req struct {
|
||||
RegistrationID string `json:"registration_id"`
|
||||
@@ -168,7 +203,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
||||
req.Platform = "ios" // Default to iOS
|
||||
}
|
||||
|
||||
err := h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
||||
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -178,7 +213,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
||||
|
||||
// DeleteDevice handles DELETE /api/notifications/devices/:id/
|
||||
func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deviceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
|
||||
88
internal/handlers/notification_handler_test.go
Normal file
88
internal/handlers/notification_handler_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupNotificationHandler(t *testing.T) (*NotificationHandler, *echo.Echo, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
notifService := services.NewNotificationService(notifRepo, nil)
|
||||
handler := NewNotificationHandler(notifService)
|
||||
e := testutil.SetupTestRouter()
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func createTestNotifications(t *testing.T, db *gorm.DB, userID uint, count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
notif := &models.Notification{
|
||||
UserID: userID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: fmt.Sprintf("Test Notification %d", i+1),
|
||||
Body: fmt.Sprintf("Body %d", i+1),
|
||||
}
|
||||
err := db.Create(notif).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationHandler_ListNotifications_LimitCappedAt200(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Create 210 notifications to exceed the cap
|
||||
createTestNotifications(t, db, user.ID, 210)
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", handler.ListNotifications)
|
||||
|
||||
t.Run("limit is capped at 200 when user requests more", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=999", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 200, count, "response should contain at most 200 notifications when limit exceeds cap")
|
||||
})
|
||||
|
||||
t.Run("limit below cap is respected", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=10", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 10, count, "response should respect limit when below cap")
|
||||
})
|
||||
|
||||
t.Run("default limit is used when no limit param", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 50, count, "response should use default limit of 50")
|
||||
})
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/i18n"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/validator"
|
||||
)
|
||||
@@ -35,7 +34,10 @@ func NewResidenceHandler(residenceService *services.ResidenceService, pdfService
|
||||
|
||||
// ListResidences handles GET /api/residences/
|
||||
func (h *ResidenceHandler) ListResidences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.residenceService.ListResidences(user.ID)
|
||||
if err != nil {
|
||||
@@ -47,7 +49,10 @@ func (h *ResidenceHandler) ListResidences(c echo.Context) error {
|
||||
|
||||
// GetMyResidences handles GET /api/residences/my-residences/
|
||||
func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
response, err := h.residenceService.GetMyResidences(user.ID, userNow)
|
||||
@@ -61,7 +66,10 @@ func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
|
||||
// GetSummary handles GET /api/residences/summary/
|
||||
// Returns just the task statistics summary without full residence data
|
||||
func (h *ResidenceHandler) GetSummary(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
summary, err := h.residenceService.GetSummary(user.ID, userNow)
|
||||
@@ -74,7 +82,10 @@ func (h *ResidenceHandler) GetSummary(c echo.Context) error {
|
||||
|
||||
// GetResidence handles GET /api/residences/:id/
|
||||
func (h *ResidenceHandler) GetResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -91,7 +102,10 @@ func (h *ResidenceHandler) GetResidence(c echo.Context) error {
|
||||
|
||||
// CreateResidence handles POST /api/residences/
|
||||
func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req requests.CreateResidenceRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -111,7 +125,10 @@ func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
|
||||
|
||||
// UpdateResidence handles PUT/PATCH /api/residences/:id/
|
||||
func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -136,7 +153,10 @@ func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
|
||||
|
||||
// DeleteResidence handles DELETE /api/residences/:id/
|
||||
func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -154,7 +174,10 @@ func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
|
||||
// GetShareCode handles GET /api/residences/:id/share-code/
|
||||
// Returns the active share code for a residence, or null if none exists
|
||||
func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -175,7 +198,10 @@ func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
|
||||
|
||||
// GenerateShareCode handles POST /api/residences/:id/generate-share-code/
|
||||
func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -197,7 +223,10 @@ func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
|
||||
// GenerateSharePackage handles POST /api/residences/:id/generate-share-package/
|
||||
// Returns a share code with metadata for creating a .casera package file
|
||||
func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -218,12 +247,18 @@ func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
|
||||
|
||||
// JoinWithCode handles POST /api/residences/join-with-code/
|
||||
func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req requests.JoinWithCodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.residenceService.JoinWithCode(req.Code, user.ID)
|
||||
if err != nil {
|
||||
@@ -235,7 +270,10 @@ func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
|
||||
|
||||
// GetResidenceUsers handles GET /api/residences/:id/users/
|
||||
func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -252,7 +290,10 @@ func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
|
||||
|
||||
// RemoveResidenceUser handles DELETE /api/residences/:id/users/:user_id/
|
||||
func (h *ResidenceHandler) RemoveResidenceUser(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -289,7 +330,10 @@ func (h *ResidenceHandler) GenerateTasksReport(c echo.Context) error {
|
||||
return apperrors.BadRequest("error.feature_disabled")
|
||||
}
|
||||
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
|
||||
@@ -525,3 +525,45 @@ func TestResidenceHandler_JSONResponses(t *testing.T) {
|
||||
assert.IsType(t, []map[string]interface{}{}, response)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_CreateResidence_NegativeBedrooms_Returns400(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateResidence)
|
||||
|
||||
t.Run("negative bedrooms rejected", func(t *testing.T) {
|
||||
bedrooms := -1
|
||||
req := requests.CreateResidenceRequest{
|
||||
Name: "Bad House",
|
||||
Bedrooms: &bedrooms,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("negative square footage rejected", func(t *testing.T) {
|
||||
sqft := -100
|
||||
req := requests.CreateResidenceRequest{
|
||||
Name: "Bad House",
|
||||
SquareFootage: &sqft,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("zero bedrooms accepted", func(t *testing.T) {
|
||||
bedrooms := 0
|
||||
req := requests.CreateResidenceRequest{
|
||||
Name: "Studio Apartment",
|
||||
Bedrooms: &bedrooms,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -23,7 +22,10 @@ func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *
|
||||
|
||||
// GetSubscription handles GET /api/subscription/
|
||||
func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.GetSubscription(user.ID)
|
||||
if err != nil {
|
||||
@@ -35,7 +37,10 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
|
||||
|
||||
// GetSubscriptionStatus handles GET /api/subscription/status/
|
||||
func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
|
||||
if err != nil {
|
||||
@@ -79,7 +84,10 @@ func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
|
||||
|
||||
// GetPromotions handles GET /api/subscription/promotions/
|
||||
func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
|
||||
if err != nil {
|
||||
@@ -91,15 +99,20 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
|
||||
|
||||
// ProcessPurchase handles POST /api/subscription/purchase/
|
||||
func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.ProcessPurchaseRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var subscription *services.SubscriptionResponse
|
||||
var err error
|
||||
|
||||
switch req.Platform {
|
||||
case "ios":
|
||||
@@ -129,7 +142,10 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
|
||||
|
||||
// CancelSubscription handles POST /api/subscription/cancel/
|
||||
func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
|
||||
if err != nil {
|
||||
@@ -144,16 +160,21 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
|
||||
|
||||
// RestoreSubscription handles POST /api/subscription/restore/
|
||||
func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.ProcessPurchaseRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Same logic as ProcessPurchase - validates receipt/token and restores
|
||||
var subscription *services.SubscriptionResponse
|
||||
var err error
|
||||
|
||||
switch req.Platform {
|
||||
case "ios":
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
@@ -101,40 +101,39 @@ type AppleRenewalInfo struct {
|
||||
// HandleAppleWebhook handles POST /api/subscription/webhook/apple/
|
||||
func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
if !h.enabled {
|
||||
log.Printf("Apple Webhook: webhooks disabled by feature flag")
|
||||
log.Info().Msg("Apple Webhook: webhooks disabled by feature flag")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to read body: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||
}
|
||||
|
||||
var payload AppleNotificationPayload
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
log.Printf("Apple Webhook: Failed to parse payload: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to parse payload")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid payload"})
|
||||
}
|
||||
|
||||
// Decode and verify the signed payload (JWS)
|
||||
notification, err := h.decodeAppleSignedPayload(payload.SignedPayload)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to decode signed payload: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to decode signed payload")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signed payload"})
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: Received %s (subtype: %s) for bundle %s",
|
||||
notification.NotificationType, notification.Subtype, notification.Data.BundleID)
|
||||
log.Info().Str("type", notification.NotificationType).Str("subtype", notification.Subtype).Str("bundle", notification.Data.BundleID).Msg("Apple Webhook: Received notification")
|
||||
|
||||
// Dedup check using notificationUUID
|
||||
if notification.NotificationUUID != "" {
|
||||
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to check dedup: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to check dedup")
|
||||
// Continue processing on dedup check failure (fail-open)
|
||||
} else if alreadyProcessed {
|
||||
log.Printf("Apple Webhook: Duplicate event %s, skipping", notification.NotificationUUID)
|
||||
log.Info().Str("uuid", notification.NotificationUUID).Msg("Apple Webhook: Duplicate event, skipping")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
|
||||
}
|
||||
}
|
||||
@@ -143,8 +142,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
cfg := config.Get()
|
||||
if cfg != nil && cfg.AppleIAP.BundleID != "" {
|
||||
if notification.Data.BundleID != cfg.AppleIAP.BundleID {
|
||||
log.Printf("Apple Webhook: Bundle ID mismatch: got %s, expected %s",
|
||||
notification.Data.BundleID, cfg.AppleIAP.BundleID)
|
||||
log.Warn().Str("got", notification.Data.BundleID).Str("expected", cfg.AppleIAP.BundleID).Msg("Apple Webhook: Bundle ID mismatch")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "bundle ID mismatch"})
|
||||
}
|
||||
}
|
||||
@@ -152,7 +150,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
// Decode transaction info
|
||||
transactionInfo, err := h.decodeAppleTransaction(notification.Data.SignedTransactionInfo)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to decode transaction: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to decode transaction")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid transaction info"})
|
||||
}
|
||||
|
||||
@@ -164,14 +162,14 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
|
||||
// Process the notification
|
||||
if err := h.processAppleNotification(notification, transactionInfo, renewalInfo); err != nil {
|
||||
log.Printf("Apple Webhook: Failed to process notification: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to process notification")
|
||||
// Still return 200 to prevent Apple from retrying
|
||||
}
|
||||
|
||||
// Record processed event for dedup
|
||||
if notification.NotificationUUID != "" {
|
||||
if err := h.webhookEventRepo.RecordEvent("apple", notification.NotificationUUID, notification.NotificationType, ""); err != nil {
|
||||
log.Printf("Apple Webhook: Failed to record event: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to record event")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +177,8 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
|
||||
}
|
||||
|
||||
// decodeAppleSignedPayload decodes and verifies an Apple JWS payload
|
||||
// decodeAppleSignedPayload verifies and decodes an Apple JWS payload.
|
||||
// The JWS signature is verified before the payload is trusted.
|
||||
func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload string) (*AppleNotificationData, error) {
|
||||
// JWS format: header.payload.signature
|
||||
parts := strings.Split(signedPayload, ".")
|
||||
@@ -187,8 +186,11 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri
|
||||
return nil, fmt.Errorf("invalid JWS format")
|
||||
}
|
||||
|
||||
// Decode payload (we're trusting Apple's signature for now)
|
||||
// In production, you should verify the signature using Apple's root certificate
|
||||
// Verify the JWS signature before trusting the payload.
|
||||
if err := h.VerifyAppleSignature(signedPayload); err != nil {
|
||||
return nil, fmt.Errorf("Apple JWS signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
||||
@@ -251,14 +253,12 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
|
||||
// Find user by stored receipt data (original transaction ID)
|
||||
user, err := h.findUserByAppleTransaction(transaction.OriginalTransactionID)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Could not find user for transaction %s: %v",
|
||||
transaction.OriginalTransactionID, err)
|
||||
log.Warn().Err(err).Str("transaction_id", transaction.OriginalTransactionID).Msg("Apple Webhook: Could not find user for transaction")
|
||||
// Not an error - might be a transaction we don't track
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: Processing %s for user %d (product: %s)",
|
||||
notification.NotificationType, user.ID, transaction.ProductID)
|
||||
log.Info().Str("type", notification.NotificationType).Uint("user_id", user.ID).Str("product", transaction.ProductID).Msg("Apple Webhook: Processing notification")
|
||||
|
||||
switch notification.NotificationType {
|
||||
case "SUBSCRIBED":
|
||||
@@ -294,7 +294,7 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
|
||||
return h.handleAppleGracePeriodExpired(user.ID, transaction)
|
||||
|
||||
default:
|
||||
log.Printf("Apple Webhook: Unhandled notification type: %s", notification.NotificationType)
|
||||
log.Warn().Str("type", notification.NotificationType).Msg("Apple Webhook: Unhandled notification type")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -326,7 +326,7 @@ func (h *SubscriptionWebhookHandler) handleAppleSubscribed(userID uint, tx *Appl
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d subscribed, expires %v, autoRenew=%v", userID, expiresAt, autoRenew)
|
||||
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Bool("auto_renew", autoRenew).Msg("Apple Webhook: User subscribed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -337,7 +337,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewed(userID uint, tx *AppleTr
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d renewed, new expiry %v", userID, expiresAt)
|
||||
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Msg("Apple Webhook: User renewed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -357,13 +357,13 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
|
||||
if err := h.subscriptionRepo.SetCancelledAt(userID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Apple Webhook: User %d turned off auto-renew, will expire at end of period", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned off auto-renew, will expire at end of period")
|
||||
} else {
|
||||
// User turned auto-renew back on
|
||||
if err := h.subscriptionRepo.ClearCancelledAt(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Apple Webhook: User %d turned auto-renew back on", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned auto-renew back on")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -371,7 +371,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *AppleTransactionInfo, renewal *AppleRenewalInfo) error {
|
||||
// Subscription is in billing retry or grace period
|
||||
log.Printf("Apple Webhook: User %d failed to renew, may be in grace period", userID)
|
||||
log.Warn().Uint("user_id", userID).Msg("Apple Webhook: User failed to renew, may be in grace period")
|
||||
// Don't downgrade yet - Apple may retry billing
|
||||
return nil
|
||||
}
|
||||
@@ -381,7 +381,7 @@ func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTr
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d subscription expired, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -390,7 +390,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTra
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d got refund, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTra
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d subscription revoked, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -408,7 +408,7 @@ func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint,
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d grace period expired, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -481,32 +481,32 @@ const (
|
||||
// HandleGoogleWebhook handles POST /api/subscription/webhook/google/
|
||||
func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
if !h.enabled {
|
||||
log.Printf("Google Webhook: webhooks disabled by feature flag")
|
||||
log.Info().Msg("Google Webhook: webhooks disabled by feature flag")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Failed to read body: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||
}
|
||||
|
||||
var notification GoogleNotification
|
||||
if err := json.Unmarshal(body, ¬ification); err != nil {
|
||||
log.Printf("Google Webhook: Failed to parse notification: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to parse notification")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid notification"})
|
||||
}
|
||||
|
||||
// Decode the base64 data
|
||||
data, err := base64.StdEncoding.DecodeString(notification.Message.Data)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Failed to decode message data: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to decode message data")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid message data"})
|
||||
}
|
||||
|
||||
var devNotification GoogleDeveloperNotification
|
||||
if err := json.Unmarshal(data, &devNotification); err != nil {
|
||||
log.Printf("Google Webhook: Failed to parse developer notification: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to parse developer notification")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid developer notification"})
|
||||
}
|
||||
|
||||
@@ -515,17 +515,17 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
if messageID != "" {
|
||||
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Failed to check dedup: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to check dedup")
|
||||
// Continue processing on dedup check failure (fail-open)
|
||||
} else if alreadyProcessed {
|
||||
log.Printf("Google Webhook: Duplicate event %s, skipping", messageID)
|
||||
log.Info().Str("message_id", messageID).Msg("Google Webhook: Duplicate event, skipping")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle test notification
|
||||
if devNotification.TestNotification != nil {
|
||||
log.Printf("Google Webhook: Received test notification")
|
||||
log.Info().Msg("Google Webhook: Received test notification")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "test received"})
|
||||
}
|
||||
|
||||
@@ -533,8 +533,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
cfg := config.Get()
|
||||
if cfg != nil && cfg.GoogleIAP.PackageName != "" {
|
||||
if devNotification.PackageName != cfg.GoogleIAP.PackageName {
|
||||
log.Printf("Google Webhook: Package name mismatch: got %s, expected %s",
|
||||
devNotification.PackageName, cfg.GoogleIAP.PackageName)
|
||||
log.Warn().Str("got", devNotification.PackageName).Str("expected", cfg.GoogleIAP.PackageName).Msg("Google Webhook: Package name mismatch")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "package name mismatch"})
|
||||
}
|
||||
}
|
||||
@@ -542,7 +541,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
// Process subscription notification
|
||||
if devNotification.SubscriptionNotification != nil {
|
||||
if err := h.processGoogleSubscriptionNotification(devNotification.SubscriptionNotification); err != nil {
|
||||
log.Printf("Google Webhook: Failed to process notification: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to process notification")
|
||||
// Still return 200 to acknowledge
|
||||
}
|
||||
}
|
||||
@@ -554,7 +553,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
eventType = fmt.Sprintf("subscription_%d", devNotification.SubscriptionNotification.NotificationType)
|
||||
}
|
||||
if err := h.webhookEventRepo.RecordEvent("google", messageID, eventType, ""); err != nil {
|
||||
log.Printf("Google Webhook: Failed to record event: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to record event")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -567,12 +566,11 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
|
||||
// Find user by purchase token
|
||||
user, err := h.findUserByGoogleToken(notification.PurchaseToken)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Could not find user for token: %v", err)
|
||||
log.Warn().Err(err).Msg("Google Webhook: Could not find user for token")
|
||||
return nil // Not an error - might be unknown token
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: Processing type %d for user %d (subscription: %s)",
|
||||
notification.NotificationType, user.ID, notification.SubscriptionID)
|
||||
log.Info().Int("type", notification.NotificationType).Uint("user_id", user.ID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: Processing notification")
|
||||
|
||||
switch notification.NotificationType {
|
||||
case GoogleSubPurchased:
|
||||
@@ -606,7 +604,7 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
|
||||
return h.handleGooglePaused(user.ID, notification)
|
||||
|
||||
default:
|
||||
log.Printf("Google Webhook: Unhandled notification type: %d", notification.NotificationType)
|
||||
log.Warn().Int("type", notification.NotificationType).Msg("Google Webhook: Unhandled notification type")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -629,7 +627,7 @@ func (h *SubscriptionWebhookHandler) findUserByGoogleToken(purchaseToken string)
|
||||
func (h *SubscriptionWebhookHandler) handleGooglePurchased(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// New subscription - we should have already processed this via the client
|
||||
// This is a backup notification
|
||||
log.Printf("Google Webhook: User %d purchased subscription %s", userID, notification.SubscriptionID)
|
||||
log.Info().Uint("user_id", userID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: User purchased subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -648,7 +646,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRenewed(userID uint, notificati
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d renewed, extended to %v", userID, newExpiry)
|
||||
log.Info().Uint("user_id", userID).Time("expires", newExpiry).Msg("Google Webhook: User renewed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -659,7 +657,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRecovered(userID uint, notifica
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d subscription recovered", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription recovered")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -673,19 +671,19 @@ func (h *SubscriptionWebhookHandler) handleGoogleCanceled(userID uint, notificat
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d canceled, will expire at end of period", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User canceled, will expire at end of period")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleOnHold(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Account hold - payment issue, may recover
|
||||
log.Printf("Google Webhook: User %d subscription on hold", userID)
|
||||
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User subscription on hold")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleGracePeriod(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// In grace period - user still has access but billing failed
|
||||
log.Printf("Google Webhook: User %d in grace period", userID)
|
||||
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User in grace period")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -702,7 +700,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d restarted subscription", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User restarted subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -712,7 +710,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notificati
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d subscription revoked", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -722,13 +720,13 @@ func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notificati
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d subscription expired", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Subscription paused by user
|
||||
log.Printf("Google Webhook: User %d subscription paused", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription paused")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -736,18 +734,21 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
|
||||
// Signature Verification (Optional but Recommended)
|
||||
// ====================
|
||||
|
||||
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate
|
||||
// This is optional but recommended for production
|
||||
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate.
|
||||
// If root certificates are not loaded, verification fails (deny by default).
|
||||
func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string) error {
|
||||
// Load Apple's root certificate if not already loaded
|
||||
// Deny by default when root certificates are not loaded.
|
||||
if h.appleRootCerts == nil {
|
||||
// Apple's root certificates can be downloaded from:
|
||||
// https://www.apple.com/certificateauthority/
|
||||
// You'd typically embed these or load from a file
|
||||
return nil // Skip verification for now
|
||||
return fmt.Errorf("Apple root certificates not configured: cannot verify JWS signature")
|
||||
}
|
||||
|
||||
// Parse the JWS token
|
||||
// Build a certificate pool from the loaded Apple root certificates
|
||||
rootPool := x509.NewCertPool()
|
||||
for _, cert := range h.appleRootCerts {
|
||||
rootPool.AddCert(cert)
|
||||
}
|
||||
|
||||
// Parse the JWS token and verify the signature using the x5c certificate chain
|
||||
token, err := jwt.Parse(signedPayload, func(token *jwt.Token) (interface{}, error) {
|
||||
// Get the x5c header (certificate chain)
|
||||
x5c, ok := token.Header["x5c"].([]interface{})
|
||||
@@ -755,21 +756,46 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
|
||||
return nil, fmt.Errorf("missing x5c header")
|
||||
}
|
||||
|
||||
// Decode the first certificate (leaf)
|
||||
// Decode the leaf certificate
|
||||
certData, err := base64.StdEncoding.DecodeString(x5c[0].(string))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certData)
|
||||
leafCert, err := x509.ParseCertificate(certData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
// Verify the certificate chain (simplified)
|
||||
// In production, you should verify the full chain
|
||||
// Build intermediate pool from remaining x5c entries
|
||||
intermediatePool := x509.NewCertPool()
|
||||
for i := 1; i < len(x5c); i++ {
|
||||
intermData, err := base64.StdEncoding.DecodeString(x5c[i].(string))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode intermediate certificate: %w", err)
|
||||
}
|
||||
intermCert, err := x509.ParseCertificate(intermData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse intermediate certificate: %w", err)
|
||||
}
|
||||
intermediatePool.AddCert(intermCert)
|
||||
}
|
||||
|
||||
return cert.PublicKey.(*ecdsa.PublicKey), nil
|
||||
// Verify the certificate chain against Apple's root certificates
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootPool,
|
||||
Intermediates: intermediatePool,
|
||||
}
|
||||
if _, err := leafCert.Verify(opts); err != nil {
|
||||
return nil, fmt.Errorf("certificate chain verification failed: %w", err)
|
||||
}
|
||||
|
||||
ecdsaKey, ok := leafCert.PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("leaf certificate public key is not ECDSA")
|
||||
}
|
||||
|
||||
return ecdsaKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -783,13 +809,58 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyGooglePubSubToken verifies the Pub/Sub push token (if configured)
|
||||
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
|
||||
// Returns false (deny) when the Authorization header is missing or the token
|
||||
// cannot be validated. This prevents unauthenticated callers from injecting
|
||||
// webhook events.
|
||||
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
|
||||
// If you configured a push endpoint with authentication, verify here
|
||||
// The token is typically in the Authorization header
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
log.Warn().Msg("Google Webhook: missing Authorization header")
|
||||
return false
|
||||
}
|
||||
|
||||
// Expect "Bearer <token>" format
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
log.Warn().Msg("Google Webhook: Authorization header is not Bearer token")
|
||||
return false
|
||||
}
|
||||
|
||||
bearerToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if bearerToken == "" {
|
||||
log.Warn().Msg("Google Webhook: empty Bearer token")
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs
|
||||
// issued by accounts.google.com. We verify the claims to ensure the
|
||||
// token was intended for our service.
|
||||
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token")
|
||||
return false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
log.Warn().Msg("Google Webhook: invalid token claims")
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify issuer is Google
|
||||
issuer, _ := claims.GetIssuer()
|
||||
if issuer != "accounts.google.com" && issuer != "https://accounts.google.com" {
|
||||
log.Warn().Str("issuer", issuer).Msg("Google Webhook: unexpected issuer")
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the email claim matches a Google service account
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !strings.HasSuffix(email, ".gserviceaccount.com") {
|
||||
log.Warn().Str("email", email).Msg("Google Webhook: token email is not a Google service account")
|
||||
return false
|
||||
}
|
||||
|
||||
// For now, we rely on the endpoint being protected by your infrastructure
|
||||
// (e.g., only accessible from Google's IP ranges)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
56
internal/handlers/subscription_webhook_handler_test.go
Normal file
56
internal/handlers/subscription_webhook_handler_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVerifyGooglePubSubToken_MissingAuth_ReturnsFalse(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
e := echo.New()
|
||||
// Request with no Authorization header
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
result := handler.VerifyGooglePubSubToken(c)
|
||||
assert.False(t, result, "VerifyGooglePubSubToken should return false when Authorization header is missing")
|
||||
}
|
||||
|
||||
func TestVerifyGooglePubSubToken_InvalidToken_ReturnsFalse(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-garbage-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
result := handler.VerifyGooglePubSubToken(c)
|
||||
assert.False(t, result, "VerifyGooglePubSubToken should return false for an invalid/unverifiable token")
|
||||
}
|
||||
|
||||
func TestDecodeAppleSignedPayload_InvalidJWS_ReturnsError(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
// No signature parts
|
||||
_, err := handler.decodeAppleSignedPayload("not-a-jws")
|
||||
assert.Error(t, err, "should reject payload that is not valid JWS format")
|
||||
}
|
||||
|
||||
func TestDecodeAppleSignedPayload_VerificationFails_ReturnsError(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
// Construct a JWS-shaped string with 3 parts but no valid signature.
|
||||
// The handler should now attempt verification and fail.
|
||||
// header.payload.signature -- all base64url garbage
|
||||
fakeJWS := "eyJhbGciOiJFUzI1NiJ9.eyJ0ZXN0IjoidHJ1ZSJ9.invalidsig"
|
||||
|
||||
_, err := handler.decodeAppleSignedPayload(fakeJWS)
|
||||
assert.Error(t, err, "should return error when Apple signature verification fails")
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -32,13 +31,16 @@ func NewTaskHandler(taskService *services.TaskService, storageService *services.
|
||||
|
||||
// ListTasks handles GET /api/tasks/
|
||||
func (h *TaskHandler) ListTasks(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
|
||||
// This runs in a goroutine to avoid blocking the response
|
||||
// Runs synchronously — this is a lightweight DB upsert that should complete quickly
|
||||
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
|
||||
go h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
||||
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
||||
}
|
||||
|
||||
daysThreshold := 30
|
||||
@@ -62,7 +64,10 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
|
||||
|
||||
// GetTask handles GET /api/tasks/:id/
|
||||
func (h *TaskHandler) GetTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -77,7 +82,10 @@ func (h *TaskHandler) GetTask(c echo.Context) error {
|
||||
|
||||
// GetTasksByResidence handles GET /api/tasks/by-residence/:residence_id/
|
||||
func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
|
||||
@@ -106,13 +114,19 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
|
||||
|
||||
// CreateTask handles POST /api/tasks/
|
||||
func (h *TaskHandler) CreateTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
var req requests.CreateTaskRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.CreateTask(&req, user.ID, userNow)
|
||||
if err != nil {
|
||||
@@ -123,7 +137,10 @@ func (h *TaskHandler) CreateTask(c echo.Context) error {
|
||||
|
||||
// UpdateTask handles PUT/PATCH /api/tasks/:id/
|
||||
func (h *TaskHandler) UpdateTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -135,6 +152,9 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.UpdateTask(uint(taskID), user.ID, &req, userNow)
|
||||
if err != nil {
|
||||
@@ -145,7 +165,10 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
|
||||
|
||||
// DeleteTask handles DELETE /api/tasks/:id/
|
||||
func (h *TaskHandler) DeleteTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -160,7 +183,10 @@ func (h *TaskHandler) DeleteTask(c echo.Context) error {
|
||||
|
||||
// MarkInProgress handles POST /api/tasks/:id/mark-in-progress/
|
||||
func (h *TaskHandler) MarkInProgress(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -177,7 +203,10 @@ func (h *TaskHandler) MarkInProgress(c echo.Context) error {
|
||||
|
||||
// CancelTask handles POST /api/tasks/:id/cancel/
|
||||
func (h *TaskHandler) CancelTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -194,7 +223,10 @@ func (h *TaskHandler) CancelTask(c echo.Context) error {
|
||||
|
||||
// UncancelTask handles POST /api/tasks/:id/uncancel/
|
||||
func (h *TaskHandler) UncancelTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -211,7 +243,10 @@ func (h *TaskHandler) UncancelTask(c echo.Context) error {
|
||||
|
||||
// ArchiveTask handles POST /api/tasks/:id/archive/
|
||||
func (h *TaskHandler) ArchiveTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -228,7 +263,10 @@ func (h *TaskHandler) ArchiveTask(c echo.Context) error {
|
||||
|
||||
// UnarchiveTask handles POST /api/tasks/:id/unarchive/
|
||||
func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -246,7 +284,10 @@ func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
|
||||
// QuickComplete handles POST /api/tasks/:id/quick-complete/
|
||||
// Lightweight endpoint for widget - just returns 200 OK on success
|
||||
func (h *TaskHandler) QuickComplete(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -263,7 +304,10 @@ func (h *TaskHandler) QuickComplete(c echo.Context) error {
|
||||
|
||||
// GetTaskCompletions handles GET /api/tasks/:id/completions/
|
||||
func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -278,7 +322,10 @@ func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
|
||||
|
||||
// ListCompletions handles GET /api/task-completions/
|
||||
func (h *TaskHandler) ListCompletions(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := h.taskService.ListCompletions(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -288,7 +335,10 @@ func (h *TaskHandler) ListCompletions(c echo.Context) error {
|
||||
|
||||
// GetCompletion handles GET /api/task-completions/:id/
|
||||
func (h *TaskHandler) GetCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_completion_id")
|
||||
@@ -304,7 +354,10 @@ func (h *TaskHandler) GetCompletion(c echo.Context) error {
|
||||
// CreateCompletion handles POST /api/task-completions/
|
||||
// Supports both JSON and multipart form data (for image uploads)
|
||||
func (h *TaskHandler) CreateCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
var req requests.CreateTaskCompletionRequest
|
||||
@@ -367,6 +420,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.CreateCompletion(&req, user.ID, userNow)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -376,7 +433,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
|
||||
|
||||
// UpdateCompletion handles PUT /api/task-completions/:id/
|
||||
func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_completion_id")
|
||||
@@ -386,6 +446,9 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.UpdateCompletion(uint(completionID), user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -396,7 +459,10 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
|
||||
|
||||
// DeleteCompletion handles DELETE /api/task-completions/:id/
|
||||
func (h *TaskHandler) DeleteCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_completion_id")
|
||||
|
||||
@@ -506,6 +506,52 @@ func TestTaskHandler_CreateCompletion(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateCompletion_Rating6_Returns400(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Rate Me")
|
||||
|
||||
authGroup := e.Group("/api/task-completions")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateCompletion)
|
||||
|
||||
t.Run("rating out of bounds rejected", func(t *testing.T) {
|
||||
rating := 6
|
||||
req := requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("rating zero rejected", func(t *testing.T) {
|
||||
rating := 0
|
||||
req := requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("rating 5 accepted", func(t *testing.T) {
|
||||
rating := 5
|
||||
completedAt := time.Now().UTC()
|
||||
req := requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
CompletedAt: &completedAt,
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_ListCompletions(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
@@ -603,6 +649,71 @@ func TestTaskHandler_DeleteCompletion(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateTask_EmptyTitle_Returns400(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/tasks")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateTask)
|
||||
|
||||
t.Run("empty body returns 400 with validation errors", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/tasks/", map[string]interface{}{}, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should contain structured validation error
|
||||
assert.Contains(t, response, "error")
|
||||
assert.Contains(t, response, "fields")
|
||||
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
|
||||
assert.Contains(t, fields, "title", "validation error should reference 'title'")
|
||||
})
|
||||
|
||||
t.Run("missing title returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"residence_id": residence.ID,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "fields")
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "title", "validation error should reference 'title'")
|
||||
})
|
||||
|
||||
t.Run("missing residence_id returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"title": "Test Task",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "fields")
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_GetLookups(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
@@ -32,7 +33,14 @@ func (h *TrackingHandler) TrackEmailOpen(c echo.Context) error {
|
||||
if trackingID != "" && h.onboardingService != nil {
|
||||
// Record the open (async, don't block response)
|
||||
go func() {
|
||||
_ = h.onboardingService.RecordEmailOpened(trackingID)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("tracking_id", trackingID).Msg("Panic in email open tracking goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.onboardingService.RecordEmailOpened(trackingID); err != nil {
|
||||
log.Error().Err(err).Str("tracking_id", trackingID).Msg("Failed to record email open")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -73,17 +76,38 @@ func (h *UploadHandler) UploadCompletion(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// DeleteFileRequest is the request body for deleting a file.
|
||||
type DeleteFileRequest struct {
|
||||
URL string `json:"url" validate:"required"`
|
||||
}
|
||||
|
||||
// DeleteFile handles DELETE /api/uploads
|
||||
// Expects JSON body with "url" field
|
||||
// Expects JSON body with "url" field.
|
||||
//
|
||||
// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete
|
||||
// any file if they know the URL. The upload system does not track which user uploaded
|
||||
// which file, so a proper fix requires adding an uploads table or file ownership metadata.
|
||||
// For now, deletions are logged with user ID for audit trail, and StorageService.Delete
|
||||
// enforces path containment to prevent deleting files outside the upload directory.
|
||||
func (h *UploadHandler) DeleteFile(c echo.Context) error {
|
||||
var req struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
}
|
||||
var req DeleteFileRequest
|
||||
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return apperrors.BadRequest("error.url_required")
|
||||
}
|
||||
|
||||
// Log the deletion with user ID for audit trail
|
||||
if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok {
|
||||
log.Info().
|
||||
Uint("user_id", user.ID).
|
||||
Str("file_url", req.URL).
|
||||
Msg("File deletion requested")
|
||||
}
|
||||
|
||||
if err := h.storageService.Delete(req.URL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
43
internal/handlers/upload_handler_test.go
Normal file
43
internal/handlers/upload_handler_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/i18n"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize i18n so the custom error handler can localize error messages.
|
||||
// Other handler tests get this from testutil.SetupTestDB, but these tests
|
||||
// don't need a database.
|
||||
i18n.Init()
|
||||
}
|
||||
|
||||
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
|
||||
// Use a test storage service — DeleteFile won't reach storage since validation fails first
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
handler := NewUploadHandler(storageSvc)
|
||||
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register route
|
||||
e.DELETE("/api/uploads/", handler.DeleteFile)
|
||||
|
||||
// Send request with empty JSON body (url field missing)
|
||||
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
handler := NewUploadHandler(storageSvc)
|
||||
|
||||
e := testutil.SetupTestRouter()
|
||||
e.DELETE("/api/uploads/", handler.DeleteFile)
|
||||
|
||||
// Send request with empty url field
|
||||
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -26,7 +25,10 @@ func NewUserHandler(userService *services.UserService) *UserHandler {
|
||||
|
||||
// ListUsers handles GET /api/users/
|
||||
func (h *UserHandler) ListUsers(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only allow listing users that share residences with the current user
|
||||
users, err := h.userService.ListUsersInSharedResidences(user.ID)
|
||||
@@ -42,7 +44,10 @@ func (h *UserHandler) ListUsers(c echo.Context) error {
|
||||
|
||||
// GetUser handles GET /api/users/:id/
|
||||
func (h *UserHandler) GetUser(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -60,7 +65,10 @@ func (h *UserHandler) GetUser(c echo.Context) error {
|
||||
|
||||
// ListProfiles handles GET /api/users/profiles/
|
||||
func (h *UserHandler) ListProfiles(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// List profiles of users in shared residences
|
||||
profiles, err := h.userService.ListProfilesInSharedResidences(user.ID)
|
||||
|
||||
633
internal/integration/security_regression_test.go
Normal file
633
internal/integration/security_regression_test.go
Normal file
@@ -0,0 +1,633 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/admin/dto"
|
||||
adminhandlers "github.com/treytartt/casera-api/internal/admin/handlers"
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/handlers"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
"github.com/treytartt/casera-api/internal/validator"
|
||||
)
|
||||
|
||||
// ============ Security Regression Test App ============
|
||||
|
||||
// SecurityTestApp holds components for security regression integration testing.
|
||||
type SecurityTestApp struct {
|
||||
DB *gorm.DB
|
||||
Router *echo.Echo
|
||||
SubscriptionService *services.SubscriptionService
|
||||
SubscriptionRepo *repositories.SubscriptionRepository
|
||||
}
|
||||
|
||||
func setupSecurityTest(t *testing.T) *SecurityTestApp {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
// Create repositories
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
notificationRepo := repositories.NewNotificationRepository(db)
|
||||
|
||||
// Create config
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret-key-for-security-tests",
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
MaxPasswordResetRate: 3,
|
||||
},
|
||||
}
|
||||
|
||||
// Create services
|
||||
authService := services.NewAuthService(userRepo, cfg)
|
||||
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
|
||||
taskService := services.NewTaskService(taskRepo, residenceRepo)
|
||||
notificationService := services.NewNotificationService(notificationRepo, nil)
|
||||
|
||||
// Wire up subscription service for tier limit enforcement
|
||||
residenceService.SetSubscriptionService(subscriptionService)
|
||||
|
||||
// Create handlers
|
||||
authHandler := handlers.NewAuthHandler(authService, nil, nil)
|
||||
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
|
||||
taskHandler := handlers.NewTaskHandler(taskService, nil)
|
||||
contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo))
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
// Create router with real middleware
|
||||
e := echo.New()
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
|
||||
|
||||
e.Use(middleware.TimezoneMiddleware())
|
||||
|
||||
// Public routes
|
||||
auth := e.Group("/api/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
}
|
||||
|
||||
// Protected routes
|
||||
authMiddleware := middleware.NewAuthMiddleware(db, nil)
|
||||
api := e.Group("/api")
|
||||
api.Use(authMiddleware.TokenAuth())
|
||||
{
|
||||
api.GET("/auth/me", authHandler.CurrentUser)
|
||||
api.POST("/auth/logout", authHandler.Logout)
|
||||
|
||||
residences := api.Group("/residences")
|
||||
{
|
||||
residences.GET("", residenceHandler.ListResidences)
|
||||
residences.POST("", residenceHandler.CreateResidence)
|
||||
residences.GET("/:id", residenceHandler.GetResidence)
|
||||
residences.PUT("/:id", residenceHandler.UpdateResidence)
|
||||
residences.DELETE("/:id", residenceHandler.DeleteResidence)
|
||||
}
|
||||
|
||||
tasks := api.Group("/tasks")
|
||||
{
|
||||
tasks.GET("", taskHandler.ListTasks)
|
||||
tasks.POST("", taskHandler.CreateTask)
|
||||
tasks.GET("/:id", taskHandler.GetTask)
|
||||
tasks.PUT("/:id", taskHandler.UpdateTask)
|
||||
tasks.DELETE("/:id", taskHandler.DeleteTask)
|
||||
}
|
||||
|
||||
completions := api.Group("/completions")
|
||||
{
|
||||
completions.GET("", taskHandler.ListCompletions)
|
||||
completions.POST("", taskHandler.CreateCompletion)
|
||||
completions.GET("/:id", taskHandler.GetCompletion)
|
||||
completions.DELETE("/:id", taskHandler.DeleteCompletion)
|
||||
}
|
||||
|
||||
contractors := api.Group("/contractors")
|
||||
{
|
||||
contractors.GET("", contractorHandler.ListContractors)
|
||||
contractors.POST("", contractorHandler.CreateContractor)
|
||||
contractors.GET("/:id", contractorHandler.GetContractor)
|
||||
}
|
||||
|
||||
subscription := api.Group("/subscription")
|
||||
{
|
||||
subscription.GET("/", subscriptionHandler.GetSubscription)
|
||||
subscription.GET("/status/", subscriptionHandler.GetSubscriptionStatus)
|
||||
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
|
||||
}
|
||||
|
||||
notifications := api.Group("/notifications")
|
||||
{
|
||||
notifications.GET("", notificationHandler.ListNotifications)
|
||||
}
|
||||
}
|
||||
|
||||
return &SecurityTestApp{
|
||||
DB: db,
|
||||
Router: e,
|
||||
SubscriptionService: subscriptionService,
|
||||
SubscriptionRepo: subscriptionRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// registerAndLoginSec registers and logs in a user, returns token and user ID.
|
||||
func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, password string) (string, uint) {
|
||||
// Register
|
||||
registerBody := map[string]string{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
}
|
||||
w := app.makeAuthReq(t, "POST", "/api/auth/register", registerBody, "")
|
||||
require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed for %s", username)
|
||||
|
||||
// Login
|
||||
loginBody := map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
w = app.makeAuthReq(t, "POST", "/api/auth/login", loginBody, "")
|
||||
require.Equal(t, http.StatusOK, w.Code, "Login should succeed for %s", username)
|
||||
|
||||
var loginResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := loginResp["token"].(string)
|
||||
userMap := loginResp["user"].(map[string]interface{})
|
||||
userID := uint(userMap["id"].(float64))
|
||||
|
||||
return token, userID
|
||||
}
|
||||
|
||||
// makeAuthReq creates and sends an HTTP request through the router.
|
||||
func (app *SecurityTestApp) makeAuthReq(t *testing.T, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
|
||||
var reqBody []byte
|
||||
var err error
|
||||
if body != nil {
|
||||
reqBody, err = json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Token "+token)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
app.Router.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
// ============ Test 1: Path Traversal Blocked ============
|
||||
|
||||
// TestE2E_PathTraversal_AllMediaEndpoints_Blocked verifies that the SafeResolvePath
|
||||
// function (used by all media endpoints) blocks path traversal attempts.
|
||||
// A document with a traversal URL like ../../../etc/passwd cannot be used to read
|
||||
// arbitrary files from the filesystem.
|
||||
func TestE2E_PathTraversal_AllMediaEndpoints_Blocked(t *testing.T) {
|
||||
// Test the SafeResolvePath function that guards all three media endpoints:
|
||||
// ServeDocument, ServeDocumentImage, ServeCompletionImage
|
||||
// Each calls resolveFilePath -> SafeResolvePath to validate containment.
|
||||
|
||||
traversalPaths := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"simple dotdot", "../../../etc/passwd"},
|
||||
{"nested dotdot", "../../etc/shadow"},
|
||||
{"embedded dotdot", "images/../../../../../../etc/passwd"},
|
||||
{"deep traversal", "a/b/c/../../../../etc/passwd"},
|
||||
{"uploads prefix with dotdot", "../../../etc/passwd"},
|
||||
}
|
||||
|
||||
for _, tt := range traversalPaths {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// SafeResolvePath must reject all traversal attempts
|
||||
_, err := services.SafeResolvePath("/var/uploads", tt.url)
|
||||
assert.Error(t, err, "Path traversal should be blocked for: %s", tt.url)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify that a legitimate path still works
|
||||
t.Run("legitimate_path_allowed", func(t *testing.T) {
|
||||
result, err := services.SafeResolvePath("/var/uploads", "documents/file.pdf")
|
||||
assert.NoError(t, err, "Legitimate path should be allowed")
|
||||
assert.Equal(t, "/var/uploads/documents/file.pdf", result)
|
||||
})
|
||||
|
||||
// Verify absolute paths are blocked
|
||||
t.Run("absolute_path_blocked", func(t *testing.T) {
|
||||
_, err := services.SafeResolvePath("/var/uploads", "/etc/passwd")
|
||||
assert.Error(t, err, "Absolute paths should be blocked")
|
||||
})
|
||||
|
||||
// Verify empty paths are blocked
|
||||
t.Run("empty_path_blocked", func(t *testing.T) {
|
||||
_, err := services.SafeResolvePath("/var/uploads", "")
|
||||
assert.Error(t, err, "Empty paths should be blocked")
|
||||
})
|
||||
}
|
||||
|
||||
// ============ Test 2: SQL Injection in Admin Sort ============
|
||||
|
||||
// TestE2E_SQLInjection_AdminSort_Blocked verifies that the admin user list endpoint
|
||||
// uses the allowlist-based sort column sanitization and does not execute injected SQL.
|
||||
func TestE2E_SQLInjection_AdminSort_Blocked(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
// Create admin user handler which uses the sort_by parameter
|
||||
adminUserHandler := adminhandlers.NewAdminUserHandler(db)
|
||||
|
||||
// Create a couple of test users to have data to sort
|
||||
testutil.CreateTestUser(t, db, "alice", "alice@test.com", "password123")
|
||||
testutil.CreateTestUser(t, db, "bob", "bob@test.com", "password123")
|
||||
|
||||
// Set up a minimal Echo instance with the admin handler
|
||||
e := echo.New()
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
|
||||
e.GET("/api/admin/users", adminUserHandler.List)
|
||||
|
||||
injections := []struct {
|
||||
name string
|
||||
sortBy string
|
||||
}{
|
||||
{"DROP TABLE", "created_at; DROP TABLE auth_user; --"},
|
||||
{"UNION SELECT", "id UNION SELECT password FROM auth_user"},
|
||||
{"subquery", "(SELECT password FROM auth_user LIMIT 1)"},
|
||||
{"OR 1=1", "created_at OR 1=1"},
|
||||
{"semicolon", "created_at;"},
|
||||
{"single quotes", "name'; DROP TABLE auth_user; --"},
|
||||
}
|
||||
|
||||
for _, tt := range injections {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := fmt.Sprintf("/api/admin/users?sort_by=%s", tt.sortBy)
|
||||
w := testutil.MakeRequest(e, "GET", path, nil, "")
|
||||
|
||||
// Handler should return 200 (using safe default sort), NOT 500
|
||||
assert.Equal(t, http.StatusOK, w.Code,
|
||||
"Admin user list should succeed with safe default sort, not crash from injection: %s", tt.sortBy)
|
||||
|
||||
// Parse response to verify valid paginated data
|
||||
var resp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.NoError(t, err, "Response should be valid JSON")
|
||||
|
||||
// Verify the auth_user table still exists (not dropped)
|
||||
var count int64
|
||||
dbErr := db.Model(&models.User{}).Count(&count).Error
|
||||
assert.NoError(t, dbErr, "auth_user table should still exist after injection attempt")
|
||||
assert.GreaterOrEqual(t, count, int64(2), "Users should still be in the database")
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the DTO allowlist directly
|
||||
t.Run("DTO_GetSafeSortBy_rejects_injection", func(t *testing.T) {
|
||||
p := dto.PaginationParams{SortBy: "created_at; DROP TABLE auth_user; --"}
|
||||
result := p.GetSafeSortBy([]string{"id", "username", "email", "date_joined"}, "date_joined")
|
||||
assert.Equal(t, "date_joined", result, "Injection should fall back to default column")
|
||||
})
|
||||
}
|
||||
|
||||
// ============ Test 3: IAP Invalid Receipt Does Not Grant Pro ============
|
||||
|
||||
// TestE2E_IAP_InvalidReceipt_NoPro verifies that submitting a purchase with
|
||||
// garbage receipt data does NOT upgrade the user to Pro tier.
|
||||
func TestE2E_IAP_InvalidReceipt_NoPro(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, userID := app.registerAndLoginSec(t, "iapuser", "iap@test.com", "password123")
|
||||
|
||||
// Create initial subscription (free tier)
|
||||
sub := &models.UserSubscription{UserID: userID, Tier: models.TierFree}
|
||||
require.NoError(t, app.DB.Create(sub).Error)
|
||||
|
||||
// Submit a purchase with garbage receipt data
|
||||
purchaseBody := map[string]interface{}{
|
||||
"platform": "ios",
|
||||
"receipt_data": "GARBAGE_RECEIPT_DATA_THAT_IS_NOT_VALID",
|
||||
}
|
||||
w := app.makeAuthReq(t, "POST", "/api/subscription/purchase/", purchaseBody, token)
|
||||
|
||||
// The purchase should fail (Apple client is nil in test environment)
|
||||
assert.NotEqual(t, http.StatusOK, w.Code,
|
||||
"Purchase with garbage receipt should NOT succeed")
|
||||
|
||||
// Verify user is still on free tier
|
||||
updatedSub, err := app.SubscriptionRepo.GetOrCreate(userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier,
|
||||
"User should remain on free tier after invalid receipt submission")
|
||||
}
|
||||
|
||||
// ============ Test 4: Completion Transaction Atomicity ============
|
||||
|
||||
// TestE2E_CompletionTransaction_Atomic verifies that creating a task completion
|
||||
// updates both the completion record and the task's NextDueDate together (P1-5/P1-6).
|
||||
func TestE2E_CompletionTransaction_Atomic(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, _ := app.registerAndLoginSec(t, "atomicuser", "atomic@test.com", "password123")
|
||||
|
||||
// Create a residence
|
||||
residenceBody := map[string]interface{}{"name": "Atomic Test House"}
|
||||
w := app.makeAuthReq(t, "POST", "/api/residences", residenceBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var residenceResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &residenceResp)
|
||||
residenceData := residenceResp["data"].(map[string]interface{})
|
||||
residenceID := residenceData["id"].(float64)
|
||||
|
||||
// Create a one-time task with a due date
|
||||
dueDate := time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")
|
||||
taskBody := map[string]interface{}{
|
||||
"residence_id": uint(residenceID),
|
||||
"title": "One-Time Atomic Task",
|
||||
"due_date": dueDate,
|
||||
}
|
||||
w = app.makeAuthReq(t, "POST", "/api/tasks", taskBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var taskResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskResp)
|
||||
taskData := taskResp["data"].(map[string]interface{})
|
||||
taskID := taskData["id"].(float64)
|
||||
|
||||
// Verify task has a next_due_date before completion
|
||||
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var taskBefore map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskBefore)
|
||||
assert.NotNil(t, taskBefore["next_due_date"], "Task should have next_due_date before completion")
|
||||
|
||||
// Create completion
|
||||
completionBody := map[string]interface{}{
|
||||
"task_id": uint(taskID),
|
||||
"notes": "Completed for atomicity test",
|
||||
}
|
||||
w = app.makeAuthReq(t, "POST", "/api/completions", completionBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var completionResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &completionResp)
|
||||
completionData := completionResp["data"].(map[string]interface{})
|
||||
completionID := completionData["id"].(float64)
|
||||
assert.NotZero(t, completionID, "Completion should be created with valid ID")
|
||||
|
||||
// Verify task is now completed (next_due_date should be nil for one-time task)
|
||||
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var taskAfter map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskAfter)
|
||||
assert.Nil(t, taskAfter["next_due_date"],
|
||||
"One-time task should have nil next_due_date after completion (atomic update)")
|
||||
assert.Equal(t, "completed_tasks", taskAfter["kanban_column"],
|
||||
"Task should be in completed column after completion")
|
||||
|
||||
// Verify completion record exists
|
||||
w = app.makeAuthReq(t, "GET", "/api/completions/"+formatID(completionID), nil, token)
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Completion record should exist")
|
||||
}
|
||||
|
||||
// ============ Test 5: Delete Completion Recalculates NextDueDate ============
|
||||
|
||||
// TestE2E_DeleteCompletion_RecalculatesNextDueDate verifies that deleting a completion
|
||||
// on a recurring task recalculates NextDueDate back to the correct value (P1-7).
|
||||
func TestE2E_DeleteCompletion_RecalculatesNextDueDate(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, _ := app.registerAndLoginSec(t, "recuruser", "recur@test.com", "password123")
|
||||
|
||||
// Create a residence
|
||||
residenceBody := map[string]interface{}{"name": "Recurring Test House"}
|
||||
w := app.makeAuthReq(t, "POST", "/api/residences", residenceBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var residenceResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &residenceResp)
|
||||
residenceData := residenceResp["data"].(map[string]interface{})
|
||||
residenceID := residenceData["id"].(float64)
|
||||
|
||||
// Get the "Weekly" frequency ID from the database
|
||||
var weeklyFreq models.TaskFrequency
|
||||
err := app.DB.Where("name = ?", "Weekly").First(&weeklyFreq).Error
|
||||
require.NoError(t, err, "Weekly frequency should exist from seed data")
|
||||
|
||||
// Create a recurring (weekly) task with a due date
|
||||
dueDate := time.Now().Add(-1 * 24 * time.Hour).Format("2006-01-02")
|
||||
taskBody := map[string]interface{}{
|
||||
"residence_id": uint(residenceID),
|
||||
"title": "Weekly Recurring Task",
|
||||
"frequency_id": weeklyFreq.ID,
|
||||
"due_date": dueDate,
|
||||
}
|
||||
w = app.makeAuthReq(t, "POST", "/api/tasks", taskBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var taskResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskResp)
|
||||
taskData := taskResp["data"].(map[string]interface{})
|
||||
taskID := taskData["id"].(float64)
|
||||
|
||||
// Record original next_due_date
|
||||
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var taskOriginal map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskOriginal)
|
||||
originalNextDueDate := taskOriginal["next_due_date"]
|
||||
require.NotNil(t, originalNextDueDate, "Recurring task should have initial next_due_date")
|
||||
|
||||
// Create a completion (should advance NextDueDate by 7 days from completion date)
|
||||
completionBody := map[string]interface{}{
|
||||
"task_id": uint(taskID),
|
||||
"notes": "Weekly completion",
|
||||
}
|
||||
w = app.makeAuthReq(t, "POST", "/api/completions", completionBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var completionResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &completionResp)
|
||||
completionData := completionResp["data"].(map[string]interface{})
|
||||
completionID := completionData["id"].(float64)
|
||||
|
||||
// Verify NextDueDate advanced
|
||||
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var taskAfterCompletion map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskAfterCompletion)
|
||||
advancedNextDueDate := taskAfterCompletion["next_due_date"]
|
||||
assert.NotNil(t, advancedNextDueDate, "Recurring task should still have next_due_date after completion")
|
||||
assert.NotEqual(t, originalNextDueDate, advancedNextDueDate,
|
||||
"NextDueDate should have advanced after completion")
|
||||
|
||||
// Delete the completion
|
||||
w = app.makeAuthReq(t, "DELETE", "/api/completions/"+formatID(completionID), nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify NextDueDate was recalculated back to original due date
|
||||
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var taskAfterDelete map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &taskAfterDelete)
|
||||
restoredNextDueDate := taskAfterDelete["next_due_date"]
|
||||
|
||||
// After deleting the only completion, NextDueDate should be restored to the original DueDate
|
||||
assert.NotNil(t, restoredNextDueDate, "NextDueDate should be restored after deleting the only completion")
|
||||
assert.Equal(t, originalNextDueDate, restoredNextDueDate,
|
||||
"NextDueDate should be recalculated back to original due date after completion deletion")
|
||||
}
|
||||
|
||||
// ============ Test 6: Tier Limits Enforced ============
|
||||
|
||||
// TestE2E_TierLimits_Enforced verifies that a free-tier user cannot exceed the
|
||||
// configured property limit.
|
||||
func TestE2E_TierLimits_Enforced(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, userID := app.registerAndLoginSec(t, "tieruser", "tier@test.com", "password123")
|
||||
|
||||
// Enable global limitations
|
||||
app.DB.Where("1=1").Delete(&models.SubscriptionSettings{})
|
||||
settings := &models.SubscriptionSettings{EnableLimitations: true}
|
||||
require.NoError(t, app.DB.Create(settings).Error)
|
||||
|
||||
// Set free tier limit to 1 property
|
||||
one := 1
|
||||
app.DB.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{})
|
||||
freeLimits := &models.TierLimits{
|
||||
Tier: models.TierFree,
|
||||
PropertiesLimit: &one,
|
||||
}
|
||||
require.NoError(t, app.DB.Create(freeLimits).Error)
|
||||
|
||||
// Ensure user is on free tier
|
||||
sub, err := app.SubscriptionRepo.GetOrCreate(userID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, models.TierFree, sub.Tier)
|
||||
|
||||
// First residence should succeed
|
||||
residenceBody := map[string]interface{}{"name": "First Property"}
|
||||
w := app.makeAuthReq(t, "POST", "/api/residences", residenceBody, token)
|
||||
require.Equal(t, http.StatusCreated, w.Code, "First residence should be allowed within limit")
|
||||
|
||||
// Second residence should be blocked
|
||||
residenceBody2 := map[string]interface{}{"name": "Second Property (over limit)"}
|
||||
w = app.makeAuthReq(t, "POST", "/api/residences", residenceBody2, token)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code,
|
||||
"Second residence should be blocked by tier limit")
|
||||
|
||||
// Verify error response
|
||||
var errResp map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &errResp)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, fmt.Sprintf("%v", errResp), "limit",
|
||||
"Error response should reference the limit")
|
||||
}
|
||||
|
||||
// ============ Test 7: Auth Assertion -- No Panics on Missing User ============
|
||||
|
||||
// TestE2E_AuthAssertion_NoPanics verifies that all protected endpoints return
|
||||
// 401 Unauthorized (not 500 panic) when no auth token is provided.
|
||||
func TestE2E_AuthAssertion_NoPanics(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
|
||||
// Make requests to protected endpoints WITHOUT any token.
|
||||
endpoints := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListTasks", "GET", "/api/tasks"},
|
||||
{"CreateTask", "POST", "/api/tasks"},
|
||||
{"GetTask", "GET", "/api/tasks/1"},
|
||||
{"ListResidences", "GET", "/api/residences"},
|
||||
{"CreateResidence", "POST", "/api/residences"},
|
||||
{"GetResidence", "GET", "/api/residences/1"},
|
||||
{"ListCompletions", "GET", "/api/completions"},
|
||||
{"CreateCompletion", "POST", "/api/completions"},
|
||||
{"ListContractors", "GET", "/api/contractors"},
|
||||
{"CreateContractor", "POST", "/api/contractors"},
|
||||
{"GetSubscription", "GET", "/api/subscription/"},
|
||||
{"SubscriptionStatus", "GET", "/api/subscription/status/"},
|
||||
{"ProcessPurchase", "POST", "/api/subscription/purchase/"},
|
||||
{"ListNotifications", "GET", "/api/notifications"},
|
||||
{"CurrentUser", "GET", "/api/auth/me"},
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
t.Run(ep.name, func(t *testing.T) {
|
||||
w := app.makeAuthReq(t, ep.method, ep.path, nil, "")
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code,
|
||||
"Endpoint %s %s should return 401, not panic with 500", ep.method, ep.path)
|
||||
})
|
||||
}
|
||||
|
||||
// Also test with an invalid token (should be 401, not 500)
|
||||
t.Run("InvalidToken", func(t *testing.T) {
|
||||
w := app.makeAuthReq(t, "GET", "/api/tasks", nil, "completely-invalid-token-xyz")
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code,
|
||||
"Invalid token should return 401, not panic")
|
||||
})
|
||||
}
|
||||
|
||||
// ============ Test 8: Notification Limit Capped ============
|
||||
|
||||
// TestE2E_NotificationLimit_Capped verifies that the notification list endpoint
|
||||
// caps the limit parameter to 200 even if the client requests more.
|
||||
func TestE2E_NotificationLimit_Capped(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, userID := app.registerAndLoginSec(t, "notifuser", "notif@test.com", "password123")
|
||||
|
||||
// Create 210 notifications directly in the database
|
||||
for i := 0; i < 210; i++ {
|
||||
notification := &models.Notification{
|
||||
UserID: userID,
|
||||
NotificationType: models.NotificationTaskCompleted,
|
||||
Title: fmt.Sprintf("Test Notification %d", i),
|
||||
Body: fmt.Sprintf("Body for notification %d", i),
|
||||
}
|
||||
require.NoError(t, app.DB.Create(notification).Error)
|
||||
}
|
||||
|
||||
// Request with limit=999 (should be capped to 200 by the handler)
|
||||
w := app.makeAuthReq(t, "GET", "/api/notifications?limit=999", nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var notifResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), ¬ifResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(notifResp["count"].(float64))
|
||||
assert.LessOrEqual(t, count, 200,
|
||||
"Notification count should be capped at 200 even when requesting limit=999")
|
||||
|
||||
results := notifResp["results"].([]interface{})
|
||||
assert.LessOrEqual(t, len(results), 200,
|
||||
"Notification results should have at most 200 items")
|
||||
}
|
||||
@@ -35,7 +35,9 @@ func AdminAuthMiddleware(cfg *config.Config, adminRepo *repositories.AdminReposi
|
||||
return func(c echo.Context) error {
|
||||
var tokenString string
|
||||
|
||||
// Get token from Authorization header
|
||||
// Get token from Authorization header only.
|
||||
// Query parameter authentication is intentionally not supported
|
||||
// because tokens in URLs leak into server logs and browser history.
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
// Check Bearer prefix
|
||||
@@ -45,11 +47,6 @@ func AdminAuthMiddleware(cfg *config.Config, adminRepo *repositories.AdminReposi
|
||||
}
|
||||
}
|
||||
|
||||
// If no header token, check query parameter (for WebSocket connections)
|
||||
if tokenString == "" {
|
||||
tokenString = c.QueryParam("token")
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Authorization required"})
|
||||
}
|
||||
@@ -121,7 +118,10 @@ func RequireSuperAdmin() echo.MiddlewareFunc {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Admin authentication required"})
|
||||
}
|
||||
|
||||
adminUser := admin.(*models.AdminUser)
|
||||
adminUser, ok := admin.(*models.AdminUser)
|
||||
if !ok {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Admin authentication required"})
|
||||
}
|
||||
if !adminUser.IsSuperAdmin() {
|
||||
return c.JSON(http.StatusForbidden, map[string]interface{}{"error": "Super admin privileges required"})
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
// Cache miss - look up token in database
|
||||
user, err = m.getUserFromDatabase(token)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("token", token[:8]+"...").Msg("Token authentication failed")
|
||||
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
||||
return apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
@@ -200,13 +200,18 @@ func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) erro
|
||||
return m.cache.InvalidateAuthToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetAuthUser retrieves the authenticated user from the Echo context
|
||||
// GetAuthUser retrieves the authenticated user from the Echo context.
|
||||
// Returns nil if the context value is missing or not of the expected type.
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
user := c.Get(AuthUserKey)
|
||||
if user == nil {
|
||||
val := c.Get(AuthUserKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return user.(*models.User)
|
||||
user, ok := val.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// GetAuthToken retrieves the auth token from the Echo context
|
||||
@@ -226,3 +231,12 @@ func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// truncateToken safely truncates a token string for logging.
|
||||
// Returns at most the first 8 characters followed by "...".
|
||||
func truncateToken(token string) string {
|
||||
if len(token) > 8 {
|
||||
return token[:8] + "..."
|
||||
}
|
||||
return token + "..."
|
||||
}
|
||||
|
||||
119
internal/middleware/auth_safety_test.go
Normal file
119
internal/middleware/auth_safety_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
)
|
||||
|
||||
func TestGetAuthUser_NilContext_ReturnsNil(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No user set in context
|
||||
user := GetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
func TestGetAuthUser_WrongType_ReturnsNil(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Set wrong type in context — should NOT panic
|
||||
c.Set(AuthUserKey, "not-a-user")
|
||||
user := GetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
func TestGetAuthUser_ValidUser_ReturnsUser(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
expected := &models.User{Username: "testuser"}
|
||||
c.Set(AuthUserKey, expected)
|
||||
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "testuser", user.Username)
|
||||
}
|
||||
|
||||
func TestMustGetAuthUser_Nil_Returns401(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
user, err := MustGetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMustGetAuthUser_WrongType_Returns401(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthUserKey, 12345)
|
||||
user, err := MustGetAuthUser(c)
|
||||
assert.Nil(t, user)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTokenTruncation_ShortToken_NoPanic(t *testing.T) {
|
||||
// Ensure truncateToken does not panic on short tokens
|
||||
assert.NotPanics(t, func() {
|
||||
result := truncateToken("ab")
|
||||
assert.Equal(t, "ab...", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenTruncation_EmptyToken_NoPanic(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
result := truncateToken("")
|
||||
assert.Equal(t, "...", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenTruncation_LongToken_Truncated(t *testing.T) {
|
||||
result := truncateToken("abcdefghijklmnop")
|
||||
assert.Equal(t, "abcdefgh...", result)
|
||||
}
|
||||
|
||||
func TestAdminAuth_QueryParamToken_Rejected(t *testing.T) {
|
||||
// SEC-20: Admin JWT via query parameter must be rejected.
|
||||
// Tokens in URLs leak into server logs and browser history.
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "should not reach here")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
|
||||
// Request with token only in query param, no Authorization header
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test?token=some-jwt-token", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err) // handler writes JSON directly, no Echo error
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "query param token must be rejected")
|
||||
assert.Contains(t, rec.Body.String(), "Authorization required")
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// validRequestID matches alphanumeric characters and hyphens, 1-64 chars.
|
||||
var validRequestID = regexp.MustCompile(`^[a-zA-Z0-9\-]{1,64}$`)
|
||||
|
||||
const (
|
||||
// HeaderXRequestID is the header key for request correlation IDs
|
||||
HeaderXRequestID = "X-Request-ID"
|
||||
@@ -17,9 +22,11 @@ const (
|
||||
func RequestIDMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Use existing request ID from header if present, otherwise generate one
|
||||
// Use existing request ID from header if present and valid, otherwise generate one.
|
||||
// Sanitize to alphanumeric + hyphens only (max 64 chars) to prevent
|
||||
// log injection via control characters or overly long values.
|
||||
reqID := c.Request().Header.Get(HeaderXRequestID)
|
||||
if reqID == "" {
|
||||
if reqID == "" || !validRequestID.MatchString(reqID) {
|
||||
reqID = uuid.New().String()
|
||||
}
|
||||
|
||||
|
||||
125
internal/middleware/request_id_test.go
Normal file
125
internal/middleware/request_id_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRequestID_ValidID_Preserved(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, "abc-123-def")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "abc-123-def", rec.Body.String())
|
||||
assert.Equal(t, "abc-123-def", rec.Header().Get(HeaderXRequestID))
|
||||
}
|
||||
|
||||
func TestRequestID_Empty_GeneratesNew(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// No X-Request-ID header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
// Should be a UUID (36 chars: 8-4-4-4-12)
|
||||
assert.Len(t, rec.Body.String(), 36)
|
||||
}
|
||||
|
||||
func TestRequestID_ControlChars_Sanitized(t *testing.T) {
|
||||
// SEC-29: Client-supplied X-Request-ID with control characters must be rejected.
|
||||
tests := []struct {
|
||||
name string
|
||||
inputID string
|
||||
}{
|
||||
{"newline injection", "abc\ndef"},
|
||||
{"carriage return", "abc\rdef"},
|
||||
{"null byte", "abc\x00def"},
|
||||
{"tab character", "abc\tdef"},
|
||||
{"html tags", "abc<script>alert(1)</script>"},
|
||||
{"spaces", "abc def"},
|
||||
{"semicolons", "abc;def"},
|
||||
{"unicode", "abc\u200bdef"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, tt.inputID)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
// The malicious ID should be replaced with a generated UUID
|
||||
assert.NotEqual(t, tt.inputID, rec.Body.String(),
|
||||
"control chars should be rejected, got original ID back")
|
||||
assert.Len(t, rec.Body.String(), 36, "should be a generated UUID")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_TooLong_Sanitized(t *testing.T) {
|
||||
// SEC-29: X-Request-ID longer than 64 chars should be rejected.
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
longID := strings.Repeat("a", 65)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, longID)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, longID, rec.Body.String(), "overly long ID should be replaced")
|
||||
assert.Len(t, rec.Body.String(), 36, "should be a generated UUID")
|
||||
}
|
||||
|
||||
func TestRequestID_MaxLength_Accepted(t *testing.T) {
|
||||
// Exactly 64 chars of valid characters should be accepted
|
||||
e := echo.New()
|
||||
mw := RequestIDMiddleware()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, GetRequestID(c))
|
||||
})
|
||||
|
||||
maxID := strings.Repeat("a", 64)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(HeaderXRequestID, maxID)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxID, rec.Body.String(), "64-char valid ID should be accepted")
|
||||
}
|
||||
19
internal/middleware/sanitize.go
Normal file
19
internal/middleware/sanitize.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import "strings"
|
||||
|
||||
// SanitizeSortColumn validates a user-supplied sort column against an allowlist.
|
||||
// Returns defaultCol if the input is empty or not in the allowlist.
|
||||
// This prevents SQL injection via ORDER BY clauses.
|
||||
func SanitizeSortColumn(input string, allowedCols []string, defaultCol string) string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return defaultCol
|
||||
}
|
||||
for _, col := range allowedCols {
|
||||
if strings.EqualFold(input, col) {
|
||||
return col
|
||||
}
|
||||
}
|
||||
return defaultCol
|
||||
}
|
||||
59
internal/middleware/sanitize_test.go
Normal file
59
internal/middleware/sanitize_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSanitizeSortColumn_AllowedColumn_Passes(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("created_at", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_CaseInsensitive(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("Created_At", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_SQLInjection_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"drop table", "created_at; DROP TABLE auth_user; --"},
|
||||
{"union select", "name UNION SELECT * FROM auth_user"},
|
||||
{"or 1=1", "name OR 1=1"},
|
||||
{"semicolon", "created_at;"},
|
||||
{"subquery", "(SELECT password FROM auth_user LIMIT 1)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeSortColumn(tt.input, allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result, "SQL injection attempt should return default")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_Empty_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_Whitespace_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn(" ", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
|
||||
func TestSanitizeSortColumn_UnknownColumn_ReturnsDefault(t *testing.T) {
|
||||
allowed := []string{"created_at", "updated_at", "name"}
|
||||
result := SanitizeSortColumn("nonexistent_column", allowed, "created_at")
|
||||
assert.Equal(t, "created_at", result)
|
||||
}
|
||||
@@ -79,22 +79,30 @@ func parseTimezone(tz string) *time.Location {
|
||||
}
|
||||
|
||||
// GetUserTimezone retrieves the user's timezone from the Echo context.
|
||||
// Returns UTC if not set.
|
||||
// Returns UTC if not set or if the stored value is not a *time.Location.
|
||||
func GetUserTimezone(c echo.Context) *time.Location {
|
||||
loc := c.Get(TimezoneKey)
|
||||
if loc == nil {
|
||||
val := c.Get(TimezoneKey)
|
||||
if val == nil {
|
||||
return time.UTC
|
||||
}
|
||||
return loc.(*time.Location)
|
||||
loc, ok := val.(*time.Location)
|
||||
if !ok {
|
||||
return time.UTC
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
// GetUserNow retrieves the timezone-aware "now" time from the Echo context.
|
||||
// This represents the start of the current day in the user's timezone.
|
||||
// Returns time.Now().UTC() if not set.
|
||||
// Returns time.Now().UTC() if not set or if the stored value is not a time.Time.
|
||||
func GetUserNow(c echo.Context) time.Time {
|
||||
now := c.Get(UserNowKey)
|
||||
if now == nil {
|
||||
val := c.Get(UserNowKey)
|
||||
if val == nil {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return now.(time.Time)
|
||||
now, ok := val.(time.Time)
|
||||
if !ok {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
return now
|
||||
}
|
||||
|
||||
@@ -52,14 +52,18 @@ func (c *Collector) Collect() SystemStats {
|
||||
// CPU stats
|
||||
c.collectCPU(&stats)
|
||||
|
||||
// Read Go runtime memory stats once (used by both memory and runtime collectors)
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
// Memory stats (system + Go runtime)
|
||||
c.collectMemory(&stats)
|
||||
c.collectMemory(&stats, &memStats)
|
||||
|
||||
// Disk stats
|
||||
c.collectDisk(&stats)
|
||||
|
||||
// Go runtime stats
|
||||
c.collectRuntime(&stats)
|
||||
c.collectRuntime(&stats, &memStats)
|
||||
|
||||
// HTTP stats (API only)
|
||||
if c.httpCollector != nil {
|
||||
@@ -77,9 +81,9 @@ func (c *Collector) Collect() SystemStats {
|
||||
}
|
||||
|
||||
func (c *Collector) collectCPU(stats *SystemStats) {
|
||||
// Get CPU usage percentage (blocks for 1 second to get accurate sample)
|
||||
// Shorter intervals can give inaccurate readings
|
||||
if cpuPercent, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercent) > 0 {
|
||||
// Get CPU usage percentage (blocks for 200ms to sample)
|
||||
// This is called periodically, so a shorter window is acceptable
|
||||
if cpuPercent, err := cpu.Percent(200*time.Millisecond, false); err == nil && len(cpuPercent) > 0 {
|
||||
stats.CPU.UsagePercent = cpuPercent[0]
|
||||
}
|
||||
|
||||
@@ -93,7 +97,7 @@ func (c *Collector) collectCPU(stats *SystemStats) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Collector) collectMemory(stats *SystemStats) {
|
||||
func (c *Collector) collectMemory(stats *SystemStats, m *runtime.MemStats) {
|
||||
// System memory
|
||||
if vmem, err := mem.VirtualMemory(); err == nil {
|
||||
stats.Memory.UsedBytes = vmem.Used
|
||||
@@ -101,9 +105,7 @@ func (c *Collector) collectMemory(stats *SystemStats) {
|
||||
stats.Memory.UsagePercent = vmem.UsedPercent
|
||||
}
|
||||
|
||||
// Go runtime memory
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
// Go runtime memory (reuses pre-read MemStats)
|
||||
stats.Memory.HeapAlloc = m.HeapAlloc
|
||||
stats.Memory.HeapSys = m.HeapSys
|
||||
stats.Memory.HeapInuse = m.HeapInuse
|
||||
@@ -119,10 +121,7 @@ func (c *Collector) collectDisk(stats *SystemStats) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Collector) collectRuntime(stats *SystemStats) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
func (c *Collector) collectRuntime(stats *SystemStats, m *runtime.MemStats) {
|
||||
stats.Runtime.Goroutines = runtime.NumGoroutine()
|
||||
stats.Runtime.NumGC = m.NumGC
|
||||
if m.NumGC > 0 {
|
||||
|
||||
@@ -17,8 +17,13 @@ var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// Allow connections from admin panel
|
||||
return true
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
// Same-origin requests may omit the Origin header
|
||||
return true
|
||||
}
|
||||
// Allow if origin matches the request host
|
||||
return strings.HasPrefix(origin, "https://"+r.Host) || strings.HasPrefix(origin, "http://"+r.Host)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,6 +121,7 @@ func (h *Handler) WebSocket(c echo.Context) error {
|
||||
conn, err := upgrader.Upgrade(c.Response().Writer, c.Request(), nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to upgrade WebSocket connection")
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
@@ -174,6 +180,7 @@ func (h *Handler) WebSocket(c echo.Context) error {
|
||||
h.sendStats(conn, &wsMu)
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +108,10 @@ func (s *Service) Stop() {
|
||||
close(s.settingsStopCh)
|
||||
|
||||
s.collector.Stop()
|
||||
|
||||
// Flush and close the log writer's background goroutine
|
||||
s.logWriter.Close()
|
||||
|
||||
log.Info().Str("process", s.process).Msg("Monitoring service stopped")
|
||||
}
|
||||
|
||||
|
||||
@@ -8,23 +8,56 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// RedisLogWriter implements io.Writer to capture zerolog output to Redis
|
||||
const (
|
||||
// writerChannelSize is the buffer size for the async log write channel.
|
||||
// Entries beyond this limit are dropped to prevent unbounded memory growth.
|
||||
writerChannelSize = 256
|
||||
)
|
||||
|
||||
// RedisLogWriter implements io.Writer to capture zerolog output to Redis.
|
||||
// It uses a single background goroutine with a buffered channel instead of
|
||||
// spawning a new goroutine per log line, preventing unbounded goroutine growth.
|
||||
type RedisLogWriter struct {
|
||||
buffer *LogBuffer
|
||||
process string
|
||||
enabled atomic.Bool
|
||||
ch chan LogEntry
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisLogWriter creates a new writer that captures logs to Redis
|
||||
// NewRedisLogWriter creates a new writer that captures logs to Redis.
|
||||
// It starts a single background goroutine that drains the buffered channel.
|
||||
func NewRedisLogWriter(buffer *LogBuffer, process string) *RedisLogWriter {
|
||||
w := &RedisLogWriter{
|
||||
buffer: buffer,
|
||||
process: process,
|
||||
ch: make(chan LogEntry, writerChannelSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
w.enabled.Store(true) // enabled by default
|
||||
|
||||
// Single background goroutine drains the channel
|
||||
go w.drainLoop()
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// drainLoop reads entries from the buffered channel and pushes them to Redis.
|
||||
// It runs in a single goroutine for the lifetime of the writer.
|
||||
func (w *RedisLogWriter) drainLoop() {
|
||||
defer close(w.done)
|
||||
for entry := range w.ch {
|
||||
_ = w.buffer.Push(entry) // Ignore errors to avoid blocking log output
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the background goroutine. It should be called during
|
||||
// graceful shutdown to ensure all buffered entries are flushed.
|
||||
func (w *RedisLogWriter) Close() {
|
||||
close(w.ch)
|
||||
<-w.done // Wait for drain to finish
|
||||
}
|
||||
|
||||
// SetEnabled enables or disables log capture to Redis
|
||||
func (w *RedisLogWriter) SetEnabled(enabled bool) {
|
||||
w.enabled.Store(enabled)
|
||||
@@ -35,8 +68,10 @@ func (w *RedisLogWriter) IsEnabled() bool {
|
||||
return w.enabled.Load()
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface
|
||||
// It parses zerolog JSON output and writes to Redis asynchronously
|
||||
// Write implements io.Writer interface.
|
||||
// It parses zerolog JSON output and sends it to the buffered channel for
|
||||
// async Redis writes. If the channel is full, the entry is dropped to
|
||||
// avoid blocking the caller (back-pressure shedding).
|
||||
func (w *RedisLogWriter) Write(p []byte) (n int, err error) {
|
||||
// Skip if monitoring is disabled
|
||||
if !w.enabled.Load() {
|
||||
@@ -86,10 +121,14 @@ func (w *RedisLogWriter) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Write to Redis asynchronously to avoid blocking
|
||||
go func() {
|
||||
_ = w.buffer.Push(entry) // Ignore errors to avoid blocking log output
|
||||
}()
|
||||
// Non-blocking send: drop entries if channel is full rather than
|
||||
// spawning unbounded goroutines or blocking the logger
|
||||
select {
|
||||
case w.ch <- entry:
|
||||
// Sent successfully
|
||||
default:
|
||||
// Channel full — drop this entry to avoid back-pressure on the logger
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -117,6 +117,9 @@ func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message st
|
||||
|
||||
// Log individual results
|
||||
for i, result := range fcmResp.Results {
|
||||
if i >= len(tokens) {
|
||||
break
|
||||
}
|
||||
if result.Error != "" {
|
||||
log.Error().
|
||||
Str("token", truncateToken(tokens[i])).
|
||||
|
||||
186
internal/push/fcm_test.go
Normal file
186
internal/push/fcm_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newTestFCMClient creates an FCMClient pointing at the given test server URL.
|
||||
func newTestFCMClient(serverURL string) *FCMClient {
|
||||
return &FCMClient{
|
||||
serverKey: "test-server-key",
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
}
|
||||
|
||||
// serveFCMResponse creates an httptest.Server that returns the given FCMResponse as JSON.
|
||||
func serveFCMResponse(t *testing.T, resp FCMResponse) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
}
|
||||
|
||||
// sendWithEndpoint is a helper that sends an FCM notification using a custom endpoint
|
||||
// (the test server) instead of the real FCM endpoint. This avoids modifying the
|
||||
// production code to be testable and instead temporarily overrides the client's HTTP
|
||||
// transport to redirect requests to our test server.
|
||||
func sendWithEndpoint(client *FCMClient, server *httptest.Server, ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
||||
// Override the HTTP client to redirect all requests to the test server
|
||||
client.httpClient = server.Client()
|
||||
|
||||
// We need to intercept the request and redirect it to our test server.
|
||||
// Use a custom RoundTripper that rewrites the URL.
|
||||
originalTransport := server.Client().Transport
|
||||
client.httpClient.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the URL to point to the test server
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = server.Listener.Addr().String()
|
||||
if originalTransport != nil {
|
||||
return originalTransport.RoundTrip(req)
|
||||
}
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
})
|
||||
|
||||
return client.Send(ctx, tokens, title, message, data)
|
||||
}
|
||||
|
||||
// roundTripFunc is a function that implements http.RoundTripper.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestFCMSend_MoreResultsThanTokens_NoPanic(t *testing.T) {
|
||||
// FCM returns 5 results but we only sent 2 tokens.
|
||||
// Before the bounds check fix, this would panic with index out of range.
|
||||
fcmResp := FCMResponse{
|
||||
MulticastID: 12345,
|
||||
Success: 2,
|
||||
Failure: 3,
|
||||
Results: []FCMResult{
|
||||
{MessageID: "msg1"},
|
||||
{MessageID: "msg2"},
|
||||
{Error: "InvalidRegistration"},
|
||||
{Error: "NotRegistered"},
|
||||
{Error: "InvalidRegistration"},
|
||||
},
|
||||
}
|
||||
|
||||
server := serveFCMResponse(t, fcmResp)
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
tokens := []string{"token-aaa-111", "token-bbb-222"}
|
||||
|
||||
// This must not panic
|
||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFCMSend_FewerResultsThanTokens_NoPanic(t *testing.T) {
|
||||
// FCM returns fewer results than tokens we sent.
|
||||
// This is also a malformed response but should not panic.
|
||||
fcmResp := FCMResponse{
|
||||
MulticastID: 12345,
|
||||
Success: 1,
|
||||
Failure: 0,
|
||||
Results: []FCMResult{
|
||||
{MessageID: "msg1"},
|
||||
},
|
||||
}
|
||||
|
||||
server := serveFCMResponse(t, fcmResp)
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"}
|
||||
|
||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFCMSend_EmptyResponse_NoPanic(t *testing.T) {
|
||||
// FCM returns an empty Results slice.
|
||||
fcmResp := FCMResponse{
|
||||
MulticastID: 12345,
|
||||
Success: 0,
|
||||
Failure: 0,
|
||||
Results: []FCMResult{},
|
||||
}
|
||||
|
||||
server := serveFCMResponse(t, fcmResp)
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
tokens := []string{"token-aaa-111"}
|
||||
|
||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
||||
// No panic expected. The function returns nil because fcmResp.Success == 0
|
||||
// and fcmResp.Failure == 0 (the "all failed" check requires Failure > 0).
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFCMSend_NilResultsSlice_NoPanic(t *testing.T) {
|
||||
// FCM returns a response with nil Results (e.g., malformed JSON).
|
||||
fcmResp := FCMResponse{
|
||||
MulticastID: 12345,
|
||||
Success: 0,
|
||||
Failure: 1,
|
||||
}
|
||||
|
||||
server := serveFCMResponse(t, fcmResp)
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
tokens := []string{"token-aaa-111"}
|
||||
|
||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
||||
// Should return error because Success == 0 and Failure > 0
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "all FCM notifications failed")
|
||||
}
|
||||
|
||||
func TestFCMSend_EmptyTokens_ReturnsNil(t *testing.T) {
|
||||
// Verify the early return for empty tokens.
|
||||
client := &FCMClient{
|
||||
serverKey: "test-key",
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
err := client.Send(context.Background(), []string{}, "Test", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFCMSend_ResultsWithErrorsMatchTokens(t *testing.T) {
|
||||
// Normal case: results count matches tokens count, all with errors.
|
||||
fcmResp := FCMResponse{
|
||||
MulticastID: 12345,
|
||||
Success: 0,
|
||||
Failure: 2,
|
||||
Results: []FCMResult{
|
||||
{Error: "InvalidRegistration"},
|
||||
{Error: "NotRegistered"},
|
||||
},
|
||||
}
|
||||
|
||||
server := serveFCMResponse(t, fcmResp)
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
tokens := []string{"token-aaa-111", "token-bbb-222"}
|
||||
|
||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "all FCM notifications failed")
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func (r *ContractorRepository) FindByUser(userID uint, residenceIDs []uint) ([]m
|
||||
query = query.Where("residence_id IS NULL AND created_by_id = ?", userID)
|
||||
}
|
||||
|
||||
err := query.Order("is_favorite DESC, name ASC").Find(&contractors).Error
|
||||
err := query.Order("is_favorite DESC, name ASC").Limit(500).Find(&contractors).Error
|
||||
return contractors, err
|
||||
}
|
||||
|
||||
@@ -85,18 +85,31 @@ func (r *ContractorRepository) Delete(id uint) error {
|
||||
Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// ToggleFavorite toggles the favorite status of a contractor
|
||||
// ToggleFavorite toggles the favorite status of a contractor atomically.
|
||||
// Uses a single UPDATE with NOT to avoid read-then-write race conditions.
|
||||
// Only toggles active contractors to prevent toggling soft-deleted records.
|
||||
func (r *ContractorRepository) ToggleFavorite(id uint) (bool, error) {
|
||||
var contractor models.Contractor
|
||||
if err := r.db.First(&contractor, id).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
newStatus := !contractor.IsFavorite
|
||||
err := r.db.Model(&models.Contractor{}).
|
||||
Where("id = ?", id).
|
||||
Update("is_favorite", newStatus).Error
|
||||
var newStatus bool
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Atomic toggle: SET is_favorite = NOT is_favorite for active contractors only
|
||||
result := tx.Model(&models.Contractor{}).
|
||||
Where("id = ? AND is_active = ?", id, true).
|
||||
Update("is_favorite", gorm.Expr("NOT is_favorite"))
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
// Read back the new value within the same transaction
|
||||
var contractor models.Contractor
|
||||
if err := tx.Select("is_favorite").First(&contractor, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
newStatus = contractor.IsFavorite
|
||||
return nil
|
||||
})
|
||||
return newStatus, err
|
||||
}
|
||||
|
||||
@@ -145,6 +158,19 @@ func (r *ContractorRepository) CountByResidence(residenceID uint) (int64, error)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByResidenceIDs counts all active contractors across multiple residences in a single query.
|
||||
// Returns the total count of active contractors for the given residence IDs.
|
||||
func (r *ContractorRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
|
||||
if len(residenceIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
err := r.db.Model(&models.Contractor{}).
|
||||
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// === Specialty Operations ===
|
||||
|
||||
// GetAllSpecialties returns all contractor specialties
|
||||
|
||||
96
internal/repositories/contractor_repo_test.go
Normal file
96
internal/repositories/contractor_repo_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestToggleFavorite_Active_Toggles(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
// Initially is_favorite is false
|
||||
assert.False(t, contractor.IsFavorite, "contractor should start as not favorite")
|
||||
|
||||
// First toggle: false -> true
|
||||
newStatus, err := repo.ToggleFavorite(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, newStatus, "first toggle should set favorite to true")
|
||||
|
||||
// Verify in database
|
||||
var found models.Contractor
|
||||
err = db.First(&found, contractor.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.IsFavorite, "database should reflect favorite = true")
|
||||
|
||||
// Second toggle: true -> false
|
||||
newStatus, err = repo.ToggleFavorite(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, newStatus, "second toggle should set favorite to false")
|
||||
|
||||
// Verify in database
|
||||
err = db.First(&found, contractor.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.IsFavorite, "database should reflect favorite = false")
|
||||
}
|
||||
|
||||
func TestToggleFavorite_SoftDeleted_ReturnsError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Deleted Contractor")
|
||||
|
||||
// Soft-delete the contractor
|
||||
err := db.Model(&models.Contractor{}).
|
||||
Where("id = ?", contractor.ID).
|
||||
Update("is_active", false).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Toggling a soft-deleted contractor should fail (record not found)
|
||||
_, err = repo.ToggleFavorite(contractor.ID)
|
||||
assert.Error(t, err, "toggling a soft-deleted contractor should return an error")
|
||||
}
|
||||
|
||||
func TestToggleFavorite_NonExistent_ReturnsError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
_, err := repo.ToggleFavorite(99999)
|
||||
assert.Error(t, err, "toggling a non-existent contractor should return an error")
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByUser_HasDefaultLimit(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create 510 contractors to exceed the default limit of 500
|
||||
for i := 0; i < 510; i++ {
|
||||
c := &models.Contractor{
|
||||
ResidenceID: &residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Name: fmt.Sprintf("Contractor %d", i+1),
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(c).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
contractors, err := repo.FindByUser(user.ID, []uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, len(contractors), "FindByUser should return at most 500 contractors by default")
|
||||
}
|
||||
@@ -52,7 +52,8 @@ func (r *DocumentRepository) FindByResidence(residenceID uint) ([]models.Documen
|
||||
return documents, err
|
||||
}
|
||||
|
||||
// FindByUser finds all documents accessible to a user
|
||||
// FindByUser finds all documents accessible to a user.
|
||||
// A default limit of 500 is applied to prevent unbounded result sets.
|
||||
func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, error) {
|
||||
var documents []models.Document
|
||||
err := r.db.Preload("CreatedBy").
|
||||
@@ -60,6 +61,7 @@ func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document,
|
||||
Preload("Images").
|
||||
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
|
||||
Order("created_at DESC").
|
||||
Limit(500).
|
||||
Find(&documents).Error
|
||||
return documents, err
|
||||
}
|
||||
@@ -89,7 +91,8 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc
|
||||
query = query.Where("expiry_date IS NOT NULL AND expiry_date > ? AND expiry_date <= ?", now, threshold)
|
||||
}
|
||||
if filter.Search != "" {
|
||||
searchPattern := "%" + filter.Search + "%"
|
||||
escaped := escapeLikeWildcards(filter.Search)
|
||||
searchPattern := "%" + escaped + "%"
|
||||
query = query.Where("(title ILIKE ? OR description ILIKE ?)", searchPattern, searchPattern)
|
||||
}
|
||||
}
|
||||
@@ -169,6 +172,19 @@ func (r *DocumentRepository) CountByResidence(residenceID uint) (int64, error) {
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByResidenceIDs counts all active documents across multiple residences in a single query.
|
||||
// Returns the total count of active documents for the given residence IDs.
|
||||
func (r *DocumentRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
|
||||
if len(residenceIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
err := r.db.Model(&models.Document{}).
|
||||
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// FindByIDIncludingInactive finds a document by ID including inactive ones
|
||||
func (r *DocumentRepository) FindByIDIncludingInactive(id uint, document *models.Document) error {
|
||||
return r.db.Preload("CreatedBy").Preload("Images").First(document, id).Error
|
||||
|
||||
38
internal/repositories/document_repo_test.go
Normal file
38
internal/repositories/document_repo_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestDocumentRepository_FindByUser_HasDefaultLimit(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create 510 documents to exceed the default limit of 500
|
||||
for i := 0; i < 510; i++ {
|
||||
doc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: fmt.Sprintf("Doc %d", i+1),
|
||||
DocumentType: models.DocumentTypeGeneral,
|
||||
FileURL: "https://example.com/doc.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(doc).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
docs, err := repo.FindByUser([]uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, len(docs), "FindByUser should return at most 500 documents by default")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -130,18 +131,25 @@ func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPre
|
||||
|
||||
// UpdatePreferences updates notification preferences
|
||||
func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error {
|
||||
return r.db.Save(prefs).Error
|
||||
return r.db.Omit("User").Save(prefs).Error
|
||||
}
|
||||
|
||||
// GetOrCreatePreferences gets or creates notification preferences for a user
|
||||
// GetOrCreatePreferences gets or creates notification preferences for a user.
|
||||
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
|
||||
func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) {
|
||||
prefs, err := r.FindPreferencesByUser(userID)
|
||||
if err == nil {
|
||||
return prefs, nil
|
||||
}
|
||||
var prefs models.NotificationPreference
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
prefs = &models.NotificationPreference{
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Where("user_id = ?", userID).First(&prefs).Error
|
||||
if err == nil {
|
||||
return nil // Found existing preferences
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err // Unexpected error
|
||||
}
|
||||
|
||||
// Record not found -- create with defaults
|
||||
prefs = models.NotificationPreference{
|
||||
UserID: userID,
|
||||
TaskDueSoon: true,
|
||||
TaskOverdue: true,
|
||||
@@ -151,17 +159,36 @@ func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.No
|
||||
WarrantyExpiring: true,
|
||||
EmailTaskCompleted: true,
|
||||
}
|
||||
if err := r.CreatePreferences(prefs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return prefs, nil
|
||||
return tx.Create(&prefs).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return &prefs, nil
|
||||
}
|
||||
|
||||
// === Device Registration ===
|
||||
|
||||
// FindAPNSDeviceByID finds an APNS device by ID
|
||||
func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) {
|
||||
var device models.APNSDevice
|
||||
err := r.db.First(&device, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// FindGCMDeviceByID finds a GCM device by ID
|
||||
func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) {
|
||||
var device models.GCMDevice
|
||||
err := r.db.First(&device, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// FindAPNSDeviceByToken finds an APNS device by registration token
|
||||
func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) {
|
||||
var device models.APNSDevice
|
||||
@@ -243,12 +270,12 @@ func (r *NotificationRepository) DeactivateGCMDevice(id uint) error {
|
||||
// GetActiveTokensForUser gets all active push tokens for a user
|
||||
func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) {
|
||||
apnsDevices, err := r.FindAPNSDevicesByUser(userID)
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
gcmDevices, err := r.FindGCMDevicesByUser(userID)
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
||||
96
internal/repositories/notification_repo_test.go
Normal file
96
internal/repositories/notification_repo_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestGetOrCreatePreferences_New_Creates(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// No preferences exist yet for this user
|
||||
prefs, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, prefs)
|
||||
|
||||
// Verify defaults were set
|
||||
assert.Equal(t, user.ID, prefs.UserID)
|
||||
assert.True(t, prefs.TaskDueSoon)
|
||||
assert.True(t, prefs.TaskOverdue)
|
||||
assert.True(t, prefs.TaskCompleted)
|
||||
assert.True(t, prefs.TaskAssigned)
|
||||
assert.True(t, prefs.ResidenceShared)
|
||||
assert.True(t, prefs.WarrantyExpiring)
|
||||
assert.True(t, prefs.EmailTaskCompleted)
|
||||
|
||||
// Verify it was actually persisted
|
||||
var count int64
|
||||
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one preferences record")
|
||||
}
|
||||
|
||||
func TestGetOrCreatePreferences_AlreadyExists_Returns(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Create preferences manually first
|
||||
existingPrefs := &models.NotificationPreference{
|
||||
UserID: user.ID,
|
||||
TaskDueSoon: true,
|
||||
TaskOverdue: true,
|
||||
TaskCompleted: true,
|
||||
TaskAssigned: true,
|
||||
ResidenceShared: true,
|
||||
WarrantyExpiring: true,
|
||||
EmailTaskCompleted: true,
|
||||
}
|
||||
err := db.Create(existingPrefs).Error
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, existingPrefs.ID)
|
||||
|
||||
// GetOrCreatePreferences should return the existing record, not create a new one
|
||||
prefs, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, prefs)
|
||||
|
||||
// The returned record should have the same ID as the existing one
|
||||
assert.Equal(t, existingPrefs.ID, prefs.ID, "should return the existing record by ID")
|
||||
assert.Equal(t, user.ID, prefs.UserID, "should have correct user_id")
|
||||
|
||||
// Verify still only one record exists (no duplicate created)
|
||||
var count int64
|
||||
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should still have exactly one preferences record")
|
||||
}
|
||||
|
||||
func TestGetOrCreatePreferences_Idempotent(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Call twice in succession
|
||||
prefs1, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
prefs2, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both should return the same record
|
||||
assert.Equal(t, prefs1.ID, prefs2.ID)
|
||||
|
||||
// Should only have one record
|
||||
var count int64
|
||||
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one preferences record after two calls")
|
||||
}
|
||||
@@ -37,6 +37,84 @@ func (r *ReminderRepository) HasSentReminder(taskID, userID uint, dueDate time.T
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ReminderKey uniquely identifies a reminder that may have been sent.
|
||||
type ReminderKey struct {
|
||||
TaskID uint
|
||||
UserID uint
|
||||
DueDate time.Time
|
||||
Stage models.ReminderStage
|
||||
}
|
||||
|
||||
// HasSentReminderBatch checks which reminders from the given list have already been sent.
|
||||
// Returns a set of indices into the input slice that have already been sent.
|
||||
// This replaces N individual HasSentReminder calls with a single query.
|
||||
func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]bool, error) {
|
||||
result := make(map[int]bool)
|
||||
if len(keys) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Build a lookup from (task_id, user_id, due_date, stage) -> index
|
||||
type normalizedKey struct {
|
||||
TaskID uint
|
||||
UserID uint
|
||||
DueDate string
|
||||
Stage models.ReminderStage
|
||||
}
|
||||
keyToIdx := make(map[normalizedKey][]int, len(keys))
|
||||
|
||||
// Collect unique task IDs and user IDs for the WHERE clause
|
||||
taskIDSet := make(map[uint]bool)
|
||||
userIDSet := make(map[uint]bool)
|
||||
for i, k := range keys {
|
||||
taskIDSet[k.TaskID] = true
|
||||
userIDSet[k.UserID] = true
|
||||
dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC)
|
||||
nk := normalizedKey{
|
||||
TaskID: k.TaskID,
|
||||
UserID: k.UserID,
|
||||
DueDate: dueDateOnly.Format("2006-01-02"),
|
||||
Stage: k.Stage,
|
||||
}
|
||||
keyToIdx[nk] = append(keyToIdx[nk], i)
|
||||
}
|
||||
|
||||
taskIDs := make([]uint, 0, len(taskIDSet))
|
||||
for id := range taskIDSet {
|
||||
taskIDs = append(taskIDs, id)
|
||||
}
|
||||
userIDs := make([]uint, 0, len(userIDSet))
|
||||
for id := range userIDSet {
|
||||
userIDs = append(userIDs, id)
|
||||
}
|
||||
|
||||
// Query all matching reminder logs in one query
|
||||
var logs []models.TaskReminderLog
|
||||
err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs).
|
||||
Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Match returned logs against our key set
|
||||
for _, l := range logs {
|
||||
dueDateStr := l.DueDate.Format("2006-01-02")
|
||||
nk := normalizedKey{
|
||||
TaskID: l.TaskID,
|
||||
UserID: l.UserID,
|
||||
DueDate: dueDateStr,
|
||||
Stage: l.ReminderStage,
|
||||
}
|
||||
if indices, ok := keyToIdx[nk]; ok {
|
||||
for _, idx := range indices {
|
||||
result[idx] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// LogReminder records that a reminder was sent.
|
||||
// Returns the created log entry or an error if the reminder was already sent
|
||||
// (unique constraint violation).
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
@@ -269,7 +270,9 @@ func (r *ResidenceRepository) GetActiveShareCode(residenceID uint) (*models.Resi
|
||||
// Check if expired
|
||||
if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) {
|
||||
// Auto-deactivate expired code
|
||||
r.DeactivateShareCode(shareCode.ID)
|
||||
if err := r.DeactivateShareCode(shareCode.ID); err != nil {
|
||||
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate expired share code")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -296,9 +299,11 @@ func (r *ResidenceRepository) generateUniqueCode() (string, error) {
|
||||
|
||||
// Check if code already exists
|
||||
var count int64
|
||||
r.db.Model(&models.ResidenceShareCode{}).
|
||||
if err := r.db.Model(&models.ResidenceShareCode{}).
|
||||
Where("code = ? AND is_active = ?", codeStr, true).
|
||||
Count(&count)
|
||||
Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return codeStr, nil
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
)
|
||||
@@ -30,31 +32,37 @@ func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscrip
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetOrCreate gets or creates a subscription for a user (defaults to free tier)
|
||||
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
|
||||
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
|
||||
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
|
||||
sub, err := r.FindByUserID(userID)
|
||||
if err == nil {
|
||||
return sub, nil
|
||||
}
|
||||
var sub models.UserSubscription
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
sub = &models.UserSubscription{
|
||||
UserID: userID,
|
||||
Tier: models.TierFree,
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Where("user_id = ?", userID).First(&sub).Error
|
||||
if err == nil {
|
||||
return nil // Found existing subscription
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err // Unexpected error
|
||||
}
|
||||
|
||||
// Record not found -- create with free tier defaults
|
||||
sub = models.UserSubscription{
|
||||
UserID: userID,
|
||||
Tier: models.TierFree,
|
||||
AutoRenew: true,
|
||||
}
|
||||
if err := r.db.Create(sub).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sub, nil
|
||||
return tx.Create(&sub).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Update updates a subscription
|
||||
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
|
||||
return r.db.Save(sub).Error
|
||||
return r.db.Omit("User").Save(sub).Error
|
||||
}
|
||||
|
||||
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
|
||||
@@ -63,7 +71,7 @@ func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time,
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Lock the row for update
|
||||
var sub models.UserSubscription
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("user_id = ?", userID).First(&sub).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -86,7 +94,7 @@ func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Lock the row for update
|
||||
var sub models.UserSubscription
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("user_id = ?", userID).First(&sub).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -165,7 +173,7 @@ func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*m
|
||||
var limits models.TierLimits
|
||||
err := r.db.Where("tier = ?", tier).First(&limits).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Return defaults
|
||||
if tier == models.TierFree {
|
||||
defaults := models.GetDefaultFreeLimits()
|
||||
@@ -193,7 +201,7 @@ func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, er
|
||||
var settings models.SubscriptionSettings
|
||||
err := r.db.First(&settings).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Return default settings (limitations disabled)
|
||||
return &models.SubscriptionSettings{
|
||||
EnableLimitations: false,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user