Harden API security: input validation, safe auth extraction, new tests, and deploy config

Comprehensive security hardening from audit findings:
- Add validation tags to all DTO request structs (max lengths, ranges, enums)
- Replace unsafe type assertions with MustGetAuthUser helper across all handlers
- Remove query-param token auth from admin middleware (prevents URL token leakage)
- Add request validation calls in handlers that were missing c.Validate()
- Remove goroutines in handlers (timezone update now synchronous)
- Add sanitize middleware and path traversal protection (path_utils)
- Stop resetting admin passwords on migration restart
- Warn on well-known default SECRET_KEY
- Add ~30 new test files covering security regressions, auth safety, repos, and services
- Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-02 09:48:01 -06:00
parent 56d6fa4514
commit 7690f07a2b
123 changed files with 8321 additions and 750 deletions

5
.deploy_prod Executable file
View File

@@ -0,0 +1,5 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
exec "${SCRIPT_DIR}/deploy/scripts/deploy_prod.sh" "$@"

38
audit-digest-1.md Normal file
View File

@@ -0,0 +1,38 @@
# Digest 1: cmd/, admin/dto, admin/handlers (first 15 files)
## Systemic Issues (across all admin handlers)
- **SQL Injection via SortBy**: Every admin list handler concatenates `filters.SortBy` directly into GORM `Order()` without allowlist validation
- **Unchecked Count() errors**: Every paginated handler ignores GORM Count error returns
- **Unchecked post-mutation Preload errors**: After Save/Create, handlers reload with Preload but ignore errors
- **`binding` vs `validate` tag mismatch**: Some request DTOs use `binding` (Gin) instead of `validate` (Echo)
- **Direct DB access**: All admin handlers bypass Service layer, accessing `*gorm.DB` directly
- **Unsafe type assertions**: `c.Get(key).(*models.AdminUser)` without comma-ok
## Per-File Highlights
### cmd/api/main.go - App entry point, wires dependencies
### cmd/worker/main.go - Background worker entry point
### admin/handlers/admin_user_handler.go (347 lines)
- N+1 query: `toUserResponse` does 2 extra DB queries per user (residence count, task count)
- Line 64: SortBy SQL injection
- Line 173: Unchecked profile creation error (user created without profile)
### admin/handlers/apple_social_auth_handler.go - CRUD for Apple social auth records
- Same systemic SQL injection and unchecked errors
### admin/handlers/auth_handler.go - Admin login/session management
### admin/handlers/auth_token_handler.go - Auth token CRUD
### admin/handlers/completion_handler.go - Task completion CRUD
### admin/handlers/completion_image_handler.go - Completion image CRUD
### admin/handlers/confirmation_code_handler.go - Email confirmation code CRUD
### admin/handlers/contractor_handler.go - Contractor CRUD
### admin/handlers/dashboard_handler.go - Admin dashboard stats
### admin/handlers/device_handler.go (317 lines)
- Exposes device push tokens (RegistrationID) in API responses
- Lines 293-296: Unchecked Count errors in GetStats
### admin/handlers/document_handler.go (394 lines)
- Lines 176-183: Date parsing errors silently ignored
- Line 379: Precision loss from decimal.Float64() discarded

51
audit-digest-2.md Normal file
View File

@@ -0,0 +1,51 @@
# Digest 2: admin/handlers (remaining 15 files)
### admin/handlers/document_image_handler.go (245 lines)
- N+1: toResponse queries DB per image in List
- Same SortBy SQL injection
### admin/handlers/feature_benefit_handler.go (231 lines)
- `binding` tags instead of `validate` - required fields never enforced
### admin/handlers/limitations_handler.go (451 lines)
- Line 37: Unchecked Create error for default settings
- Line 191-197: UpdateTierLimits overwrites ALL fields even for partial updates
### admin/handlers/lookup_handler.go (877 lines)
- **CRITICAL**: Lines 30-32, 50-52, etc.: refreshXxxCache checks `if cache == nil {}` with EMPTY body, then calls cache.CacheXxx() — nil pointer panic when cache is nil
- Line 792: Hardcoded join table name "task_contractor_specialties"
### admin/handlers/notification_handler.go (419 lines)
- Line 351-363: HTML template built by string concatenation with user-provided subject/body — XSS in admin emails
### admin/handlers/notification_prefs_handler.go (347 lines)
- Line 154: Unchecked user lookup — deleted user produces zero-value username/email
### admin/handlers/onboarding_handler.go (343 lines)
- Line 304: Internal error details leaked to client
### admin/handlers/password_reset_code_handler.go (161 lines)
- **BUG**: Line 85: `code.ResetToken[:8] + "..." + code.ResetToken[len-4:]` panics if token < 8 chars
### admin/handlers/promotion_handler.go (304 lines)
- `binding` tags: required fields never enforced
### admin/handlers/residence_handler.go (371 lines)
- Lines 121-122: Unchecked Count errors for task/document counts
### admin/handlers/settings_handler.go (794 lines)
- Line 378: Raw SQL execution from seed files (no parameterization)
- Line 529-793: ClearAllData is destructive with no double-auth check
- Line 536-539: Panic in ClearAllData silently swallowed
### admin/handlers/share_code_handler.go (225 lines)
- Line 155-162: `IsActive` as non-pointer bool — absent field defaults to false, deactivating codes
### admin/handlers/subscription_handler.go (237 lines)
- **BUG**: Line 40-41: JOIN uses "users" but actual table is "auth_user" — query fails on PostgreSQL
### admin/handlers/task_handler.go (401 lines)
- Line 247-296: Admin Create bypasses service layer — no business logic applied
### admin/handlers/task_template_handler.go (347 lines)
- Lines 29-31: Same nil cache panic as lookup_handler.go

30
audit-digest-3.md Normal file
View File

@@ -0,0 +1,30 @@
# Digest 3: admin routes, apperrors, config, database, dto/requests, dto/responses (first half)
### admin/routes.go (483 lines)
- Route ordering: users DELETE "/:id" before "/bulk" — "/bulk" matches as id param
- Line 454: Uses os.Getenv instead of Viper config
- Line 460-462: url.Parse failure returns silently, no logging
- Line 467-469: Proxy errors not surfaced (always returns nil)
### apperrors/errors.go (98 lines) - Clean. Error types with Wrap/Unwrap.
### apperrors/handler.go (67 lines) - c.JSON error returns discarded (minor)
### config/config.go (427 lines)
- Line 339: Hardcoded debug secret key "change-me-in-production-secret-key-12345"
- Lines 311-312: Comments say wrong UTC times (says 8PM/9AM, actually 2PM/3PM)
### database/database.go (468 lines)
- **SECURITY**: Line 372-382: Hardcoded bcrypt hash for GoAdmin with password "admin" — migration RESETS password every run
- **SECURITY**: Line 447-463: Hardcoded admin@mycrib.com / admin123 — password reset on every migration
- Line 182+: Multiple db.Exec errors unchecked for index creation
- Line 100-102: WithTransaction coupled to global db variable
### dto/requests/auth.go (66 lines) - LoginRequest min=1 password (intentional for login)
### dto/requests/contractor.go (37 lines) - Rating has no min/max validation bounds
### dto/requests/document.go (47 lines) - ImageURLs no length limit, Description no max length
### dto/requests/residence.go (59 lines) - Bedrooms/SquareFootage accept negative values, ExpiresInHours no validation
### dto/requests/task.go (110 lines) - Rating no bounds, ImageURLs no length limit, CustomIntervalDays no min
### dto/responses/auth.go (190 lines) - Clean. Proper nil checks on Profile.
### dto/responses/contractor.go (131 lines) - TaskCount depends on preloaded Tasks association
### dto/responses/document.go (126 lines) - Line 101: CreatedBy accessed as value type (fragile if changed to pointer)

48
audit-digest-4.md Normal file
View File

@@ -0,0 +1,48 @@
# Digest 4: dto/responses (remaining), echohelpers, handlers (first half)
### dto/responses/residence.go (215 lines) - NewResidenceResponse no nil check on param. Owner zero-value if not preloaded.
### dto/responses/task_template.go (135 lines) - No nil check on template param
### dto/responses/task.go (399 lines) - No nil checks on params in factory functions
### dto/responses/user.go (20 lines) - Clean data types
### echohelpers/helpers.go (46 lines) - Clean utilities
### echohelpers/pagination.go (33 lines) - Clean, properly bounded
### handlers/auth_handler.go (379 lines)
- **ARCHITECTURE**: Lines 83, 178, 207, 241, 329, 370: SIX goroutine spawns for email — violates "no goroutines in handlers" rule
- Line 308-312: AppError constructed directly instead of factory function
### handlers/contractor_handler.go (154 lines)
- Line 28+: Unchecked type assertions throughout (7 instances)
- Line 31: Raw err.Error() returned to client
- Line 55: CreateContractor missing c.Validate() call
### handlers/document_handler.go (336 lines)
- Line 37+: Unchecked type assertions (10 instances)
- Line 92-93: Raw error leaked to client
- Line 137: No DocumentType validation — any string accepted
- Lines 187, 217: Missing c.Validate() calls
### handlers/media_handler.go (172 lines)
- **SECURITY**: Line 156-171: resolveFilePath uses filepath.Join with user-influenced data — PATH TRAVERSAL vulnerability. TrimPrefix doesn't sanitize ../
- Line 19-22: Handler accesses repositories directly, bypasses service layer
### handlers/notification_handler.go (200 lines)
- Line 29-40: No upper bound on limit — unbounded query with limit=999999999
- Line 168: Silent default to "ios" platform
### handlers/residence_handler.go (365 lines)
- Line 38+: Unchecked type assertions (14 instances)
- Lines 187, 209, 303: Bind errors silently discarded
- Line 224: JoinWithCode missing c.Validate()
### handlers/static_data_handler.go (152 lines) - Uses interface{} instead of concrete types
### handlers/subscription_handler.go (176 lines) - Lines 97, 150: Missing c.Validate() calls
- Line 159-163: RestoreSubscription doesn't validate receipt/transaction ID presence
### handlers/subscription_webhook_handler.go (821 lines)
- **SECURITY**: Line 190-192: Apple JWS payload decoded WITHOUT signature verification
- **SECURITY**: Line 787-793: VerifyGooglePubSubToken ALWAYS returns true — webhook unauthenticated
- Line 639-643: Subscription duration guessed by string matching product ID
- Line 657, 694: Hardcoded 1-month extension regardless of actual plan
- Line 759, 772: Unchecked type assertions in VerifyAppleSignature
- Line 162: Apple renewal info error silently discarded

45
audit-digest-5.md Normal file
View File

@@ -0,0 +1,45 @@
# Digest 5: handlers (remaining), i18n, middleware, models (first half)
### handlers/task_handler.go (440 lines)
- Line 35+: Unchecked type assertions (18 locations)
- Line 42: Fire-and-forget goroutine for UpdateUserTimezone — no error handling, no context
- Lines 112-115, 134-137: Missing c.Validate() calls
- Line 317: 32MB multipart limit with no per-file size check
### handlers/task_template_handler.go (98 lines)
- Line 59: No max length on search query — slow LIKE queries possible
### handlers/tracking_handler.go (46 lines)
- Line 25: Package-level base64 decode error discarded
- Lines 34-36: Fire-and-forget goroutine — violates no-goroutines rule
### handlers/upload_handler.go (93 lines)
- Line 31: User-controlled `category` param passed to storage — potential path traversal
- Line 80: `binding` tag instead of `validate`
- No file type or size validation at handler level
### handlers/user_handler.go (76 lines) - Unchecked type assertions
### i18n/i18n.go (87 lines)
- Line 16: Global Bundle is nil until Init() — NewLocalizer dereferences without nil check
- Line 37: MustParseMessageFileBytes panics on malformed translation files
- Line 83: MustT panics on missing translations
### i18n/middleware.go (127 lines) - Clean
### middleware/admin_auth.go (133 lines)
- **SECURITY**: Line 50: Admin JWT accepted via query param — tokens leak into server/proxy logs
- Line 124: Unchecked type assertion
### middleware/auth.go (229 lines)
- **BUG**: Line 66: `token[:8]` panics if token is fewer than 8 characters
- Line 104: cacheUserID error silently discarded
- Line 209: Unchecked type assertion
### middleware/logger.go (54 lines) - Clean
### middleware/request_id.go (44 lines) - Line 21: Client-supplied X-Request-ID accepted without validation (log injection)
### middleware/timezone.go (101 lines) - Lines 88, 99: Unchecked type assertions
### models/admin.go (64 lines) - Line 38: No max password length check; bcrypt truncates at 72 bytes
### models/base.go (39 lines) - Clean GORM hooks
### models/contractor.go (54 lines) - *float64 mapped to decimal(2,1) — minor precision mismatch

40
audit-digest-6.md Normal file
View File

@@ -0,0 +1,40 @@
# Digest 6: models (remaining), monitoring
### models/document.go (100 lines) - Clean
### models/notification.go (141 lines) - Clean
### models/onboarding_email.go (35 lines) - Clean
### models/reminder_log.go (92 lines) - Clean
### models/residence.go (106 lines)
- Lines 65-70: GetAllUsers/HasAccess assumes Owner and Users are preloaded — returns wrong results if not
### models/subscription.go (169 lines)
- Lines 57-65: IsActive()/IsPro() don't account for IsFree admin override field — misleading method names
### models/task_template.go (23 lines) - Clean
### models/task.go (317 lines)
- Lines 189-264: GetKanbanColumnWithTimezone duplicates categorization chain logic — maintenance drift risk
- Lines 158-182: IsDueSoon uses time.Now() internally — non-deterministic, harder to test
### models/user.go (268 lines)
- Line 101: crypto/rand.Read error unchecked (safe in practice since Go 1.20)
- Line 164-172: GenerateConfirmationCode has slight distribution bias (negligible)
### monitoring/buffer.go (166 lines) - Line 75: Corrupted Redis data silently dropped
### monitoring/collector.go (201 lines)
- Line 82: cpu.Percent blocks 1 second per collection
- Line 96-110: ReadMemStats called TWICE per cycle (also in collectRuntime)
### monitoring/handler.go (195 lines)
- **SECURITY**: Line 19-22: WebSocket CheckOrigin always returns true
- **BUG**: Line 117-119: After upgrader.Upgrade fails, execution continues to conn.Close() — nil pointer panic
- **BUG**: Line 177: Missing return after ctx.Done() — goroutine spins
- Lines 183-184: GetAllStats error silently ignored
- Line 192: WriteJSON error unchecked
### monitoring/middleware.go (220 lines) - Clean
### monitoring/models.go (129 lines) - Clean
### monitoring/service.go (196 lines)
- Line 121: Hardcoded primary key 1 for singleton settings row
- Line 191-194: MetricsMiddleware returns nil when httpCollector is nil — caller must nil-check

57
audit-digest-7.md Normal file
View File

@@ -0,0 +1,57 @@
# Digest 7: monitoring/writer, notifications, push, repositories
### monitoring/writer.go (96 lines)
- Line 90-92: Unbounded fire-and-forget goroutines for Redis push — no rate limiting
### notifications/reminder_config.go (64 lines) - Clean config data
### notifications/reminder_schedule.go (199 lines)
- Line 112: Integer truncation of float division — DST off-by-one possible
- Line 148-161: Custom itoa reimplements strconv.Itoa
### push/apns.go (209 lines)
- Line 44: Double-negative logic — both Production=false and Sandbox=false defaults to production
### push/client.go (158 lines)
- Line 89-105: SendToAll last-error-wins — cannot tell which platform failed
- Line 150-157: HealthCheck always returns nil — useless health check
### push/fcm.go (140 lines)
- Line 16: Legacy FCM HTTP API (deprecated by Google)
- Line 119-126: If FCM returns fewer results than tokens, index out of bounds panic
### repositories/admin_repo.go (108 lines)
- Line 92: Negative page produces negative offset
### repositories/contractor_repo.go (166 lines)
- **RACE**: Line 89-101: ToggleFavorite read-then-write without transaction
- Line 91: ToggleFavorite doesn't filter is_active — can toggle deleted contractors
### repositories/document_repo.go (201 lines)
- Line 92: LIKE wildcards in user input not escaped
- Line 12: DocumentFilter.ResidenceID field defined but never used
### repositories/notification_repo.go (267 lines)
- **RACE**: Line 137-161: GetOrCreatePreferences race — concurrent calls both create, duplicate key error
- Line 143: Uses == instead of errors.Is for ErrRecordNotFound
### repositories/reminder_repo.go (126 lines)
- Line 115-122: rows.Err() not checked after iteration loop
### repositories/residence_repo.go (344 lines)
- Line 272: DeactivateShareCode error silently ignored
- Line 298-301: Count error unchecked in generateUniqueCode — potential duplicate codes
- Line 125-128: PostgreSQL-specific ON CONFLICT — fails on SQLite in tests
### repositories/subscription_repo.go (257 lines)
- **RACE**: Line 40: GetOrCreate race condition (same as notification_repo)
- Line 66: GORM v1 pattern `gorm:query_option` for FOR UPDATE — may not work in GORM v2
- Line 129: LIKE search on receipt data blobs — inefficient, no index
- Lines 40, 168, 196: Uses == instead of errors.Is
### repositories/task_repo.go (765 lines)
- Line 707-709: DeleteCompletion ignores image deletion error
- Line 62-101: applyFilterOptions applies NO scope when no filter set — could query all tasks
### repositories/task_template_repo.go (124 lines)
- Line 48: LIKE wildcard escape issue
- Line 79-81: Save without Omit could corrupt Category/Frequency lookup data

35
audit-digest-8.md Normal file
View File

@@ -0,0 +1,35 @@
# Digest 8: repositories (remaining), router, services (first half)
### repositories/user_repo.go - Standard GORM CRUD
### repositories/webhook_event_repo.go - Webhook event storage
### router/router.go - Route registration wiring
### services/apple_auth.go - Apple Sign In JWT validation
### services/auth_service.go - Token management, password hashing, email verification
### services/cache_service.go - Redis caching for lookups
### services/contractor_service.go - Contractor CRUD via repository
### services/document_service.go - Document management
### services/email_service.go - SMTP email sending
### services/google_auth.go - Google OAuth token validation
### services/iap_validation.go - Apple/Google receipt validation
### services/notification_service.go - Push notifications, preferences
### services/onboarding_email_service.go (371 lines)
- **ARCHITECTURE**: Direct *gorm.DB access — bypasses repository layer entirely
- Line 43-46: HasSentEmail ignores Count error — could send duplicate emails
- Line 128-133: GetEmailStats ignores 4 Count errors
- Line 170: Raw SQL references "auth_user" table
- Line 354: Delete error silently ignored
### services/pdf_service.go (179 lines)
- **BUG**: Line 131-133: Byte-level truncation of title — breaks multi-byte UTF-8 (CJK, emoji)
### services/residence_service.go (648 lines)
- Line 155: TODO comment — subscription tier limit check commented out (free tier bypass)
- Line 447-450: Empty if block — DeactivateShareCode error completely ignored
- Line 625: Status only set for in-progress tasks; all others have empty string

49
audit-digest-9.md Normal file
View File

@@ -0,0 +1,49 @@
# Digest 9: services (remaining), task package, testutil, validator, worker, pkg
### services/storage_service.go (184 lines)
- Line 75: UUID truncated to 8 chars — increased collision risk
- **SECURITY**: Line 137-138: filepath.Abs errors ignored — path traversal check could be bypassed
### services/subscription_service.go (659 lines)
- **PERFORMANCE**: Line 186-204: N+1 queries in getUserUsage — 3 queries per residence
- **SECURITY/BUSINESS**: Line 371: Apple validation failure grants 1-month free Pro
- **SECURITY/BUSINESS**: Line 381: Apple validation not configured grants 1-year free Pro
- **SECURITY/BUSINESS**: Line 429, 449: Same for Google — errors/misconfiguration grant free Pro
- Line 7: Uses stdlib "log" instead of zerolog
### services/task_button_types.go (85 lines) - Clean, uses predicates correctly
### services/task_service.go (1092 lines)
- **DATA INTEGRITY**: Line 601: If task update fails after completion creation, error only logged not returned — stale NextDueDate/InProgress
- Line 735: Goroutine in QuickComplete (service method) — inconsistent with synchronous CreateCompletion
- Line 773: Unbounded goroutine creation per user for notifications
- Line 790: Fail-open email notification on error (intentional but risky)
- **SECURITY**: Line 857-862: resolveImageFilePath has NO path traversal validation
### services/task_template_service.go (70 lines) - Errors returned raw (not wrapped with apperrors)
### services/user_service.go (88 lines) - Returns nil instead of empty slice (JSON null vs [])
### task/categorization/chain.go (359 lines) - Clean chain-of-responsibility
### task/predicates/predicates.go (217 lines)
- Line 210: IsRecurring requires Frequency preloaded — returns false without it
### task/scopes/scopes.go (270 lines)
- Line 118: ScopeOverdue doesn't exclude InProgress — differs from categorization chain
### task/task.go (261 lines) - Clean facade re-exports
### testutil/testutil.go (359 lines)
- Line 86: json.Marshal error ignored in MakeRequest
- Line 92: http.NewRequest error ignored
### validator/validator.go (103 lines) - Clean
### worker/jobs/email_jobs.go (118 lines) - Clean
### worker/jobs/handler.go (810 lines)
- Lines 95-106, 193-204: Direct DB access bypasses repository layer
- Line 627-635: Raw SQL with fmt.Sprintf (not currently user-supplied but fragile)
- Line 154, 251: O(N*M) lookup instead of map
### worker/scheduler.go (240 lines)
- Line 200-212: Cron schedules at fixed UTC times may conflict with smart reminder system — potential duplicate notifications
### pkg/utils/logger.go (132 lines) - Panic recovery bypasses apperrors.HTTPErrorHandler

18
deploy/.gitignore vendored Normal file
View File

@@ -0,0 +1,18 @@
# Local deploy inputs (copy from *.example files)
cluster.env
registry.env
prod.env
# Local secret material
secrets/*.txt
secrets/*.p8
# Keep templates and docs tracked
!*.example
!README.md
!shit_deploy_cant_do.md
!swarm-stack.prod.yml
!scripts/
!scripts/**
!secrets/*.example
!secrets/README.md

134
deploy/README.md Normal file
View File

@@ -0,0 +1,134 @@
# Deploy Folder
This folder is the full production deploy toolkit for `myCribAPI-go`.
Run deploy with:
```bash
./.deploy_prod
```
The script will refuse to run until all required values are set.
## First-Time Prerequisite: Create The Swarm Cluster
You must do this once before `./.deploy_prod` can work.
1. SSH to manager #1 and initialize Swarm:
```bash
docker swarm init --advertise-addr <manager1-private-ip>
```
2. On manager #1, get join commands:
```bash
docker swarm join-token manager
docker swarm join-token worker
```
3. SSH to each additional node and run the appropriate `docker swarm join ...` command.
4. Verify from manager #1:
```bash
docker node ls
```
## Security Requirements Before Public Launch
Use this as a mandatory checklist before you route production traffic.
### 1) Firewall Rules (Node-Level)
Apply firewall rules to all Swarm nodes:
- SSH port (for example `2222/tcp`): your IP only
- `80/tcp`, `443/tcp`: Hetzner LB only (or Cloudflare IP ranges only if no LB)
- `2377/tcp`: Swarm nodes only
- `7946/tcp,udp`: Swarm nodes only
- `4789/udp`: Swarm nodes only
- Everything else: blocked
### 2) SSH Hardening
On each node, harden `/etc/ssh/sshd_config`:
```text
Port 2222
PermitRootLogin no
PasswordAuthentication no
PubkeyAuthentication yes
AllowUsers deploy
```
### 3) Cloudflare Origin Lockdown
- Keep public DNS records proxied (orange cloud on).
- Point Cloudflare to LB, not node IPs.
- Do not publish Swarm node IPs in DNS.
- Enforce firewall source restrictions so public traffic cannot bypass Cloudflare/LB.
### 4) Secrets Policy
- Keep runtime secrets in Docker Swarm secrets only.
- Do not put production secrets in git or plain `.env` files.
- `./.deploy_prod` already creates versioned Swarm secrets from files in `deploy/secrets/`.
- Rotate secrets after incidents or credential exposure.
### 5) Data Path Security
- Neon/Postgres: `DB_SSLMODE=require`, strong DB password, Neon IP allowlist limited to node IPs.
- Backblaze B2: HTTPS only, scoped app keys (not master key), least-privilege bucket access.
- Swarm overlay: encrypted network enabled in stack (`driver_opts.encrypted: "true"`).
### 6) Dozzle Hardening
- Keep Dozzle private (no public DNS/ingress).
- Put auth/SSO in front (Cloudflare Access or equivalent).
- Prefer a Docker socket proxy with restricted read-only scope.
### 7) Backup + Restore Readiness
- Postgres PITR path tested in staging.
- Redis persistence enabled and restore path tested.
- Written runbook for restore and secret rotation.
- Named owner for incident response.
## Files You Fill In
Paste your values into these files:
- `deploy/cluster.env`
- `deploy/registry.env`
- `deploy/prod.env`
- `deploy/secrets/postgres_password.txt`
- `deploy/secrets/secret_key.txt`
- `deploy/secrets/email_host_password.txt`
- `deploy/secrets/fcm_server_key.txt`
- `deploy/secrets/apns_auth_key.p8`
If one is missing, the deploy script auto-copies it from its `.example` template and exits so you can fill it.
## What `./.deploy_prod` Does
1. Validates all required config files and credentials.
2. Builds and pushes `api`, `worker`, and `admin` images.
3. Uploads deploy bundle to your Swarm manager over SSH.
4. Creates versioned Docker secrets on the manager.
5. Deploys the stack with `docker stack deploy --with-registry-auth`.
6. Waits until service replicas converge.
7. Runs an HTTP health check (if `DEPLOY_HEALTHCHECK_URL` is set).
## Useful Flags
Environment flags:
- `SKIP_BUILD=1 ./.deploy_prod` to deploy already-pushed images.
- `SKIP_HEALTHCHECK=1 ./.deploy_prod` to skip final URL check.
- `DEPLOY_TAG=<tag> ./.deploy_prod` to deploy a specific image tag.
## Important
- `deploy/shit_deploy_cant_do.md` lists the manual tasks this script cannot automate.
- Keep real credentials and secret files out of git.

View File

@@ -0,0 +1,22 @@
# Swarm manager connection
DEPLOY_MANAGER_HOST=CHANGEME_MANAGER_IP_OR_HOSTNAME
DEPLOY_MANAGER_USER=deploy
DEPLOY_MANAGER_SSH_PORT=22
DEPLOY_SSH_KEY_PATH=~/.ssh/id_ed25519
# Stack settings
DEPLOY_STACK_NAME=casera
DEPLOY_REMOTE_DIR=/opt/casera/deploy
DEPLOY_WAIT_SECONDS=420
DEPLOY_HEALTHCHECK_URL=https://api.casera.app/api/health/
# Replicas and published ports
API_REPLICAS=3
WORKER_REPLICAS=2
ADMIN_REPLICAS=1
API_PORT=8000
ADMIN_PORT=3000
DOZZLE_PORT=9999
# Build behavior
PUSH_LATEST_TAG=true

73
deploy/prod.env.example Normal file
View File

@@ -0,0 +1,73 @@
# API service settings
DEBUG=false
ALLOWED_HOSTS=api.casera.app,casera.app
CORS_ALLOWED_ORIGINS=https://casera.app,https://admin.casera.app
TIMEZONE=UTC
BASE_URL=https://casera.app
PORT=8000
# Admin service settings
NEXT_PUBLIC_API_URL=https://api.casera.app
ADMIN_PANEL_URL=https://admin.casera.app
# Database (Neon recommended)
DB_HOST=CHANGEME_NEON_HOST
DB_PORT=5432
POSTGRES_USER=CHANGEME_DB_USER
POSTGRES_DB=casera
DB_SSLMODE=require
DB_MAX_OPEN_CONNS=25
DB_MAX_IDLE_CONNS=10
DB_MAX_LIFETIME=600s
# Redis (in stack defaults to redis://redis:6379/0)
REDIS_URL=redis://redis:6379/0
REDIS_DB=0
# Email (password goes in deploy/secrets/email_host_password.txt)
EMAIL_HOST=smtp.gmail.com
EMAIL_PORT=587
EMAIL_USE_TLS=true
EMAIL_HOST_USER=CHANGEME_EMAIL_USER
DEFAULT_FROM_EMAIL=Casera <noreply@casera.app>
# Push notifications
# APNS private key goes in deploy/secrets/apns_auth_key.p8
APNS_AUTH_KEY_ID=CHANGEME_APNS_KEY_ID
APNS_TEAM_ID=CHANGEME_APNS_TEAM_ID
APNS_TOPIC=com.tt.casera
APNS_USE_SANDBOX=false
APNS_PRODUCTION=true
# Worker schedules (UTC)
TASK_REMINDER_HOUR=14
OVERDUE_REMINDER_HOUR=15
DAILY_DIGEST_HOUR=3
# Storage
STORAGE_UPLOAD_DIR=/app/uploads
STORAGE_BASE_URL=/uploads
STORAGE_MAX_FILE_SIZE=10485760
STORAGE_ALLOWED_TYPES=image/jpeg,image/png,image/gif,image/webp,application/pdf
# Feature flags
FEATURE_PUSH_ENABLED=true
FEATURE_EMAIL_ENABLED=true
FEATURE_WEBHOOKS_ENABLED=true
FEATURE_ONBOARDING_EMAILS_ENABLED=true
FEATURE_PDF_REPORTS_ENABLED=true
FEATURE_WORKER_ENABLED=true
# Optional auth/iap values
APPLE_CLIENT_ID=
APPLE_TEAM_ID=
GOOGLE_CLIENT_ID=
GOOGLE_ANDROID_CLIENT_ID=
GOOGLE_IOS_CLIENT_ID=
APPLE_IAP_KEY_ID=
APPLE_IAP_ISSUER_ID=
APPLE_IAP_BUNDLE_ID=
APPLE_IAP_KEY_PATH=
APPLE_IAP_SANDBOX=false
GOOGLE_IAP_PACKAGE_NAME=
GOOGLE_IAP_SERVICE_ACCOUNT_PATH=

View File

@@ -0,0 +1,11 @@
# Container registry used for deploy images.
# For GHCR:
# REGISTRY=ghcr.io
# REGISTRY_NAMESPACE=<github-username-or-org>
# REGISTRY_USERNAME=<github-username>
# REGISTRY_TOKEN=<github-pat-with-read:packages,write:packages>
REGISTRY=ghcr.io
REGISTRY_NAMESPACE=CHANGEME_NAMESPACE
REGISTRY_USERNAME=CHANGEME_USERNAME
REGISTRY_TOKEN=CHANGEME_TOKEN

397
deploy/scripts/deploy_prod.sh Executable file
View File

@@ -0,0 +1,397 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DEPLOY_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
REPO_DIR="$(cd "${DEPLOY_DIR}/.." && pwd)"
STACK_TEMPLATE="${DEPLOY_DIR}/swarm-stack.prod.yml"
CLUSTER_ENV="${DEPLOY_DIR}/cluster.env"
REGISTRY_ENV="${DEPLOY_DIR}/registry.env"
PROD_ENV="${DEPLOY_DIR}/prod.env"
SECRET_POSTGRES="${DEPLOY_DIR}/secrets/postgres_password.txt"
SECRET_APP_KEY="${DEPLOY_DIR}/secrets/secret_key.txt"
SECRET_EMAIL_PASS="${DEPLOY_DIR}/secrets/email_host_password.txt"
SECRET_FCM_KEY="${DEPLOY_DIR}/secrets/fcm_server_key.txt"
SECRET_APNS_KEY="${DEPLOY_DIR}/secrets/apns_auth_key.p8"
SKIP_BUILD="${SKIP_BUILD:-0}"
SKIP_HEALTHCHECK="${SKIP_HEALTHCHECK:-0}"
log() {
printf '[deploy] %s\n' "$*"
}
warn() {
printf '[deploy][warn] %s\n' "$*" >&2
}
die() {
printf '[deploy][error] %s\n' "$*" >&2
exit 1
}
require_cmd() {
command -v "$1" >/dev/null 2>&1 || die "Missing required command: $1"
}
contains_placeholder() {
local value="$1"
[[ -z "${value}" ]] && return 0
local lowered
lowered="$(printf '%s' "${value}" | tr '[:upper:]' '[:lower:]')"
case "${lowered}" in
*changeme*|*replace_me*|*example.com*|*your-*|*todo*|*fill_me*|*paste_here*)
return 0
;;
*)
return 1
;;
esac
}
ensure_file_from_example() {
local path="$1"
local example="${path}.example"
if [[ -f "${path}" ]]; then
return
fi
if [[ -f "${example}" ]]; then
cp "${example}" "${path}"
die "Created ${path} from template. Fill it in and rerun."
fi
die "Missing required file: ${path}"
}
require_var() {
local name="$1"
local value="${!name:-}"
[[ -n "${value}" ]] || die "Missing required value: ${name}"
if contains_placeholder "${value}"; then
die "Value still uses placeholder text: ${name}=${value}"
fi
}
require_secret_file() {
local path="$1"
local label="$2"
ensure_file_from_example "${path}"
local contents
contents="$(tr -d '\r' < "${path}" | sed 's/[[:space:]]*$//')"
[[ -n "${contents}" ]] || die "Secret file is empty: ${path}"
if contains_placeholder "${contents}"; then
die "Secret file still has placeholder text (${label}): ${path}"
fi
}
print_usage() {
cat <<'EOF'
Usage:
./.deploy_prod
Optional environment flags:
SKIP_BUILD=1 Deploy existing image tags without rebuilding/pushing.
SKIP_HEALTHCHECK=1 Skip final HTTP health check.
DEPLOY_TAG=<tag> Override image tag (default: git short sha).
EOF
}
while (($# > 0)); do
case "$1" in
-h|--help)
print_usage
exit 0
;;
*)
die "Unknown argument: $1"
;;
esac
done
require_cmd docker
require_cmd ssh
require_cmd scp
require_cmd git
require_cmd awk
require_cmd sed
require_cmd grep
require_cmd mktemp
require_cmd date
require_cmd curl
ensure_file_from_example "${CLUSTER_ENV}"
ensure_file_from_example "${REGISTRY_ENV}"
ensure_file_from_example "${PROD_ENV}"
require_secret_file "${SECRET_POSTGRES}" "Postgres password"
require_secret_file "${SECRET_APP_KEY}" "SECRET_KEY"
require_secret_file "${SECRET_EMAIL_PASS}" "SMTP password"
require_secret_file "${SECRET_FCM_KEY}" "FCM server key"
require_secret_file "${SECRET_APNS_KEY}" "APNS private key"
set -a
# shellcheck disable=SC1090
source "${CLUSTER_ENV}"
# shellcheck disable=SC1090
source "${REGISTRY_ENV}"
# shellcheck disable=SC1090
source "${PROD_ENV}"
set +a
DEPLOY_MANAGER_SSH_PORT="${DEPLOY_MANAGER_SSH_PORT:-22}"
DEPLOY_STACK_NAME="${DEPLOY_STACK_NAME:-casera}"
DEPLOY_REMOTE_DIR="${DEPLOY_REMOTE_DIR:-/opt/casera/deploy}"
DEPLOY_WAIT_SECONDS="${DEPLOY_WAIT_SECONDS:-420}"
DEPLOY_TAG="${DEPLOY_TAG:-$(git -C "${REPO_DIR}" rev-parse --short HEAD)}"
PUSH_LATEST_TAG="${PUSH_LATEST_TAG:-true}"
require_var DEPLOY_MANAGER_HOST
require_var DEPLOY_MANAGER_USER
require_var DEPLOY_STACK_NAME
require_var DEPLOY_REMOTE_DIR
require_var REGISTRY
require_var REGISTRY_NAMESPACE
require_var REGISTRY_USERNAME
require_var REGISTRY_TOKEN
require_var ALLOWED_HOSTS
require_var CORS_ALLOWED_ORIGINS
require_var BASE_URL
require_var NEXT_PUBLIC_API_URL
require_var DB_HOST
require_var DB_PORT
require_var POSTGRES_USER
require_var POSTGRES_DB
require_var DB_SSLMODE
require_var REDIS_URL
require_var EMAIL_HOST
require_var EMAIL_PORT
require_var EMAIL_HOST_USER
require_var DEFAULT_FROM_EMAIL
require_var APNS_AUTH_KEY_ID
require_var APNS_TEAM_ID
require_var APNS_TOPIC
if [[ ! "$(tr -d '\r\n' < "${SECRET_APNS_KEY}")" =~ BEGIN[[:space:]]+PRIVATE[[:space:]]+KEY ]]; then
die "APNS key file does not look like a private key: ${SECRET_APNS_KEY}"
fi
app_secret_len="$(tr -d '\r\n' < "${SECRET_APP_KEY}" | wc -c | tr -d ' ')"
if (( app_secret_len < 32 )); then
die "deploy/secrets/secret_key.txt must be at least 32 characters."
fi
REGISTRY_PREFIX="${REGISTRY%/}/${REGISTRY_NAMESPACE#/}"
API_IMAGE="${REGISTRY_PREFIX}/casera-api:${DEPLOY_TAG}"
WORKER_IMAGE="${REGISTRY_PREFIX}/casera-worker:${DEPLOY_TAG}"
ADMIN_IMAGE="${REGISTRY_PREFIX}/casera-admin:${DEPLOY_TAG}"
SSH_KEY_PATH="${DEPLOY_SSH_KEY_PATH:-}"
if [[ -n "${SSH_KEY_PATH}" ]]; then
SSH_KEY_PATH="${SSH_KEY_PATH/#\~/${HOME}}"
fi
SSH_TARGET="${DEPLOY_MANAGER_USER}@${DEPLOY_MANAGER_HOST}"
SSH_OPTS=(-p "${DEPLOY_MANAGER_SSH_PORT}")
SCP_OPTS=(-P "${DEPLOY_MANAGER_SSH_PORT}")
if [[ -n "${SSH_KEY_PATH}" ]]; then
SSH_OPTS+=(-i "${SSH_KEY_PATH}")
SCP_OPTS+=(-i "${SSH_KEY_PATH}")
fi
log "Validating SSH access to ${SSH_TARGET}"
if ! ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "echo ok" >/dev/null 2>&1; then
die "SSH connection failed to ${SSH_TARGET}"
fi
remote_swarm_state="$(ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "docker info --format '{{.Swarm.LocalNodeState}} {{.Swarm.ControlAvailable}}'" 2>/dev/null || true)"
if [[ -z "${remote_swarm_state}" ]]; then
die "Could not read Docker Swarm state on manager. Is Docker installed/running?"
fi
if [[ "${remote_swarm_state}" != "active true" ]]; then
die "Remote node must be a Swarm manager. Got: ${remote_swarm_state}"
fi
if [[ "${SKIP_BUILD}" != "1" ]]; then
log "Logging in to ${REGISTRY}"
printf '%s' "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" -u "${REGISTRY_USERNAME}" --password-stdin >/dev/null
log "Building API image ${API_IMAGE}"
docker build --target api -t "${API_IMAGE}" "${REPO_DIR}"
log "Building Worker image ${WORKER_IMAGE}"
docker build --target worker -t "${WORKER_IMAGE}" "${REPO_DIR}"
log "Building Admin image ${ADMIN_IMAGE}"
docker build --target admin -t "${ADMIN_IMAGE}" "${REPO_DIR}"
log "Pushing deploy images"
docker push "${API_IMAGE}"
docker push "${WORKER_IMAGE}"
docker push "${ADMIN_IMAGE}"
if [[ "${PUSH_LATEST_TAG}" == "true" ]]; then
log "Updating :latest tags"
docker tag "${API_IMAGE}" "${REGISTRY_PREFIX}/casera-api:latest"
docker tag "${WORKER_IMAGE}" "${REGISTRY_PREFIX}/casera-worker:latest"
docker tag "${ADMIN_IMAGE}" "${REGISTRY_PREFIX}/casera-admin:latest"
docker push "${REGISTRY_PREFIX}/casera-api:latest"
docker push "${REGISTRY_PREFIX}/casera-worker:latest"
docker push "${REGISTRY_PREFIX}/casera-admin:latest"
fi
else
warn "SKIP_BUILD=1 set. Using prebuilt images for tag: ${DEPLOY_TAG}"
fi
DEPLOY_ID_RAW="${DEPLOY_TAG}-$(date +%Y%m%d%H%M%S)"
DEPLOY_ID="$(printf '%s' "${DEPLOY_ID_RAW}" | tr -c 'a-zA-Z0-9_-' '_')"
POSTGRES_PASSWORD_SECRET="${DEPLOY_STACK_NAME}_postgres_password_${DEPLOY_ID}"
SECRET_KEY_SECRET="${DEPLOY_STACK_NAME}_secret_key_${DEPLOY_ID}"
EMAIL_HOST_PASSWORD_SECRET="${DEPLOY_STACK_NAME}_email_host_password_${DEPLOY_ID}"
FCM_SERVER_KEY_SECRET="${DEPLOY_STACK_NAME}_fcm_server_key_${DEPLOY_ID}"
APNS_AUTH_KEY_SECRET="${DEPLOY_STACK_NAME}_apns_auth_key_${DEPLOY_ID}"
TMP_DIR="$(mktemp -d)"
cleanup() {
rm -rf "${TMP_DIR}"
}
trap cleanup EXIT
cp "${STACK_TEMPLATE}" "${TMP_DIR}/swarm-stack.prod.yml"
cp "${PROD_ENV}" "${TMP_DIR}/prod.env"
cp "${REGISTRY_ENV}" "${TMP_DIR}/registry.env"
mkdir -p "${TMP_DIR}/secrets"
cp "${SECRET_POSTGRES}" "${TMP_DIR}/secrets/postgres_password.txt"
cp "${SECRET_APP_KEY}" "${TMP_DIR}/secrets/secret_key.txt"
cp "${SECRET_EMAIL_PASS}" "${TMP_DIR}/secrets/email_host_password.txt"
cp "${SECRET_FCM_KEY}" "${TMP_DIR}/secrets/fcm_server_key.txt"
cp "${SECRET_APNS_KEY}" "${TMP_DIR}/secrets/apns_auth_key.p8"
cat > "${TMP_DIR}/runtime.env" <<EOF
API_IMAGE=${API_IMAGE}
WORKER_IMAGE=${WORKER_IMAGE}
ADMIN_IMAGE=${ADMIN_IMAGE}
API_REPLICAS=${API_REPLICAS:-3}
WORKER_REPLICAS=${WORKER_REPLICAS:-2}
ADMIN_REPLICAS=${ADMIN_REPLICAS:-1}
API_PORT=${API_PORT:-8000}
ADMIN_PORT=${ADMIN_PORT:-3000}
DOZZLE_PORT=${DOZZLE_PORT:-9999}
POSTGRES_PASSWORD_SECRET=${POSTGRES_PASSWORD_SECRET}
SECRET_KEY_SECRET=${SECRET_KEY_SECRET}
EMAIL_HOST_PASSWORD_SECRET=${EMAIL_HOST_PASSWORD_SECRET}
FCM_SERVER_KEY_SECRET=${FCM_SERVER_KEY_SECRET}
APNS_AUTH_KEY_SECRET=${APNS_AUTH_KEY_SECRET}
EOF
log "Uploading deploy bundle to ${SSH_TARGET}:${DEPLOY_REMOTE_DIR}"
ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "mkdir -p '${DEPLOY_REMOTE_DIR}/secrets'"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/swarm-stack.prod.yml" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/swarm-stack.prod.yml"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/prod.env" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/prod.env"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/registry.env" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/registry.env"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/runtime.env" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/runtime.env"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/postgres_password.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/postgres_password.txt"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/secret_key.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/secret_key.txt"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/email_host_password.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/email_host_password.txt"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/fcm_server_key.txt" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/fcm_server_key.txt"
scp "${SCP_OPTS[@]}" "${TMP_DIR}/secrets/apns_auth_key.p8" "${SSH_TARGET}:${DEPLOY_REMOTE_DIR}/secrets/apns_auth_key.p8"
log "Creating Docker secrets and deploying stack on manager"
ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "bash -s -- '${DEPLOY_REMOTE_DIR}' '${DEPLOY_STACK_NAME}'" <<'EOF'
set -euo pipefail
REMOTE_DIR="$1"
STACK_NAME="$2"
set -a
# shellcheck disable=SC1090
source "${REMOTE_DIR}/registry.env"
# shellcheck disable=SC1090
source "${REMOTE_DIR}/prod.env"
# shellcheck disable=SC1090
source "${REMOTE_DIR}/runtime.env"
set +a
create_secret() {
local name="$1"
local src="$2"
if docker secret inspect "${name}" >/dev/null 2>&1; then
echo "[remote] secret exists: ${name}"
else
docker secret create "${name}" "${src}" >/dev/null
echo "[remote] created secret: ${name}"
fi
}
printf '%s' "${REGISTRY_TOKEN}" | docker login "${REGISTRY}" -u "${REGISTRY_USERNAME}" --password-stdin >/dev/null
rm -f "${REMOTE_DIR}/registry.env"
create_secret "${POSTGRES_PASSWORD_SECRET}" "${REMOTE_DIR}/secrets/postgres_password.txt"
create_secret "${SECRET_KEY_SECRET}" "${REMOTE_DIR}/secrets/secret_key.txt"
create_secret "${EMAIL_HOST_PASSWORD_SECRET}" "${REMOTE_DIR}/secrets/email_host_password.txt"
create_secret "${FCM_SERVER_KEY_SECRET}" "${REMOTE_DIR}/secrets/fcm_server_key.txt"
create_secret "${APNS_AUTH_KEY_SECRET}" "${REMOTE_DIR}/secrets/apns_auth_key.p8"
rm -f "${REMOTE_DIR}/secrets/postgres_password.txt"
rm -f "${REMOTE_DIR}/secrets/secret_key.txt"
rm -f "${REMOTE_DIR}/secrets/email_host_password.txt"
rm -f "${REMOTE_DIR}/secrets/fcm_server_key.txt"
rm -f "${REMOTE_DIR}/secrets/apns_auth_key.p8"
set -a
# shellcheck disable=SC1090
source "${REMOTE_DIR}/prod.env"
# shellcheck disable=SC1090
source "${REMOTE_DIR}/runtime.env"
set +a
docker stack deploy --with-registry-auth -c "${REMOTE_DIR}/swarm-stack.prod.yml" "${STACK_NAME}"
EOF
log "Waiting for stack convergence (${DEPLOY_WAIT_SECONDS}s max)"
start_epoch="$(date +%s)"
while true; do
services="$(ssh "${SSH_OPTS[@]}" "${SSH_TARGET}" "docker stack services '${DEPLOY_STACK_NAME}' --format '{{.Name}} {{.Replicas}}'" 2>/dev/null || true)"
if [[ -n "${services}" ]]; then
all_ready=1
while IFS=' ' read -r svc replicas; do
[[ -z "${svc}" ]] && continue
current="${replicas%%/*}"
desired="${replicas##*/}"
if [[ "${desired}" == "0" ]]; then
continue
fi
if [[ "${current}" != "${desired}" ]]; then
all_ready=0
fi
done <<< "${services}"
if [[ "${all_ready}" -eq 1 ]]; then
break
fi
fi
now_epoch="$(date +%s)"
elapsed=$((now_epoch - start_epoch))
if (( elapsed >= DEPLOY_WAIT_SECONDS )); then
die "Timed out waiting for stack to converge. Check: ssh ${SSH_TARGET} docker stack services ${DEPLOY_STACK_NAME}"
fi
sleep 10
done
if [[ "${SKIP_HEALTHCHECK}" != "1" && -n "${DEPLOY_HEALTHCHECK_URL:-}" ]]; then
log "Running health check: ${DEPLOY_HEALTHCHECK_URL}"
curl -fsS --max-time 20 "${DEPLOY_HEALTHCHECK_URL}" >/dev/null
fi
log "Deploy completed successfully."
log "Stack: ${DEPLOY_STACK_NAME}"
log "Images:"
log " ${API_IMAGE}"
log " ${WORKER_IMAGE}"
log " ${ADMIN_IMAGE}"

11
deploy/secrets/README.md Normal file
View File

@@ -0,0 +1,11 @@
# Secrets Directory
Create these files (copy from `.example` files):
- `deploy/secrets/postgres_password.txt`
- `deploy/secrets/secret_key.txt`
- `deploy/secrets/email_host_password.txt`
- `deploy/secrets/fcm_server_key.txt`
- `deploy/secrets/apns_auth_key.p8`
These are consumed by `./.deploy_prod` and converted into Docker Swarm secrets.

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
CHANGEME_APNS_PRIVATE_KEY
-----END PRIVATE KEY-----

View File

@@ -0,0 +1 @@
CHANGEME_SMTP_PASSWORD

View File

@@ -0,0 +1 @@
CHANGEME_FCM_SERVER_KEY

View File

@@ -0,0 +1 @@
CHANGEME_DATABASE_PASSWORD

View File

@@ -0,0 +1 @@
CHANGEME_SECRET_KEY_MIN_32_CHARS

View File

@@ -0,0 +1,67 @@
# Shit Deploy Can't Do
This is everything `./.deploy_prod` cannot safely automate for you.
## 1. Create Infrastructure
Step:
Create Hetzner servers, networking, and load balancer.
Reason:
The script only deploys app workloads. It cannot create paid cloud resources without cloud API credentials and IaC wiring.
## 2. Join Nodes To Swarm
Step:
Run `docker swarm init` on the first manager and `docker swarm join` on other nodes.
Reason:
Joining nodes requires one-time bootstrap tokens and host-level control.
## 3. Configure Firewall And Origin Restrictions
Step:
Set firewall rules so only expected ingress paths can reach your nodes.
Reason:
Firewall policies live in provider networking controls, outside this repo.
## 4. Configure DNS / Cloudflare
Step:
Point DNS at LB, enable proxying, set SSL mode, and lock down origin access.
Reason:
DNS and CDN settings are account-level operations in Cloudflare, not deploy-time app actions.
## 5. Configure External Services
Step:
Create and configure Neon, B2, email provider, APNS, and FCM credentials.
Reason:
These credentials are issued in vendor dashboards and must be manually generated/rotated.
## 6. Seed SSH Trust
Step:
Ensure your local machine can SSH to the manager with the key in `deploy/cluster.env`.
Reason:
The script assumes SSH already works; it cannot grant itself SSH access.
## 7. First-Time Smoke Testing Beyond `/api/health/`
Step:
Manually test login, push, background jobs, and admin panel flows after first deploy.
Reason:
Automated health checks prove container readiness, not end-to-end business behavior.
## 8. Safe Secret Garbage Collection
Step:
Periodically remove old versioned Docker secrets that are no longer referenced.
Reason:
This deploy script creates versioned secrets for safe rollouts and does not auto-delete old ones to avoid breaking running services.

288
deploy/swarm-stack.prod.yml Normal file
View File

@@ -0,0 +1,288 @@
version: "3.8"
services:
redis:
image: redis:7-alpine
command: redis-server --appendonly yes --appendfsync everysec
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
deploy:
replicas: 1
restart_policy:
condition: any
delay: 5s
placement:
max_replicas_per_node: 1
networks:
- casera-network
api:
image: ${API_IMAGE}
ports:
- target: 8000
published: ${API_PORT}
protocol: tcp
mode: ingress
environment:
PORT: "8000"
DEBUG: "${DEBUG}"
ALLOWED_HOSTS: "${ALLOWED_HOSTS}"
CORS_ALLOWED_ORIGINS: "${CORS_ALLOWED_ORIGINS}"
TIMEZONE: "${TIMEZONE}"
BASE_URL: "${BASE_URL}"
ADMIN_PANEL_URL: "${ADMIN_PANEL_URL}"
DB_HOST: "${DB_HOST}"
DB_PORT: "${DB_PORT}"
POSTGRES_USER: "${POSTGRES_USER}"
POSTGRES_DB: "${POSTGRES_DB}"
DB_SSLMODE: "${DB_SSLMODE}"
DB_MAX_OPEN_CONNS: "${DB_MAX_OPEN_CONNS}"
DB_MAX_IDLE_CONNS: "${DB_MAX_IDLE_CONNS}"
DB_MAX_LIFETIME: "${DB_MAX_LIFETIME}"
REDIS_URL: "${REDIS_URL}"
REDIS_DB: "${REDIS_DB}"
EMAIL_HOST: "${EMAIL_HOST}"
EMAIL_PORT: "${EMAIL_PORT}"
EMAIL_HOST_USER: "${EMAIL_HOST_USER}"
DEFAULT_FROM_EMAIL: "${DEFAULT_FROM_EMAIL}"
EMAIL_USE_TLS: "${EMAIL_USE_TLS}"
APNS_AUTH_KEY_PATH: "/run/secrets/apns_auth_key"
APNS_AUTH_KEY_ID: "${APNS_AUTH_KEY_ID}"
APNS_TEAM_ID: "${APNS_TEAM_ID}"
APNS_TOPIC: "${APNS_TOPIC}"
APNS_USE_SANDBOX: "${APNS_USE_SANDBOX}"
APNS_PRODUCTION: "${APNS_PRODUCTION}"
STORAGE_UPLOAD_DIR: "${STORAGE_UPLOAD_DIR}"
STORAGE_BASE_URL: "${STORAGE_BASE_URL}"
STORAGE_MAX_FILE_SIZE: "${STORAGE_MAX_FILE_SIZE}"
STORAGE_ALLOWED_TYPES: "${STORAGE_ALLOWED_TYPES}"
FEATURE_PUSH_ENABLED: "${FEATURE_PUSH_ENABLED}"
FEATURE_EMAIL_ENABLED: "${FEATURE_EMAIL_ENABLED}"
FEATURE_WEBHOOKS_ENABLED: "${FEATURE_WEBHOOKS_ENABLED}"
FEATURE_ONBOARDING_EMAILS_ENABLED: "${FEATURE_ONBOARDING_EMAILS_ENABLED}"
FEATURE_PDF_REPORTS_ENABLED: "${FEATURE_PDF_REPORTS_ENABLED}"
FEATURE_WORKER_ENABLED: "${FEATURE_WORKER_ENABLED}"
APPLE_CLIENT_ID: "${APPLE_CLIENT_ID}"
APPLE_TEAM_ID: "${APPLE_TEAM_ID}"
GOOGLE_CLIENT_ID: "${GOOGLE_CLIENT_ID}"
GOOGLE_ANDROID_CLIENT_ID: "${GOOGLE_ANDROID_CLIENT_ID}"
GOOGLE_IOS_CLIENT_ID: "${GOOGLE_IOS_CLIENT_ID}"
APPLE_IAP_KEY_PATH: "${APPLE_IAP_KEY_PATH}"
APPLE_IAP_KEY_ID: "${APPLE_IAP_KEY_ID}"
APPLE_IAP_ISSUER_ID: "${APPLE_IAP_ISSUER_ID}"
APPLE_IAP_BUNDLE_ID: "${APPLE_IAP_BUNDLE_ID}"
APPLE_IAP_SANDBOX: "${APPLE_IAP_SANDBOX}"
GOOGLE_IAP_SERVICE_ACCOUNT_PATH: "${GOOGLE_IAP_SERVICE_ACCOUNT_PATH}"
GOOGLE_IAP_PACKAGE_NAME: "${GOOGLE_IAP_PACKAGE_NAME}"
command:
- /bin/sh
- -lc
- |
set -eu
export POSTGRES_PASSWORD="$$(cat /run/secrets/postgres_password)"
export SECRET_KEY="$$(cat /run/secrets/secret_key)"
export EMAIL_HOST_PASSWORD="$$(cat /run/secrets/email_host_password)"
export FCM_SERVER_KEY="$$(cat /run/secrets/fcm_server_key)"
exec /app/api
secrets:
- source: ${POSTGRES_PASSWORD_SECRET}
target: postgres_password
- source: ${SECRET_KEY_SECRET}
target: secret_key
- source: ${EMAIL_HOST_PASSWORD_SECRET}
target: email_host_password
- source: ${FCM_SERVER_KEY_SECRET}
target: fcm_server_key
- source: ${APNS_AUTH_KEY_SECRET}
target: apns_auth_key
volumes:
- uploads:/app/uploads
healthcheck:
test: ["CMD", "curl", "-f", "http://127.0.0.1:8000/api/health/"]
interval: 30s
timeout: 10s
start_period: 15s
retries: 3
deploy:
replicas: ${API_REPLICAS}
restart_policy:
condition: any
delay: 5s
update_config:
parallelism: 1
delay: 10s
order: start-first
rollback_config:
parallelism: 1
delay: 5s
order: stop-first
networks:
- casera-network
admin:
image: ${ADMIN_IMAGE}
ports:
- target: 3000
published: ${ADMIN_PORT}
protocol: tcp
mode: ingress
environment:
PORT: "3000"
HOSTNAME: "0.0.0.0"
NEXT_PUBLIC_API_URL: "${NEXT_PUBLIC_API_URL}"
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://127.0.0.1:3000/admin/"]
interval: 30s
timeout: 10s
retries: 3
deploy:
replicas: ${ADMIN_REPLICAS}
restart_policy:
condition: any
delay: 5s
update_config:
parallelism: 1
delay: 10s
order: start-first
rollback_config:
parallelism: 1
delay: 5s
order: stop-first
networks:
- casera-network
worker:
image: ${WORKER_IMAGE}
environment:
DB_HOST: "${DB_HOST}"
DB_PORT: "${DB_PORT}"
POSTGRES_USER: "${POSTGRES_USER}"
POSTGRES_DB: "${POSTGRES_DB}"
DB_SSLMODE: "${DB_SSLMODE}"
DB_MAX_OPEN_CONNS: "${DB_MAX_OPEN_CONNS}"
DB_MAX_IDLE_CONNS: "${DB_MAX_IDLE_CONNS}"
DB_MAX_LIFETIME: "${DB_MAX_LIFETIME}"
REDIS_URL: "${REDIS_URL}"
REDIS_DB: "${REDIS_DB}"
EMAIL_HOST: "${EMAIL_HOST}"
EMAIL_PORT: "${EMAIL_PORT}"
EMAIL_HOST_USER: "${EMAIL_HOST_USER}"
DEFAULT_FROM_EMAIL: "${DEFAULT_FROM_EMAIL}"
EMAIL_USE_TLS: "${EMAIL_USE_TLS}"
APNS_AUTH_KEY_PATH: "/run/secrets/apns_auth_key"
APNS_AUTH_KEY_ID: "${APNS_AUTH_KEY_ID}"
APNS_TEAM_ID: "${APNS_TEAM_ID}"
APNS_TOPIC: "${APNS_TOPIC}"
APNS_USE_SANDBOX: "${APNS_USE_SANDBOX}"
APNS_PRODUCTION: "${APNS_PRODUCTION}"
TASK_REMINDER_HOUR: "${TASK_REMINDER_HOUR}"
OVERDUE_REMINDER_HOUR: "${OVERDUE_REMINDER_HOUR}"
DAILY_DIGEST_HOUR: "${DAILY_DIGEST_HOUR}"
FEATURE_PUSH_ENABLED: "${FEATURE_PUSH_ENABLED}"
FEATURE_EMAIL_ENABLED: "${FEATURE_EMAIL_ENABLED}"
FEATURE_WEBHOOKS_ENABLED: "${FEATURE_WEBHOOKS_ENABLED}"
FEATURE_ONBOARDING_EMAILS_ENABLED: "${FEATURE_ONBOARDING_EMAILS_ENABLED}"
FEATURE_PDF_REPORTS_ENABLED: "${FEATURE_PDF_REPORTS_ENABLED}"
FEATURE_WORKER_ENABLED: "${FEATURE_WORKER_ENABLED}"
command:
- /bin/sh
- -lc
- |
set -eu
export POSTGRES_PASSWORD="$$(cat /run/secrets/postgres_password)"
export SECRET_KEY="$$(cat /run/secrets/secret_key)"
export EMAIL_HOST_PASSWORD="$$(cat /run/secrets/email_host_password)"
export FCM_SERVER_KEY="$$(cat /run/secrets/fcm_server_key)"
exec /app/worker
secrets:
- source: ${POSTGRES_PASSWORD_SECRET}
target: postgres_password
- source: ${SECRET_KEY_SECRET}
target: secret_key
- source: ${EMAIL_HOST_PASSWORD_SECRET}
target: email_host_password
- source: ${FCM_SERVER_KEY_SECRET}
target: fcm_server_key
- source: ${APNS_AUTH_KEY_SECRET}
target: apns_auth_key
deploy:
replicas: ${WORKER_REPLICAS}
restart_policy:
condition: any
delay: 5s
update_config:
parallelism: 1
delay: 10s
order: start-first
rollback_config:
parallelism: 1
delay: 5s
order: stop-first
networks:
- casera-network
dozzle:
image: amir20/dozzle:latest
ports:
- target: 8080
published: ${DOZZLE_PORT}
protocol: tcp
mode: ingress
environment:
DOZZLE_NO_ANALYTICS: "true"
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
deploy:
replicas: 1
restart_policy:
condition: any
delay: 5s
placement:
constraints:
- node.role == manager
networks:
- casera-network
volumes:
redis_data:
uploads:
networks:
casera-network:
driver: overlay
driver_opts:
encrypted: "true"
secrets:
postgres_password:
external: true
name: ${POSTGRES_PASSWORD_SECRET}
secret_key:
external: true
name: ${SECRET_KEY_SECRET}
email_host_password:
external: true
name: ${EMAIL_HOST_PASSWORD_SECRET}
fcm_server_key:
external: true
name: ${FCM_SERVER_KEY_SECRET}
apns_auth_key:
external: true
name: ${APNS_AUTH_KEY_SECRET}

1527
docs/AUDIT_FINDINGS.md Normal file

File diff suppressed because it is too large Load Diff

37
hardening-report.md Normal file
View File

@@ -0,0 +1,37 @@
# Go Backend Hardening Audit Report
## Audit Sources
- 9 mapper agents (100% file coverage)
- 8 specialized domain auditors (parallel)
- 1 cross-cutting deep audit (parallel)
- Total source files: 136 (excluding 27 test files)
---
## CRITICAL — Will crash or lose data
## BUG — Incorrect behavior
## SILENT FAILURE — Error swallowed or ignored
## RACE CONDITION — Concurrency issue
## LOGIC ERROR — Code doesn't match intent
## PERFORMANCE — Unnecessary cost
## SECURITY — Vulnerability or exposure
## AUTHORIZATION — Access control gap
## DATA INTEGRITY — GORM / database issue
## API CONTRACT — Request/response issue
## ARCHITECTURE — Layer or pattern violation
## FRAGILE — Works now but will break easily
---
## Summary

View File

@@ -1,5 +1,7 @@
package dto
import "github.com/treytartt/casera-api/internal/middleware"
// PaginationParams holds pagination query parameters
type PaginationParams struct {
Page int `form:"page" validate:"omitempty,min=1"`
@@ -41,6 +43,12 @@ func (p *PaginationParams) GetSortDir() string {
return "DESC"
}
// GetSafeSortBy validates SortBy against an allowlist to prevent SQL injection.
// Returns the matching allowed column, or defaultCol if SortBy is empty or not allowed.
func (p *PaginationParams) GetSafeSortBy(allowedCols []string, defaultCol string) string {
return middleware.SanitizeSortColumn(p.SortBy, allowedCols, defaultCol)
}
// UserFilters holds user-specific filter parameters
type UserFilters struct {
PaginationParams

View File

@@ -0,0 +1,199 @@
package handlers
import (
"html"
"testing"
"github.com/stretchr/testify/assert"
"github.com/treytartt/casera-api/internal/admin/dto"
)
func TestAdminSortBy_ValidColumn_Works(t *testing.T) {
tests := []struct {
name string
sortBy string
allowlist []string
defaultCol string
expected string
}{
{
name: "exact match returns column",
sortBy: "created_at",
allowlist: []string{"id", "created_at", "updated_at", "name"},
defaultCol: "created_at",
expected: "created_at",
},
{
name: "case insensitive match returns canonical column",
sortBy: "Created_At",
allowlist: []string{"id", "created_at", "updated_at", "name"},
defaultCol: "created_at",
expected: "created_at",
},
{
name: "different valid column",
sortBy: "name",
allowlist: []string{"id", "created_at", "updated_at", "name"},
defaultCol: "created_at",
expected: "name",
},
{
name: "date_joined for user handler",
sortBy: "date_joined",
allowlist: []string{"id", "username", "email", "date_joined", "last_login", "is_active"},
defaultCol: "date_joined",
expected: "date_joined",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := dto.PaginationParams{SortBy: tt.sortBy}
result := p.GetSafeSortBy(tt.allowlist, tt.defaultCol)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAdminSortBy_SQLInjection_ReturnsDefault(t *testing.T) {
allowlist := []string{"id", "created_at", "updated_at", "name"}
defaultCol := "created_at"
tests := []struct {
name string
sortBy string
}{
{
name: "SQL injection with DROP TABLE",
sortBy: "created_at; DROP TABLE users; --",
},
{
name: "SQL injection with UNION SELECT",
sortBy: "id UNION SELECT password FROM auth_user",
},
{
name: "SQL injection with subquery",
sortBy: "(SELECT password FROM auth_user LIMIT 1)",
},
{
name: "SQL injection with comment",
sortBy: "created_at--",
},
{
name: "SQL injection with semicolon",
sortBy: "created_at;",
},
{
name: "SQL injection with OR 1=1",
sortBy: "created_at OR 1=1",
},
{
name: "column not in allowlist",
sortBy: "password",
},
{
name: "SQL injection with single quotes",
sortBy: "name'; DROP TABLE users; --",
},
{
name: "SQL injection with double dashes",
sortBy: "id -- comment",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := dto.PaginationParams{SortBy: tt.sortBy}
result := p.GetSafeSortBy(allowlist, defaultCol)
assert.Equal(t, defaultCol, result, "SQL injection attempt should return default column")
})
}
}
func TestAdminSortBy_EmptyString_ReturnsDefault(t *testing.T) {
tests := []struct {
name string
sortBy string
defaultCol string
}{
{
name: "empty string returns default",
sortBy: "",
defaultCol: "created_at",
},
{
name: "whitespace only returns default",
sortBy: " ",
defaultCol: "created_at",
},
{
name: "tab only returns default",
sortBy: "\t",
defaultCol: "date_joined",
},
{
name: "different default column",
sortBy: "",
defaultCol: "completed_at",
},
}
allowlist := []string{"id", "created_at", "updated_at", "name"}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := dto.PaginationParams{SortBy: tt.sortBy}
result := p.GetSafeSortBy(allowlist, tt.defaultCol)
assert.Equal(t, tt.defaultCol, result)
})
}
}
func TestSendEmail_XSSEscaped(t *testing.T) {
// SEC-22: Subject and Body must be HTML-escaped before insertion into email template.
// This tests the html.EscapeString behavior that the handler relies on.
tests := []struct {
name string
input string
expected string
}{
{
name: "script tag in subject",
input: `<script>alert("xss")</script>`,
expected: `&lt;script&gt;alert(&#34;xss&#34;)&lt;/script&gt;`,
},
{
name: "img onerror payload",
input: `<img src=x onerror=alert(1)>`,
expected: `&lt;img src=x onerror=alert(1)&gt;`,
},
{
name: "ampersand and angle brackets",
input: `Tom & Jerry <bros>`,
expected: `Tom &amp; Jerry &lt;bros&gt;`,
},
{
name: "plain text passes through",
input: "Hello World",
expected: "Hello World",
},
{
name: "single quotes",
input: `It's a 'test'`,
expected: `It&#39;s a &#39;test&#39;`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
escaped := html.EscapeString(tt.input)
assert.Equal(t, tt.expected, escaped)
// Verify the escaped output does NOT contain raw angle brackets from the input
if tt.input != tt.expected {
assert.NotContains(t, escaped, "<script>")
assert.NotContains(t, escaped, "<img")
}
})
}
}

View File

@@ -80,11 +80,11 @@ func (h *AdminUserManagementHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "email", "first_name", "last_name",
"role", "is_active", "created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -63,11 +63,11 @@ func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "apple_id", "email", "is_private_email",
"created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -55,11 +55,10 @@ func (h *AdminAuthTokenHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"created", "user_id",
}, "created")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -96,11 +96,11 @@ func (h *AdminCompletionHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "completed_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "task_id", "completed_by_id", "completed_at",
"created_at", "notes", "actual_cost", "rating",
}, "completed_at")
sortDir := "DESC"
if filters.SortDir != "" {
sortDir = filters.GetSortDir()

View File

@@ -78,11 +78,10 @@ func (h *AdminCompletionImageHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "completion_id", "created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -58,11 +58,10 @@ func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "created_at", "expires_at", "is_used",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -59,11 +59,11 @@ func (h *AdminContractorHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "name", "company", "email", "phone", "city",
"created_at", "updated_at", "is_active", "is_favorite", "rating",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -70,10 +70,10 @@ func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
query.Count(&total)
sortBy := "date_created"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "name", "active", "user_id", "device_id", "date_created",
}, "date_created")
query = query.Order(sortBy + " " + filters.GetSortDir())
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
@@ -125,10 +125,10 @@ func (h *AdminDeviceHandler) ListGCM(c echo.Context) error {
query.Count(&total)
sortBy := "date_created"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "name", "active", "user_id", "device_id", "cloud_message_type", "date_created",
}, "date_created")
query = query.Order(sortBy + " " + filters.GetSortDir())
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())

View File

@@ -61,11 +61,11 @@ func (h *AdminDocumentHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "title", "created_at", "updated_at", "document_type",
"residence_id", "is_active", "expiry_date", "vendor",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -79,11 +79,10 @@ func (h *AdminDocumentImageHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "document_id", "created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -52,10 +52,10 @@ func (h *AdminFeatureBenefitHandler) List(c echo.Context) error {
query.Count(&total)
sortBy := "display_order"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "feature_name", "display_order", "is_active", "created_at", "updated_at",
}, "display_order")
query = query.Order(sortBy + " " + filters.GetSortDir())
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())

View File

@@ -29,6 +29,8 @@ func NewAdminLookupHandler(db *gorm.DB) *AdminLookupHandler {
func (h *AdminLookupHandler) refreshCategoriesCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping categories cache refresh")
return
}
var categories []models.TaskCategory
@@ -49,6 +51,8 @@ func (h *AdminLookupHandler) refreshCategoriesCache(ctx context.Context) {
func (h *AdminLookupHandler) refreshPrioritiesCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping priorities cache refresh")
return
}
var priorities []models.TaskPriority
@@ -69,6 +73,8 @@ func (h *AdminLookupHandler) refreshPrioritiesCache(ctx context.Context) {
func (h *AdminLookupHandler) refreshFrequenciesCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping frequencies cache refresh")
return
}
var frequencies []models.TaskFrequency
@@ -89,6 +95,8 @@ func (h *AdminLookupHandler) refreshFrequenciesCache(ctx context.Context) {
func (h *AdminLookupHandler) refreshResidenceTypesCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping residence types cache refresh")
return
}
var types []models.ResidenceType
@@ -109,6 +117,8 @@ func (h *AdminLookupHandler) refreshResidenceTypesCache(ctx context.Context) {
func (h *AdminLookupHandler) refreshSpecialtiesCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping specialties cache refresh")
return
}
var specialties []models.ContractorSpecialty
@@ -130,6 +140,8 @@ func (h *AdminLookupHandler) refreshSpecialtiesCache(ctx context.Context) {
func (h *AdminLookupHandler) invalidateSeededDataCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping seeded data cache invalidation")
return
}
if err := cache.InvalidateSeededData(ctx); err != nil {

View File

@@ -2,6 +2,7 @@ package handlers
import (
"context"
"html"
"net/http"
"strconv"
"time"
@@ -67,11 +68,11 @@ func (h *AdminNotificationHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "created_at", "updated_at", "user_id",
"notification_type", "sent", "read", "title",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination
@@ -347,16 +348,20 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
return c.JSON(http.StatusServiceUnavailable, map[string]interface{}{"error": "Email service not configured"})
}
// HTML-escape user-supplied values to prevent XSS via email content
escapedSubject := html.EscapeString(req.Subject)
escapedBody := html.EscapeString(req.Body)
// Create HTML body with basic styling
htmlBody := `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>` + req.Subject + `</title>
<title>` + escapedSubject + `</title>
</head>
<body style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto; padding: 20px;">
<h2 style="color: #333;">` + req.Subject + `</h2>
<div style="color: #666; line-height: 1.6;">` + req.Body + `</div>
<h2 style="color: #333;">` + escapedSubject + `</h2>
<div style="color: #666; line-height: 1.6;">` + escapedBody + `</div>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #999; font-size: 12px;">This is a test email sent from Casera Admin Panel.</p>
</body>

View File

@@ -76,11 +76,10 @@ func (h *AdminNotificationPrefsHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -60,11 +60,10 @@ func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "created_at", "expires_at", "used",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -56,10 +56,11 @@ func (h *AdminPromotionHandler) List(c echo.Context) error {
query.Count(&total)
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "promotion_id", "title", "start_date", "end_date",
"target_tier", "is_active", "created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())

View File

@@ -58,11 +58,11 @@ func (h *AdminResidenceHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "name", "created_at", "updated_at", "owner_id",
"city", "state_province", "country", "is_active", "is_primary",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -62,11 +62,11 @@ func (h *AdminShareCodeHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "residence_id", "code", "created_by_id",
"is_active", "expires_at", "created_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination
@@ -153,13 +153,17 @@ func (h *AdminShareCodeHandler) Update(c echo.Context) error {
}
var req struct {
IsActive bool `json:"is_active"`
IsActive *bool `json:"is_active"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
}
code.IsActive = req.IsActive
// Only update IsActive when explicitly provided (non-nil).
// Using *bool prevents a missing field from defaulting to false.
if req.IsActive != nil {
code.IsActive = *req.IsActive
}
if err := h.db.Save(&code).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update share code"})

View File

@@ -65,11 +65,11 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "created_at", "updated_at", "user_id",
"tier", "platform", "auto_renew", "expires_at", "subscribed_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -68,11 +68,12 @@ func (h *AdminTaskHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "title", "created_at", "updated_at", "due_date", "next_due_date",
"residence_id", "category_id", "priority_id", "in_progress",
"is_cancelled", "is_archived", "estimated_cost", "actual_cost",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -56,11 +56,11 @@ func (h *AdminUserHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "date_joined"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "username", "email", "first_name", "last_name",
"date_joined", "last_login", "is_active", "is_staff", "is_superuser",
}, "date_joined")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -69,11 +69,10 @@ func (h *AdminUserProfileHandler) List(c echo.Context) error {
// Get total count
query.Count(&total)
// Apply sorting
sortBy := "created_at"
if filters.SortBy != "" {
sortBy = filters.SortBy
}
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "verified", "created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination

View File

@@ -338,10 +338,14 @@ func validate(cfg *Config) error {
// In debug mode, use a default key with a warning for local development
cfg.Security.SecretKey = "change-me-in-production-secret-key-12345"
fmt.Println("WARNING: SECRET_KEY not set, using default (debug mode only)")
fmt.Println("WARNING: *** DO NOT USE THIS DEFAULT KEY IN PRODUCTION ***")
} else {
// In production, refuse to start without a proper secret key
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
}
} else if cfg.Security.SecretKey == "change-me-in-production-secret-key-12345" {
// Warn if someone explicitly set the well-known debug key
fmt.Println("WARNING: SECRET_KEY is set to the well-known debug default. Change it for production use.")
}
// Database password might come from DATABASE_URL, don't require it separately

View File

@@ -369,17 +369,13 @@ func migrateGoAdmin() error {
}
db.Exec(`CREATE INDEX IF NOT EXISTS idx_goadmin_site_key ON goadmin_site(key)`)
// Seed default admin user (password: admin - bcrypt hash)
// Seed default admin user only on first run (ON CONFLICT DO NOTHING).
// Password is NOT reset on subsequent migrations to preserve operator changes.
db.Exec(`
INSERT INTO goadmin_users (username, password, name, avatar)
VALUES ('admin', '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm', 'Administrator', '')
ON CONFLICT DO NOTHING
`)
// Update existing admin password if it exists with wrong hash
db.Exec(`
UPDATE goadmin_users SET password = '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm'
WHERE username = 'admin'
`)
// Seed default roles
db.Exec(`INSERT INTO goadmin_roles (name, slug) VALUES ('Administrator', 'administrator') ON CONFLICT DO NOTHING`)
@@ -443,8 +439,8 @@ func migrateGoAdmin() error {
log.Info().Msg("GoAdmin migrations completed")
// Seed default Next.js admin user (email: admin@mycrib.com, password: admin123)
// bcrypt hash for "admin123": $2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O
// Seed default Next.js admin user only on first run.
// Password is NOT reset on subsequent migrations to preserve operator changes.
var adminCount int64
db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = 'admin@mycrib.com'`).Scan(&adminCount)
if adminCount == 0 {
@@ -453,14 +449,7 @@ func migrateGoAdmin() error {
INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at)
VALUES ('admin@mycrib.com', '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O', 'Admin', 'User', 'super_admin', true, NOW(), NOW())
`)
log.Info().Msg("Default admin user created: admin@mycrib.com / admin123")
} else {
// Update existing admin password if needed
db.Exec(`
UPDATE admin_users SET password = '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O'
WHERE email = 'admin@mycrib.com'
`)
log.Info().Msg("Updated admin@mycrib.com password to admin123")
log.Info().Msg("Default admin user created: admin@mycrib.com")
}
return nil

View File

@@ -8,13 +8,13 @@ type CreateContractorRequest struct {
Phone string `json:"phone" validate:"max=20"`
Email string `json:"email" validate:"omitempty,email,max=254"`
Website string `json:"website" validate:"max=200"`
Notes string `json:"notes"`
Notes string `json:"notes" validate:"max=10000"`
StreetAddress string `json:"street_address" validate:"max=255"`
City string `json:"city" validate:"max=100"`
StateProvince string `json:"state_province" validate:"max=100"`
PostalCode string `json:"postal_code" validate:"max=20"`
SpecialtyIDs []uint `json:"specialty_ids"`
Rating *float64 `json:"rating"`
SpecialtyIDs []uint `json:"specialty_ids" validate:"omitempty,max=20"`
Rating *float64 `json:"rating" validate:"omitempty,min=0,max=5"`
IsFavorite *bool `json:"is_favorite"`
}
@@ -25,13 +25,13 @@ type UpdateContractorRequest struct {
Phone *string `json:"phone" validate:"omitempty,max=20"`
Email *string `json:"email" validate:"omitempty,email,max=254"`
Website *string `json:"website" validate:"omitempty,max=200"`
Notes *string `json:"notes"`
Notes *string `json:"notes" validate:"omitempty,max=10000"`
StreetAddress *string `json:"street_address" validate:"omitempty,max=255"`
City *string `json:"city" validate:"omitempty,max=100"`
StateProvince *string `json:"state_province" validate:"omitempty,max=100"`
PostalCode *string `json:"postal_code" validate:"omitempty,max=20"`
SpecialtyIDs []uint `json:"specialty_ids"`
Rating *float64 `json:"rating"`
SpecialtyIDs []uint `json:"specialty_ids" validate:"omitempty,max=20"`
Rating *float64 `json:"rating" validate:"omitempty,min=0,max=5"`
IsFavorite *bool `json:"is_favorite"`
ResidenceID *uint `json:"residence_id"`
}

View File

@@ -12,11 +12,11 @@ import (
type CreateDocumentRequest struct {
ResidenceID uint `json:"residence_id" validate:"required"`
Title string `json:"title" validate:"required,min=1,max=200"`
Description string `json:"description"`
DocumentType models.DocumentType `json:"document_type"`
Description string `json:"description" validate:"max=10000"`
DocumentType models.DocumentType `json:"document_type" validate:"omitempty,oneof=general warranty receipt contract insurance manual"`
FileURL string `json:"file_url" validate:"max=500"`
FileName string `json:"file_name" validate:"max=255"`
FileSize *int64 `json:"file_size"`
FileSize *int64 `json:"file_size" validate:"omitempty,min=0"`
MimeType string `json:"mime_type" validate:"max=100"`
PurchaseDate *time.Time `json:"purchase_date"`
ExpiryDate *time.Time `json:"expiry_date"`
@@ -25,17 +25,17 @@ type CreateDocumentRequest struct {
SerialNumber string `json:"serial_number" validate:"max=100"`
ModelNumber string `json:"model_number" validate:"max=100"`
TaskID *uint `json:"task_id"`
ImageURLs []string `json:"image_urls"` // Multiple image URLs
ImageURLs []string `json:"image_urls" validate:"omitempty,max=20,dive,max=500"` // Multiple image URLs
}
// UpdateDocumentRequest represents the request to update a document
type UpdateDocumentRequest struct {
Title *string `json:"title" validate:"omitempty,min=1,max=200"`
Description *string `json:"description"`
DocumentType *models.DocumentType `json:"document_type"`
Description *string `json:"description" validate:"omitempty,max=10000"`
DocumentType *models.DocumentType `json:"document_type" validate:"omitempty,oneof=general warranty receipt contract insurance manual"`
FileURL *string `json:"file_url" validate:"omitempty,max=500"`
FileName *string `json:"file_name" validate:"omitempty,max=255"`
FileSize *int64 `json:"file_size"`
FileSize *int64 `json:"file_size" validate:"omitempty,min=0"`
MimeType *string `json:"mime_type" validate:"omitempty,max=100"`
PurchaseDate *time.Time `json:"purchase_date"`
ExpiryDate *time.Time `json:"expiry_date"`

View File

@@ -16,12 +16,12 @@ type CreateResidenceRequest struct {
StateProvince string `json:"state_province" validate:"max=100"`
PostalCode string `json:"postal_code" validate:"max=20"`
Country string `json:"country" validate:"max=100"`
Bedrooms *int `json:"bedrooms"`
Bedrooms *int `json:"bedrooms" validate:"omitempty,min=0"`
Bathrooms *decimal.Decimal `json:"bathrooms"`
SquareFootage *int `json:"square_footage"`
SquareFootage *int `json:"square_footage" validate:"omitempty,min=0"`
LotSize *decimal.Decimal `json:"lot_size"`
YearBuilt *int `json:"year_built"`
Description string `json:"description"`
Description string `json:"description" validate:"max=10000"`
PurchaseDate *time.Time `json:"purchase_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
IsPrimary *bool `json:"is_primary"`
@@ -37,12 +37,12 @@ type UpdateResidenceRequest struct {
StateProvince *string `json:"state_province" validate:"omitempty,max=100"`
PostalCode *string `json:"postal_code" validate:"omitempty,max=20"`
Country *string `json:"country" validate:"omitempty,max=100"`
Bedrooms *int `json:"bedrooms"`
Bedrooms *int `json:"bedrooms" validate:"omitempty,min=0"`
Bathrooms *decimal.Decimal `json:"bathrooms"`
SquareFootage *int `json:"square_footage"`
SquareFootage *int `json:"square_footage" validate:"omitempty,min=0"`
LotSize *decimal.Decimal `json:"lot_size"`
YearBuilt *int `json:"year_built"`
Description *string `json:"description"`
Description *string `json:"description" validate:"omitempty,max=10000"`
PurchaseDate *time.Time `json:"purchase_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
IsPrimary *bool `json:"is_primary"`
@@ -55,5 +55,5 @@ type JoinWithCodeRequest struct {
// GenerateShareCodeRequest represents the request to generate a share code
type GenerateShareCodeRequest struct {
ExpiresInHours int `json:"expires_in_hours"` // Default: 24 hours
ExpiresInHours int `json:"expires_in_hours" validate:"omitempty,min=1"` // Default: 24 hours
}

View File

@@ -56,11 +56,11 @@ func (fd *FlexibleDate) ToTimePtr() *time.Time {
type CreateTaskRequest struct {
ResidenceID uint `json:"residence_id" validate:"required"`
Title string `json:"title" validate:"required,min=1,max=200"`
Description string `json:"description"`
Description string `json:"description" validate:"max=10000"`
CategoryID *uint `json:"category_id"`
PriorityID *uint `json:"priority_id"`
FrequencyID *uint `json:"frequency_id"`
CustomIntervalDays *int `json:"custom_interval_days"` // For "Custom" frequency, user-specified days
CustomIntervalDays *int `json:"custom_interval_days" validate:"omitempty,min=1"` // For "Custom" frequency, user-specified days
InProgress bool `json:"in_progress"`
AssignedToID *uint `json:"assigned_to_id"`
DueDate *FlexibleDate `json:"due_date"`
@@ -75,7 +75,7 @@ type UpdateTaskRequest struct {
CategoryID *uint `json:"category_id"`
PriorityID *uint `json:"priority_id"`
FrequencyID *uint `json:"frequency_id"`
CustomIntervalDays *int `json:"custom_interval_days"` // For "Custom" frequency, user-specified days
CustomIntervalDays *int `json:"custom_interval_days" validate:"omitempty,min=1"` // For "Custom" frequency, user-specified days
InProgress *bool `json:"in_progress"`
AssignedToID *uint `json:"assigned_to_id"`
DueDate *FlexibleDate `json:"due_date"`
@@ -88,18 +88,18 @@ type UpdateTaskRequest struct {
type CreateTaskCompletionRequest struct {
TaskID uint `json:"task_id" validate:"required"`
CompletedAt *time.Time `json:"completed_at"` // Defaults to now
Notes string `json:"notes"`
Notes string `json:"notes" validate:"max=10000"`
ActualCost *decimal.Decimal `json:"actual_cost"`
Rating *int `json:"rating"` // 1-5 star rating
ImageURLs []string `json:"image_urls"` // Multiple image URLs
Rating *int `json:"rating" validate:"omitempty,min=1,max=5"` // 1-5 star rating
ImageURLs []string `json:"image_urls" validate:"omitempty,max=20,dive,max=500"` // Multiple image URLs
}
// UpdateTaskCompletionRequest represents the request to update a task completion
type UpdateTaskCompletionRequest struct {
Notes *string `json:"notes"`
Notes *string `json:"notes" validate:"omitempty,max=10000"`
ActualCost *decimal.Decimal `json:"actual_cost"`
Rating *int `json:"rating"`
ImageURLs []string `json:"image_urls"`
Rating *int `json:"rating" validate:"omitempty,min=1,max=5"`
ImageURLs []string `json:"image_urls" validate:"omitempty,max=20,dive,max=500"`
}
// CompletionImageInput represents an image to add to a completion

View File

@@ -81,6 +81,11 @@ func (h *AuthHandler) Register(c echo.Context) error {
// Send welcome email with confirmation code (async)
if h.emailService != nil && confirmationCode != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine")
}
}()
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
}
@@ -176,6 +181,11 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
// Send post-verification welcome email with tips (async)
if h.emailService != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine")
}
}()
if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
}
@@ -204,6 +214,11 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
// Send verification email (async)
if h.emailService != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine")
}
}()
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
}
@@ -238,6 +253,11 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
// Send password reset email (async) - only if user found
if h.emailService != nil && code != "" && user != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in password reset email goroutine")
}
}()
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
}
@@ -326,6 +346,11 @@ func (h *AuthHandler) AppleSignIn(c echo.Context) error {
// Send welcome email for new users (async)
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine")
}
}()
if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email")
}
@@ -368,6 +393,11 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
// Send welcome email for new users (async)
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine")
}
}()
if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email")
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -25,17 +24,23 @@ func NewContractorHandler(contractorService *services.ContractorService) *Contra
// ListContractors handles GET /api/contractors/
func (h *ContractorHandler) ListContractors(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.contractorService.ListContractors(user.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return apperrors.Internal(err)
}
return c.JSON(http.StatusOK, response)
}
// GetContractor handles GET /api/contractors/:id/
func (h *ContractorHandler) GetContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -50,11 +55,17 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
// CreateContractor handles POST /api/contractors/
func (h *ContractorHandler) CreateContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.CreateContractorRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.contractorService.CreateContractor(&req, user.ID)
if err != nil {
@@ -65,7 +76,10 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
// UpdateContractor handles PUT/PATCH /api/contractors/:id/
func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -75,6 +89,9 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
if err != nil {
@@ -85,7 +102,10 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
// DeleteContractor handles DELETE /api/contractors/:id/
func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -100,7 +120,10 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -115,7 +138,10 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
// GetContractorTasks handles GET /api/contractors/:id/tasks/
func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -130,7 +156,10 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
// ListContractorsByResidence handles GET /api/contractors/by-residence/:residence_id/
func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_residence_id")
@@ -147,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
specialties, err := h.contractorService.GetSpecialties()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return apperrors.Internal(err)
}
return c.JSON(http.StatusOK, specialties)
}

View File

@@ -0,0 +1,182 @@
package handlers
import (
"encoding/json"
"net/http"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
)
func setupContractorHandler(t *testing.T) (*ContractorHandler, *echo.Echo, *gorm.DB) {
db := testutil.SetupTestDB(t)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
handler := NewContractorHandler(contractorService)
e := testutil.SetupTestRouter()
return handler, e, db
}
func TestContractorHandler_CreateContractor_MissingName_Returns400(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateContractor)
t.Run("missing name returns 400 validation error", func(t *testing.T) {
// Send request with no name (required field)
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Should contain structured validation error
assert.Contains(t, response, "error")
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "name", "validation error should reference the 'name' field")
})
t.Run("empty body returns 400 validation error", func(t *testing.T) {
// Send completely empty body
w := testutil.MakeRequest(e, "POST", "/api/contractors/", map[string]interface{}{}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "error")
})
t.Run("valid contractor creation succeeds", func(t *testing.T) {
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "John the Plumber",
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}
func TestContractorHandler_ListContractors_Error_NoRawErrorInResponse(t *testing.T) {
_, e, db := setupContractorHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a handler with a broken service to simulate an internal error.
// We do this by closing the underlying SQL connection, which will cause
// the service to return an error on the next query.
brokenDB := testutil.SetupTestDB(t)
sqlDB, _ := brokenDB.DB()
sqlDB.Close()
brokenContractorRepo := repositories.NewContractorRepository(brokenDB)
brokenResidenceRepo := repositories.NewResidenceRepository(brokenDB)
brokenService := services.NewContractorService(brokenContractorRepo, brokenResidenceRepo)
brokenHandler := NewContractorHandler(brokenService)
authGroup := e.Group("/api/broken-contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", brokenHandler.ListContractors)
t.Run("internal error does not leak raw error message", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/broken-contractors/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Should contain the generic error key, NOT a raw database error
errorMsg, ok := response["error"].(string)
require.True(t, ok, "response should have an 'error' string field")
// Must not contain database-specific details
assert.NotContains(t, errorMsg, "sql", "error message should not leak SQL details")
assert.NotContains(t, errorMsg, "database", "error message should not leak database details")
assert.NotContains(t, errorMsg, "closed", "error message should not leak connection state")
})
}
func TestContractorHandler_CreateContractor_100Specialties_Returns400(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateContractor)
t.Run("too many specialties rejected", func(t *testing.T) {
// Create a slice with 100 specialty IDs (exceeds max=20)
specialtyIDs := make([]uint, 100)
for i := range specialtyIDs {
specialtyIDs[i] = uint(i + 1)
}
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Over-specialized Contractor",
SpecialtyIDs: specialtyIDs,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("20 specialties accepted", func(t *testing.T) {
specialtyIDs := make([]uint, 20)
for i := range specialtyIDs {
specialtyIDs[i] = uint(i + 1)
}
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Multi-skilled Contractor",
SpecialtyIDs: specialtyIDs,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
// Should pass validation (201 or success, not 400)
assert.NotEqual(t, http.StatusBadRequest, w.Code, "20 specialties should pass validation")
})
t.Run("rating above 5 rejected", func(t *testing.T) {
rating := 6.0
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Bad Rating Contractor",
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}

View File

@@ -34,7 +34,10 @@ func NewDocumentHandler(documentService *services.DocumentService, storageServic
// ListDocuments handles GET /api/documents/
func (h *DocumentHandler) ListDocuments(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
// Build filter from supported query params.
var filter *repositories.DocumentFilter
@@ -71,7 +74,10 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
// GetDocument handles GET /api/documents/:id/
func (h *DocumentHandler) GetDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -86,10 +92,13 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
// ListWarranties handles GET /api/documents/warranties/
func (h *DocumentHandler) ListWarranties(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.documentService.ListWarranties(user.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return apperrors.Internal(err)
}
return c.JSON(http.StatusOK, response)
}
@@ -97,7 +106,10 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
// CreateDocument handles POST /api/documents/
// Supports both JSON and multipart form data (for file uploads)
func (h *DocumentHandler) CreateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.CreateDocumentRequest
contentType := c.Request().Header.Get("Content-Type")
@@ -198,6 +210,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
}
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.documentService.CreateDocument(&req, user.ID)
if err != nil {
return err
@@ -207,7 +223,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
// UpdateDocument handles PUT/PATCH /api/documents/:id/
func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -217,6 +236,9 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
if err != nil {
@@ -227,7 +249,10 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
// DeleteDocument handles DELETE /api/documents/:id/
func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -242,7 +267,10 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
// ActivateDocument handles POST /api/documents/:id/activate/
func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -257,7 +285,10 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
// DeactivateDocument handles POST /api/documents/:id/deactivate/
func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -272,7 +303,10 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
// UploadDocumentImage handles POST /api/documents/:id/images/
func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -316,7 +350,10 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
// DeleteDocumentImage handles DELETE /api/documents/:id/images/:imageId/
func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")

View File

@@ -1,7 +1,6 @@
package handlers
import (
"path/filepath"
"strconv"
"strings"
@@ -9,7 +8,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
)
@@ -40,7 +38,10 @@ func NewMediaHandler(
// ServeDocument serves a document file with access control
// GET /api/media/document/:id
func (h *MediaHandler) ServeDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -73,7 +74,10 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
// ServeDocumentImage serves a document image with access control
// GET /api/media/document-image/:id
func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -111,7 +115,10 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
// ServeCompletionImage serves a task completion image with access control
// GET /api/media/completion-image/:id
func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -152,7 +159,9 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
return c.File(filePath)
}
// resolveFilePath converts a stored URL to an actual file path
// resolveFilePath converts a stored URL to an actual file path.
// Returns empty string if the URL is empty or the resolved path would escape
// the upload directory (path traversal attempt).
func (h *MediaHandler) resolveFilePath(storedURL string) string {
if storedURL == "" {
return ""
@@ -160,12 +169,18 @@ func (h *MediaHandler) resolveFilePath(storedURL string) string {
uploadDir := h.storageSvc.GetUploadDir()
// Handle legacy /uploads/... URLs
// Strip legacy /uploads/ prefix to get relative path
relativePath := storedURL
if strings.HasPrefix(storedURL, "/uploads/") {
relativePath := strings.TrimPrefix(storedURL, "/uploads/")
return filepath.Join(uploadDir, relativePath)
relativePath = strings.TrimPrefix(storedURL, "/uploads/")
}
// Handle relative paths (new format)
return filepath.Join(uploadDir, storedURL)
// Use SafeResolvePath to validate containment within upload directory
resolved, err := services.SafeResolvePath(uploadDir, relativePath)
if err != nil {
// Path traversal or invalid path — return empty to signal file not found
return ""
}
return resolved
}

View File

@@ -0,0 +1,74 @@
package handlers
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/services"
)
// newTestStorageService creates a StorageService with a known upload directory for testing.
// It does NOT call NewStorageService because that creates directories on disk.
// Instead, it directly constructs the struct with only what resolveFilePath needs.
func newTestStorageService(uploadDir string) *services.StorageService {
cfg := &config.StorageConfig{
UploadDir: uploadDir,
BaseURL: "/uploads",
MaxFileSize: 10 * 1024 * 1024,
AllowedTypes: "image/jpeg,image/png",
}
// Use the exported constructor helper that skips directory creation (for tests)
return services.NewStorageServiceForTest(cfg)
}
func TestResolveFilePath_NormalPath_Works(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
result := h.resolveFilePath("images/photo.jpg")
require.NotEmpty(t, result)
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
}
func TestResolveFilePath_LegacyUploadPath_Works(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
result := h.resolveFilePath("/uploads/images/photo.jpg")
require.NotEmpty(t, result)
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
}
func TestResolveFilePath_DotDotTraversal_Blocked(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
tests := []struct {
name string
storedURL string
}{
{"simple dotdot", "../etc/passwd"},
{"nested dotdot", "../../etc/shadow"},
{"embedded dotdot", "images/../../etc/passwd"},
{"legacy prefix with dotdot", "/uploads/../../../etc/passwd"},
{"deep dotdot", "a/b/c/../../../../etc/passwd"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := h.resolveFilePath(tt.storedURL)
assert.Empty(t, result, "path traversal should return empty string for: %s", tt.storedURL)
})
}
}
func TestResolveFilePath_EmptyURL_ReturnsEmpty(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
result := h.resolveFilePath("")
assert.Empty(t, result)
}

View File

@@ -0,0 +1,334 @@
package handlers
import (
"net/http"
"testing"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
)
// TestTaskHandler_NoAuth_Returns401 verifies that task handler endpoints return
// 401 Unauthorized when no auth user is set in the context (e.g., auth middleware
// misconfigured or bypassed). This is a regression test for P1-1 (SEC-19).
func TestTaskHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskService := services.NewTaskService(taskRepo, residenceRepo)
handler := NewTaskHandler(taskService, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/tasks/", handler.ListTasks)
e.GET("/api/tasks/:id/", handler.GetTask)
e.POST("/api/tasks/", handler.CreateTask)
e.PUT("/api/tasks/:id/", handler.UpdateTask)
e.DELETE("/api/tasks/:id/", handler.DeleteTask)
e.POST("/api/tasks/:id/cancel/", handler.CancelTask)
e.POST("/api/tasks/:id/mark-in-progress/", handler.MarkInProgress)
e.GET("/api/task-completions/", handler.ListCompletions)
e.POST("/api/task-completions/", handler.CreateCompletion)
tests := []struct {
name string
method string
path string
}{
{"ListTasks", "GET", "/api/tasks/"},
{"GetTask", "GET", "/api/tasks/1/"},
{"CreateTask", "POST", "/api/tasks/"},
{"UpdateTask", "PUT", "/api/tasks/1/"},
{"DeleteTask", "DELETE", "/api/tasks/1/"},
{"CancelTask", "POST", "/api/tasks/1/cancel/"},
{"MarkInProgress", "POST", "/api/tasks/1/mark-in-progress/"},
{"ListCompletions", "GET", "/api/task-completions/"},
{"CreateCompletion", "POST", "/api/task-completions/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestResidenceHandler_NoAuth_Returns401 verifies that residence handler endpoints
// return 401 Unauthorized when no auth user is set in the context.
func TestResidenceHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
handler := NewResidenceHandler(residenceService, nil, nil, true)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/residences/", handler.ListResidences)
e.GET("/api/residences/my/", handler.GetMyResidences)
e.GET("/api/residences/summary/", handler.GetSummary)
e.GET("/api/residences/:id/", handler.GetResidence)
e.POST("/api/residences/", handler.CreateResidence)
e.PUT("/api/residences/:id/", handler.UpdateResidence)
e.DELETE("/api/residences/:id/", handler.DeleteResidence)
e.POST("/api/residences/:id/generate-share-code/", handler.GenerateShareCode)
e.POST("/api/residences/join-with-code/", handler.JoinWithCode)
e.GET("/api/residences/:id/users/", handler.GetResidenceUsers)
tests := []struct {
name string
method string
path string
}{
{"ListResidences", "GET", "/api/residences/"},
{"GetMyResidences", "GET", "/api/residences/my/"},
{"GetSummary", "GET", "/api/residences/summary/"},
{"GetResidence", "GET", "/api/residences/1/"},
{"CreateResidence", "POST", "/api/residences/"},
{"UpdateResidence", "PUT", "/api/residences/1/"},
{"DeleteResidence", "DELETE", "/api/residences/1/"},
{"GenerateShareCode", "POST", "/api/residences/1/generate-share-code/"},
{"JoinWithCode", "POST", "/api/residences/join-with-code/"},
{"GetResidenceUsers", "GET", "/api/residences/1/users/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestNotificationHandler_NoAuth_Returns401 verifies that notification handler
// endpoints return 401 Unauthorized when no auth user is set in the context.
func TestNotificationHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
notificationRepo := repositories.NewNotificationRepository(db)
notificationService := services.NewNotificationService(notificationRepo, nil)
handler := NewNotificationHandler(notificationService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/notifications/", handler.ListNotifications)
e.GET("/api/notifications/unread-count/", handler.GetUnreadCount)
e.POST("/api/notifications/:id/read/", handler.MarkAsRead)
e.POST("/api/notifications/mark-all-read/", handler.MarkAllAsRead)
e.GET("/api/notifications/preferences/", handler.GetPreferences)
e.PUT("/api/notifications/preferences/", handler.UpdatePreferences)
e.POST("/api/notifications/devices/", handler.RegisterDevice)
e.GET("/api/notifications/devices/", handler.ListDevices)
tests := []struct {
name string
method string
path string
}{
{"ListNotifications", "GET", "/api/notifications/"},
{"GetUnreadCount", "GET", "/api/notifications/unread-count/"},
{"MarkAsRead", "POST", "/api/notifications/1/read/"},
{"MarkAllAsRead", "POST", "/api/notifications/mark-all-read/"},
{"GetPreferences", "GET", "/api/notifications/preferences/"},
{"UpdatePreferences", "PUT", "/api/notifications/preferences/"},
{"RegisterDevice", "POST", "/api/notifications/devices/"},
{"ListDevices", "GET", "/api/notifications/devices/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestDocumentHandler_NoAuth_Returns401 verifies that document handler endpoints
// return 401 Unauthorized when no auth user is set in the context.
func TestDocumentHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
documentService := services.NewDocumentService(documentRepo, residenceRepo)
handler := NewDocumentHandler(documentService, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/documents/", handler.ListDocuments)
e.GET("/api/documents/:id/", handler.GetDocument)
e.GET("/api/documents/warranties/", handler.ListWarranties)
e.POST("/api/documents/", handler.CreateDocument)
e.PUT("/api/documents/:id/", handler.UpdateDocument)
e.DELETE("/api/documents/:id/", handler.DeleteDocument)
e.POST("/api/documents/:id/activate/", handler.ActivateDocument)
e.POST("/api/documents/:id/deactivate/", handler.DeactivateDocument)
tests := []struct {
name string
method string
path string
}{
{"ListDocuments", "GET", "/api/documents/"},
{"GetDocument", "GET", "/api/documents/1/"},
{"ListWarranties", "GET", "/api/documents/warranties/"},
{"CreateDocument", "POST", "/api/documents/"},
{"UpdateDocument", "PUT", "/api/documents/1/"},
{"DeleteDocument", "DELETE", "/api/documents/1/"},
{"ActivateDocument", "POST", "/api/documents/1/activate/"},
{"DeactivateDocument", "POST", "/api/documents/1/deactivate/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestContractorHandler_NoAuth_Returns401 verifies that contractor handler endpoints
// return 401 Unauthorized when no auth user is set in the context.
func TestContractorHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
handler := NewContractorHandler(contractorService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/contractors/", handler.ListContractors)
e.GET("/api/contractors/:id/", handler.GetContractor)
e.POST("/api/contractors/", handler.CreateContractor)
e.PUT("/api/contractors/:id/", handler.UpdateContractor)
e.DELETE("/api/contractors/:id/", handler.DeleteContractor)
e.POST("/api/contractors/:id/toggle-favorite/", handler.ToggleFavorite)
e.GET("/api/contractors/:id/tasks/", handler.GetContractorTasks)
tests := []struct {
name string
method string
path string
}{
{"ListContractors", "GET", "/api/contractors/"},
{"GetContractor", "GET", "/api/contractors/1/"},
{"CreateContractor", "POST", "/api/contractors/"},
{"UpdateContractor", "PUT", "/api/contractors/1/"},
{"DeleteContractor", "DELETE", "/api/contractors/1/"},
{"ToggleFavorite", "POST", "/api/contractors/1/toggle-favorite/"},
{"GetContractorTasks", "GET", "/api/contractors/1/tasks/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestSubscriptionHandler_NoAuth_Returns401 verifies that subscription handler
// endpoints return 401 Unauthorized when no auth user is set in the context.
func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
handler := NewSubscriptionHandler(subscriptionService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/subscription/", handler.GetSubscription)
e.GET("/api/subscription/status/", handler.GetSubscriptionStatus)
e.GET("/api/subscription/promotions/", handler.GetPromotions)
e.POST("/api/subscription/purchase/", handler.ProcessPurchase)
e.POST("/api/subscription/cancel/", handler.CancelSubscription)
e.POST("/api/subscription/restore/", handler.RestoreSubscription)
tests := []struct {
name string
method string
path string
}{
{"GetSubscription", "GET", "/api/subscription/"},
{"GetSubscriptionStatus", "GET", "/api/subscription/status/"},
{"GetPromotions", "GET", "/api/subscription/promotions/"},
{"ProcessPurchase", "POST", "/api/subscription/purchase/"},
{"CancelSubscription", "POST", "/api/subscription/cancel/"},
{"RestoreSubscription", "POST", "/api/subscription/restore/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestMediaHandler_NoAuth_Returns401 verifies that media handler endpoints return
// 401 Unauthorized when no auth user is set in the context.
func TestMediaHandler_NoAuth_Returns401(t *testing.T) {
handler := NewMediaHandler(nil, nil, nil, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/media/document/:id", handler.ServeDocument)
e.GET("/api/media/document-image/:id", handler.ServeDocumentImage)
e.GET("/api/media/completion-image/:id", handler.ServeCompletionImage)
tests := []struct {
name string
method string
path string
}{
{"ServeDocument", "GET", "/api/media/document/1"},
{"ServeDocumentImage", "GET", "/api/media/document-image/1"},
{"ServeCompletionImage", "GET", "/api/media/completion-image/1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestUserHandler_NoAuth_Returns401 verifies that user handler endpoints return
// 401 Unauthorized when no auth user is set in the context.
func TestUserHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
userService := services.NewUserService(userRepo)
handler := NewUserHandler(userService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/users/", handler.ListUsers)
e.GET("/api/users/:id/", handler.GetUser)
e.GET("/api/users/profiles/", handler.ListProfiles)
tests := []struct {
name string
method string
path string
}{
{"ListUsers", "GET", "/api/users/"},
{"GetUser", "GET", "/api/users/1/"},
{"ListProfiles", "GET", "/api/users/profiles/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -24,7 +23,10 @@ func NewNotificationHandler(notificationService *services.NotificationService) *
// ListNotifications handles GET /api/notifications/
func (h *NotificationHandler) ListNotifications(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
limit := 50
offset := 0
@@ -33,6 +35,9 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
limit = parsed
}
}
if limit > 200 {
limit = 200
}
if o := c.QueryParam("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
@@ -52,7 +57,10 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
// GetUnreadCount handles GET /api/notifications/unread-count/
func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
count, err := h.notificationService.GetUnreadCount(user.ID)
if err != nil {
@@ -64,7 +72,10 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
// MarkAsRead handles POST /api/notifications/:id/read/
func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
notificationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -81,9 +92,12 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
err := h.notificationService.MarkAllAsRead(user.ID)
err = h.notificationService.MarkAllAsRead(user.ID)
if err != nil {
return err
}
@@ -93,7 +107,10 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
// GetPreferences handles GET /api/notifications/preferences/
func (h *NotificationHandler) GetPreferences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
prefs, err := h.notificationService.GetPreferences(user.ID)
if err != nil {
@@ -105,12 +122,18 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
// UpdatePreferences handles PUT/PATCH /api/notifications/preferences/
func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.UpdatePreferencesRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
if err != nil {
@@ -122,12 +145,18 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
// RegisterDevice handles POST /api/notifications/devices/
func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.RegisterDeviceRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
device, err := h.notificationService.RegisterDevice(user.ID, &req)
if err != nil {
@@ -139,7 +168,10 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
// ListDevices handles GET /api/notifications/devices/
func (h *NotificationHandler) ListDevices(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
devices, err := h.notificationService.ListDevices(user.ID)
if err != nil {
@@ -152,7 +184,10 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
// UnregisterDevice handles POST /api/notifications/devices/unregister/
// Accepts {registration_id, platform} and deactivates the matching device
func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req struct {
RegistrationID string `json:"registration_id"`
@@ -168,7 +203,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
req.Platform = "ios" // Default to iOS
}
err := h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
if err != nil {
return err
}
@@ -178,7 +213,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
// DeleteDevice handles DELETE /api/notifications/devices/:id/
func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
deviceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {

View File

@@ -0,0 +1,88 @@
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
)
func setupNotificationHandler(t *testing.T) (*NotificationHandler, *echo.Echo, *gorm.DB) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
notifService := services.NewNotificationService(notifRepo, nil)
handler := NewNotificationHandler(notifService)
e := testutil.SetupTestRouter()
return handler, e, db
}
func createTestNotifications(t *testing.T, db *gorm.DB, userID uint, count int) {
for i := 0; i < count; i++ {
notif := &models.Notification{
UserID: userID,
NotificationType: models.NotificationTaskDueSoon,
Title: fmt.Sprintf("Test Notification %d", i+1),
Body: fmt.Sprintf("Body %d", i+1),
}
err := db.Create(notif).Error
require.NoError(t, err)
}
}
func TestNotificationHandler_ListNotifications_LimitCappedAt200(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Create 210 notifications to exceed the cap
createTestNotifications(t, db, user.ID, 210)
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", handler.ListNotifications)
t.Run("limit is capped at 200 when user requests more", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=999", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 200, count, "response should contain at most 200 notifications when limit exceeds cap")
})
t.Run("limit below cap is respected", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=10", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 10, count, "response should respect limit when below cap")
})
t.Run("default limit is used when no limit param", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 50, count, "response should use default limit of 50")
})
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/i18n"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/validator"
)
@@ -35,7 +34,10 @@ func NewResidenceHandler(residenceService *services.ResidenceService, pdfService
// ListResidences handles GET /api/residences/
func (h *ResidenceHandler) ListResidences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.residenceService.ListResidences(user.ID)
if err != nil {
@@ -47,7 +49,10 @@ func (h *ResidenceHandler) ListResidences(c echo.Context) error {
// GetMyResidences handles GET /api/residences/my-residences/
func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
response, err := h.residenceService.GetMyResidences(user.ID, userNow)
@@ -61,7 +66,10 @@ func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
// GetSummary handles GET /api/residences/summary/
// Returns just the task statistics summary without full residence data
func (h *ResidenceHandler) GetSummary(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
summary, err := h.residenceService.GetSummary(user.ID, userNow)
@@ -74,7 +82,10 @@ func (h *ResidenceHandler) GetSummary(c echo.Context) error {
// GetResidence handles GET /api/residences/:id/
func (h *ResidenceHandler) GetResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -91,7 +102,10 @@ func (h *ResidenceHandler) GetResidence(c echo.Context) error {
// CreateResidence handles POST /api/residences/
func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.CreateResidenceRequest
if err := c.Bind(&req); err != nil {
@@ -111,7 +125,10 @@ func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
// UpdateResidence handles PUT/PATCH /api/residences/:id/
func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -136,7 +153,10 @@ func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
// DeleteResidence handles DELETE /api/residences/:id/
func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -154,7 +174,10 @@ func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
// GetShareCode handles GET /api/residences/:id/share-code/
// Returns the active share code for a residence, or null if none exists
func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -175,7 +198,10 @@ func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
// GenerateShareCode handles POST /api/residences/:id/generate-share-code/
func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -197,7 +223,10 @@ func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
// GenerateSharePackage handles POST /api/residences/:id/generate-share-package/
// Returns a share code with metadata for creating a .casera package file
func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -218,12 +247,18 @@ func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
// JoinWithCode handles POST /api/residences/join-with-code/
func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.JoinWithCodeRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.residenceService.JoinWithCode(req.Code, user.ID)
if err != nil {
@@ -235,7 +270,10 @@ func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
// GetResidenceUsers handles GET /api/residences/:id/users/
func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -252,7 +290,10 @@ func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
// RemoveResidenceUser handles DELETE /api/residences/:id/users/:user_id/
func (h *ResidenceHandler) RemoveResidenceUser(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -289,7 +330,10 @@ func (h *ResidenceHandler) GenerateTasksReport(c echo.Context) error {
return apperrors.BadRequest("error.feature_disabled")
}
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {

View File

@@ -525,3 +525,45 @@ func TestResidenceHandler_JSONResponses(t *testing.T) {
assert.IsType(t, []map[string]interface{}{}, response)
})
}
func TestResidenceHandler_CreateResidence_NegativeBedrooms_Returns400(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateResidence)
t.Run("negative bedrooms rejected", func(t *testing.T) {
bedrooms := -1
req := requests.CreateResidenceRequest{
Name: "Bad House",
Bedrooms: &bedrooms,
}
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("negative square footage rejected", func(t *testing.T) {
sqft := -100
req := requests.CreateResidenceRequest{
Name: "Bad House",
SquareFootage: &sqft,
}
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("zero bedrooms accepted", func(t *testing.T) {
bedrooms := 0
req := requests.CreateResidenceRequest{
Name: "Studio Apartment",
Bedrooms: &bedrooms,
}
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -23,7 +22,10 @@ func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *
// GetSubscription handles GET /api/subscription/
func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
subscription, err := h.subscriptionService.GetSubscription(user.ID)
if err != nil {
@@ -35,7 +37,10 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
// GetSubscriptionStatus handles GET /api/subscription/status/
func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
if err != nil {
@@ -79,7 +84,10 @@ func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
// GetPromotions handles GET /api/subscription/promotions/
func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
if err != nil {
@@ -91,15 +99,20 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
// ProcessPurchase handles POST /api/subscription/purchase/
func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.ProcessPurchaseRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
var subscription *services.SubscriptionResponse
var err error
switch req.Platform {
case "ios":
@@ -129,7 +142,10 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
// CancelSubscription handles POST /api/subscription/cancel/
func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
if err != nil {
@@ -144,16 +160,21 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
// RestoreSubscription handles POST /api/subscription/restore/
func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.ProcessPurchaseRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
// Same logic as ProcessPurchase - validates receipt/token and restores
var subscription *services.SubscriptionResponse
var err error
switch req.Platform {
case "ios":

View File

@@ -8,14 +8,14 @@ import (
"encoding/pem"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/models"
@@ -101,40 +101,39 @@ type AppleRenewalInfo struct {
// HandleAppleWebhook handles POST /api/subscription/webhook/apple/
func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
if !h.enabled {
log.Printf("Apple Webhook: webhooks disabled by feature flag")
log.Info().Msg("Apple Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Printf("Apple Webhook: Failed to read body: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
}
var payload AppleNotificationPayload
if err := json.Unmarshal(body, &payload); err != nil {
log.Printf("Apple Webhook: Failed to parse payload: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to parse payload")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid payload"})
}
// Decode and verify the signed payload (JWS)
notification, err := h.decodeAppleSignedPayload(payload.SignedPayload)
if err != nil {
log.Printf("Apple Webhook: Failed to decode signed payload: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to decode signed payload")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signed payload"})
}
log.Printf("Apple Webhook: Received %s (subtype: %s) for bundle %s",
notification.NotificationType, notification.Subtype, notification.Data.BundleID)
log.Info().Str("type", notification.NotificationType).Str("subtype", notification.Subtype).Str("bundle", notification.Data.BundleID).Msg("Apple Webhook: Received notification")
// Dedup check using notificationUUID
if notification.NotificationUUID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
if err != nil {
log.Printf("Apple Webhook: Failed to check dedup: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to check dedup")
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
log.Printf("Apple Webhook: Duplicate event %s, skipping", notification.NotificationUUID)
log.Info().Str("uuid", notification.NotificationUUID).Msg("Apple Webhook: Duplicate event, skipping")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
}
@@ -143,8 +142,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
cfg := config.Get()
if cfg != nil && cfg.AppleIAP.BundleID != "" {
if notification.Data.BundleID != cfg.AppleIAP.BundleID {
log.Printf("Apple Webhook: Bundle ID mismatch: got %s, expected %s",
notification.Data.BundleID, cfg.AppleIAP.BundleID)
log.Warn().Str("got", notification.Data.BundleID).Str("expected", cfg.AppleIAP.BundleID).Msg("Apple Webhook: Bundle ID mismatch")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "bundle ID mismatch"})
}
}
@@ -152,7 +150,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
// Decode transaction info
transactionInfo, err := h.decodeAppleTransaction(notification.Data.SignedTransactionInfo)
if err != nil {
log.Printf("Apple Webhook: Failed to decode transaction: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to decode transaction")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid transaction info"})
}
@@ -164,14 +162,14 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
// Process the notification
if err := h.processAppleNotification(notification, transactionInfo, renewalInfo); err != nil {
log.Printf("Apple Webhook: Failed to process notification: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to process notification")
// Still return 200 to prevent Apple from retrying
}
// Record processed event for dedup
if notification.NotificationUUID != "" {
if err := h.webhookEventRepo.RecordEvent("apple", notification.NotificationUUID, notification.NotificationType, ""); err != nil {
log.Printf("Apple Webhook: Failed to record event: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to record event")
}
}
@@ -179,7 +177,8 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
}
// decodeAppleSignedPayload decodes and verifies an Apple JWS payload
// decodeAppleSignedPayload verifies and decodes an Apple JWS payload.
// The JWS signature is verified before the payload is trusted.
func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload string) (*AppleNotificationData, error) {
// JWS format: header.payload.signature
parts := strings.Split(signedPayload, ".")
@@ -187,8 +186,11 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri
return nil, fmt.Errorf("invalid JWS format")
}
// Decode payload (we're trusting Apple's signature for now)
// In production, you should verify the signature using Apple's root certificate
// Verify the JWS signature before trusting the payload.
if err := h.VerifyAppleSignature(signedPayload); err != nil {
return nil, fmt.Errorf("Apple JWS signature verification failed: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
@@ -251,14 +253,12 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
// Find user by stored receipt data (original transaction ID)
user, err := h.findUserByAppleTransaction(transaction.OriginalTransactionID)
if err != nil {
log.Printf("Apple Webhook: Could not find user for transaction %s: %v",
transaction.OriginalTransactionID, err)
log.Warn().Err(err).Str("transaction_id", transaction.OriginalTransactionID).Msg("Apple Webhook: Could not find user for transaction")
// Not an error - might be a transaction we don't track
return nil
}
log.Printf("Apple Webhook: Processing %s for user %d (product: %s)",
notification.NotificationType, user.ID, transaction.ProductID)
log.Info().Str("type", notification.NotificationType).Uint("user_id", user.ID).Str("product", transaction.ProductID).Msg("Apple Webhook: Processing notification")
switch notification.NotificationType {
case "SUBSCRIBED":
@@ -294,7 +294,7 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
return h.handleAppleGracePeriodExpired(user.ID, transaction)
default:
log.Printf("Apple Webhook: Unhandled notification type: %s", notification.NotificationType)
log.Warn().Str("type", notification.NotificationType).Msg("Apple Webhook: Unhandled notification type")
}
return nil
@@ -326,7 +326,7 @@ func (h *SubscriptionWebhookHandler) handleAppleSubscribed(userID uint, tx *Appl
return err
}
log.Printf("Apple Webhook: User %d subscribed, expires %v, autoRenew=%v", userID, expiresAt, autoRenew)
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Bool("auto_renew", autoRenew).Msg("Apple Webhook: User subscribed")
return nil
}
@@ -337,7 +337,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewed(userID uint, tx *AppleTr
return err
}
log.Printf("Apple Webhook: User %d renewed, new expiry %v", userID, expiresAt)
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Msg("Apple Webhook: User renewed")
return nil
}
@@ -357,13 +357,13 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
if err := h.subscriptionRepo.SetCancelledAt(userID, now); err != nil {
return err
}
log.Printf("Apple Webhook: User %d turned off auto-renew, will expire at end of period", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned off auto-renew, will expire at end of period")
} else {
// User turned auto-renew back on
if err := h.subscriptionRepo.ClearCancelledAt(userID); err != nil {
return err
}
log.Printf("Apple Webhook: User %d turned auto-renew back on", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned auto-renew back on")
}
return nil
@@ -371,7 +371,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *AppleTransactionInfo, renewal *AppleRenewalInfo) error {
// Subscription is in billing retry or grace period
log.Printf("Apple Webhook: User %d failed to renew, may be in grace period", userID)
log.Warn().Uint("user_id", userID).Msg("Apple Webhook: User failed to renew, may be in grace period")
// Don't downgrade yet - Apple may retry billing
return nil
}
@@ -381,7 +381,7 @@ func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTr
return err
}
log.Printf("Apple Webhook: User %d subscription expired, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
return nil
}
@@ -390,7 +390,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTra
return err
}
log.Printf("Apple Webhook: User %d got refund, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
return nil
}
@@ -399,7 +399,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTra
return err
}
log.Printf("Apple Webhook: User %d subscription revoked, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
return nil
}
@@ -408,7 +408,7 @@ func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint,
return err
}
log.Printf("Apple Webhook: User %d grace period expired, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
return nil
}
@@ -481,32 +481,32 @@ const (
// HandleGoogleWebhook handles POST /api/subscription/webhook/google/
func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
if !h.enabled {
log.Printf("Google Webhook: webhooks disabled by feature flag")
log.Info().Msg("Google Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Printf("Google Webhook: Failed to read body: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
}
var notification GoogleNotification
if err := json.Unmarshal(body, &notification); err != nil {
log.Printf("Google Webhook: Failed to parse notification: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to parse notification")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid notification"})
}
// Decode the base64 data
data, err := base64.StdEncoding.DecodeString(notification.Message.Data)
if err != nil {
log.Printf("Google Webhook: Failed to decode message data: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to decode message data")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid message data"})
}
var devNotification GoogleDeveloperNotification
if err := json.Unmarshal(data, &devNotification); err != nil {
log.Printf("Google Webhook: Failed to parse developer notification: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to parse developer notification")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid developer notification"})
}
@@ -515,17 +515,17 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
if messageID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
if err != nil {
log.Printf("Google Webhook: Failed to check dedup: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to check dedup")
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
log.Printf("Google Webhook: Duplicate event %s, skipping", messageID)
log.Info().Str("message_id", messageID).Msg("Google Webhook: Duplicate event, skipping")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
}
// Handle test notification
if devNotification.TestNotification != nil {
log.Printf("Google Webhook: Received test notification")
log.Info().Msg("Google Webhook: Received test notification")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "test received"})
}
@@ -533,8 +533,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
cfg := config.Get()
if cfg != nil && cfg.GoogleIAP.PackageName != "" {
if devNotification.PackageName != cfg.GoogleIAP.PackageName {
log.Printf("Google Webhook: Package name mismatch: got %s, expected %s",
devNotification.PackageName, cfg.GoogleIAP.PackageName)
log.Warn().Str("got", devNotification.PackageName).Str("expected", cfg.GoogleIAP.PackageName).Msg("Google Webhook: Package name mismatch")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "package name mismatch"})
}
}
@@ -542,7 +541,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
// Process subscription notification
if devNotification.SubscriptionNotification != nil {
if err := h.processGoogleSubscriptionNotification(devNotification.SubscriptionNotification); err != nil {
log.Printf("Google Webhook: Failed to process notification: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to process notification")
// Still return 200 to acknowledge
}
}
@@ -554,7 +553,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
eventType = fmt.Sprintf("subscription_%d", devNotification.SubscriptionNotification.NotificationType)
}
if err := h.webhookEventRepo.RecordEvent("google", messageID, eventType, ""); err != nil {
log.Printf("Google Webhook: Failed to record event: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to record event")
}
}
@@ -567,12 +566,11 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
// Find user by purchase token
user, err := h.findUserByGoogleToken(notification.PurchaseToken)
if err != nil {
log.Printf("Google Webhook: Could not find user for token: %v", err)
log.Warn().Err(err).Msg("Google Webhook: Could not find user for token")
return nil // Not an error - might be unknown token
}
log.Printf("Google Webhook: Processing type %d for user %d (subscription: %s)",
notification.NotificationType, user.ID, notification.SubscriptionID)
log.Info().Int("type", notification.NotificationType).Uint("user_id", user.ID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: Processing notification")
switch notification.NotificationType {
case GoogleSubPurchased:
@@ -606,7 +604,7 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
return h.handleGooglePaused(user.ID, notification)
default:
log.Printf("Google Webhook: Unhandled notification type: %d", notification.NotificationType)
log.Warn().Int("type", notification.NotificationType).Msg("Google Webhook: Unhandled notification type")
}
return nil
@@ -629,7 +627,7 @@ func (h *SubscriptionWebhookHandler) findUserByGoogleToken(purchaseToken string)
func (h *SubscriptionWebhookHandler) handleGooglePurchased(userID uint, notification *GoogleSubscriptionNotification) error {
// New subscription - we should have already processed this via the client
// This is a backup notification
log.Printf("Google Webhook: User %d purchased subscription %s", userID, notification.SubscriptionID)
log.Info().Uint("user_id", userID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: User purchased subscription")
return nil
}
@@ -648,7 +646,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRenewed(userID uint, notificati
return err
}
log.Printf("Google Webhook: User %d renewed, extended to %v", userID, newExpiry)
log.Info().Uint("user_id", userID).Time("expires", newExpiry).Msg("Google Webhook: User renewed")
return nil
}
@@ -659,7 +657,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRecovered(userID uint, notifica
return err
}
log.Printf("Google Webhook: User %d subscription recovered", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription recovered")
return nil
}
@@ -673,19 +671,19 @@ func (h *SubscriptionWebhookHandler) handleGoogleCanceled(userID uint, notificat
return err
}
log.Printf("Google Webhook: User %d canceled, will expire at end of period", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User canceled, will expire at end of period")
return nil
}
func (h *SubscriptionWebhookHandler) handleGoogleOnHold(userID uint, notification *GoogleSubscriptionNotification) error {
// Account hold - payment issue, may recover
log.Printf("Google Webhook: User %d subscription on hold", userID)
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User subscription on hold")
return nil
}
func (h *SubscriptionWebhookHandler) handleGoogleGracePeriod(userID uint, notification *GoogleSubscriptionNotification) error {
// In grace period - user still has access but billing failed
log.Printf("Google Webhook: User %d in grace period", userID)
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User in grace period")
return nil
}
@@ -702,7 +700,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
return err
}
log.Printf("Google Webhook: User %d restarted subscription", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User restarted subscription")
return nil
}
@@ -712,7 +710,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notificati
return err
}
log.Printf("Google Webhook: User %d subscription revoked", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
return nil
}
@@ -722,13 +720,13 @@ func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notificati
return err
}
log.Printf("Google Webhook: User %d subscription expired", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
return nil
}
func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notification *GoogleSubscriptionNotification) error {
// Subscription paused by user
log.Printf("Google Webhook: User %d subscription paused", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription paused")
return nil
}
@@ -736,18 +734,21 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
// Signature Verification (Optional but Recommended)
// ====================
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate
// This is optional but recommended for production
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate.
// If root certificates are not loaded, verification fails (deny by default).
func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string) error {
// Load Apple's root certificate if not already loaded
// Deny by default when root certificates are not loaded.
if h.appleRootCerts == nil {
// Apple's root certificates can be downloaded from:
// https://www.apple.com/certificateauthority/
// You'd typically embed these or load from a file
return nil // Skip verification for now
return fmt.Errorf("Apple root certificates not configured: cannot verify JWS signature")
}
// Parse the JWS token
// Build a certificate pool from the loaded Apple root certificates
rootPool := x509.NewCertPool()
for _, cert := range h.appleRootCerts {
rootPool.AddCert(cert)
}
// Parse the JWS token and verify the signature using the x5c certificate chain
token, err := jwt.Parse(signedPayload, func(token *jwt.Token) (interface{}, error) {
// Get the x5c header (certificate chain)
x5c, ok := token.Header["x5c"].([]interface{})
@@ -755,21 +756,46 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
return nil, fmt.Errorf("missing x5c header")
}
// Decode the first certificate (leaf)
// Decode the leaf certificate
certData, err := base64.StdEncoding.DecodeString(x5c[0].(string))
if err != nil {
return nil, fmt.Errorf("failed to decode certificate: %w", err)
}
cert, err := x509.ParseCertificate(certData)
leafCert, err := x509.ParseCertificate(certData)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
// Verify the certificate chain (simplified)
// In production, you should verify the full chain
// Build intermediate pool from remaining x5c entries
intermediatePool := x509.NewCertPool()
for i := 1; i < len(x5c); i++ {
intermData, err := base64.StdEncoding.DecodeString(x5c[i].(string))
if err != nil {
return nil, fmt.Errorf("failed to decode intermediate certificate: %w", err)
}
intermCert, err := x509.ParseCertificate(intermData)
if err != nil {
return nil, fmt.Errorf("failed to parse intermediate certificate: %w", err)
}
intermediatePool.AddCert(intermCert)
}
return cert.PublicKey.(*ecdsa.PublicKey), nil
// Verify the certificate chain against Apple's root certificates
opts := x509.VerifyOptions{
Roots: rootPool,
Intermediates: intermediatePool,
}
if _, err := leafCert.Verify(opts); err != nil {
return nil, fmt.Errorf("certificate chain verification failed: %w", err)
}
ecdsaKey, ok := leafCert.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("leaf certificate public key is not ECDSA")
}
return ecdsaKey, nil
})
if err != nil {
@@ -783,13 +809,58 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
return nil
}
// VerifyGooglePubSubToken verifies the Pub/Sub push token (if configured)
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
// Returns false (deny) when the Authorization header is missing or the token
// cannot be validated. This prevents unauthenticated callers from injecting
// webhook events.
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
// If you configured a push endpoint with authentication, verify here
// The token is typically in the Authorization header
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
log.Warn().Msg("Google Webhook: missing Authorization header")
return false
}
// Expect "Bearer <token>" format
if !strings.HasPrefix(authHeader, "Bearer ") {
log.Warn().Msg("Google Webhook: Authorization header is not Bearer token")
return false
}
bearerToken := strings.TrimPrefix(authHeader, "Bearer ")
if bearerToken == "" {
log.Warn().Msg("Google Webhook: empty Bearer token")
return false
}
// Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs
// issued by accounts.google.com. We verify the claims to ensure the
// token was intended for our service.
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
if err != nil {
log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token")
return false
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Warn().Msg("Google Webhook: invalid token claims")
return false
}
// Verify issuer is Google
issuer, _ := claims.GetIssuer()
if issuer != "accounts.google.com" && issuer != "https://accounts.google.com" {
log.Warn().Str("issuer", issuer).Msg("Google Webhook: unexpected issuer")
return false
}
// Verify the email claim matches a Google service account
email, _ := claims["email"].(string)
if email == "" || !strings.HasSuffix(email, ".gserviceaccount.com") {
log.Warn().Str("email", email).Msg("Google Webhook: token email is not a Google service account")
return false
}
// For now, we rely on the endpoint being protected by your infrastructure
// (e.g., only accessible from Google's IP ranges)
return true
}

View File

@@ -0,0 +1,56 @@
package handlers
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestVerifyGooglePubSubToken_MissingAuth_ReturnsFalse(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
e := echo.New()
// Request with no Authorization header
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := handler.VerifyGooglePubSubToken(c)
assert.False(t, result, "VerifyGooglePubSubToken should return false when Authorization header is missing")
}
func TestVerifyGooglePubSubToken_InvalidToken_ReturnsFalse(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
req.Header.Set("Authorization", "Bearer invalid-garbage-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := handler.VerifyGooglePubSubToken(c)
assert.False(t, result, "VerifyGooglePubSubToken should return false for an invalid/unverifiable token")
}
func TestDecodeAppleSignedPayload_InvalidJWS_ReturnsError(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
// No signature parts
_, err := handler.decodeAppleSignedPayload("not-a-jws")
assert.Error(t, err, "should reject payload that is not valid JWS format")
}
func TestDecodeAppleSignedPayload_VerificationFails_ReturnsError(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
// Construct a JWS-shaped string with 3 parts but no valid signature.
// The handler should now attempt verification and fail.
// header.payload.signature -- all base64url garbage
fakeJWS := "eyJhbGciOiJFUzI1NiJ9.eyJ0ZXN0IjoidHJ1ZSJ9.invalidsig"
_, err := handler.decodeAppleSignedPayload(fakeJWS)
assert.Error(t, err, "should return error when Apple signature verification fails")
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -32,13 +31,16 @@ func NewTaskHandler(taskService *services.TaskService, storageService *services.
// ListTasks handles GET /api/tasks/
func (h *TaskHandler) ListTasks(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
// This runs in a goroutine to avoid blocking the response
// Runs synchronously — this is a lightweight DB upsert that should complete quickly
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
go h.taskService.UpdateUserTimezone(user.ID, tzHeader)
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
}
daysThreshold := 30
@@ -62,7 +64,10 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
// GetTask handles GET /api/tasks/:id/
func (h *TaskHandler) GetTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -77,7 +82,10 @@ func (h *TaskHandler) GetTask(c echo.Context) error {
// GetTasksByResidence handles GET /api/tasks/by-residence/:residence_id/
func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
@@ -106,13 +114,19 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
// CreateTask handles POST /api/tasks/
func (h *TaskHandler) CreateTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
var req requests.CreateTaskRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.CreateTask(&req, user.ID, userNow)
if err != nil {
@@ -123,7 +137,10 @@ func (h *TaskHandler) CreateTask(c echo.Context) error {
// UpdateTask handles PUT/PATCH /api/tasks/:id/
func (h *TaskHandler) UpdateTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -135,6 +152,9 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.UpdateTask(uint(taskID), user.ID, &req, userNow)
if err != nil {
@@ -145,7 +165,10 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
// DeleteTask handles DELETE /api/tasks/:id/
func (h *TaskHandler) DeleteTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -160,7 +183,10 @@ func (h *TaskHandler) DeleteTask(c echo.Context) error {
// MarkInProgress handles POST /api/tasks/:id/mark-in-progress/
func (h *TaskHandler) MarkInProgress(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -177,7 +203,10 @@ func (h *TaskHandler) MarkInProgress(c echo.Context) error {
// CancelTask handles POST /api/tasks/:id/cancel/
func (h *TaskHandler) CancelTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -194,7 +223,10 @@ func (h *TaskHandler) CancelTask(c echo.Context) error {
// UncancelTask handles POST /api/tasks/:id/uncancel/
func (h *TaskHandler) UncancelTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -211,7 +243,10 @@ func (h *TaskHandler) UncancelTask(c echo.Context) error {
// ArchiveTask handles POST /api/tasks/:id/archive/
func (h *TaskHandler) ArchiveTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -228,7 +263,10 @@ func (h *TaskHandler) ArchiveTask(c echo.Context) error {
// UnarchiveTask handles POST /api/tasks/:id/unarchive/
func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -246,7 +284,10 @@ func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
// QuickComplete handles POST /api/tasks/:id/quick-complete/
// Lightweight endpoint for widget - just returns 200 OK on success
func (h *TaskHandler) QuickComplete(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -263,7 +304,10 @@ func (h *TaskHandler) QuickComplete(c echo.Context) error {
// GetTaskCompletions handles GET /api/tasks/:id/completions/
func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -278,7 +322,10 @@ func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
// ListCompletions handles GET /api/task-completions/
func (h *TaskHandler) ListCompletions(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.taskService.ListCompletions(user.ID)
if err != nil {
return err
@@ -288,7 +335,10 @@ func (h *TaskHandler) ListCompletions(c echo.Context) error {
// GetCompletion handles GET /api/task-completions/:id/
func (h *TaskHandler) GetCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_completion_id")
@@ -304,7 +354,10 @@ func (h *TaskHandler) GetCompletion(c echo.Context) error {
// CreateCompletion handles POST /api/task-completions/
// Supports both JSON and multipart form data (for image uploads)
func (h *TaskHandler) CreateCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
var req requests.CreateTaskCompletionRequest
@@ -367,6 +420,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
}
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.CreateCompletion(&req, user.ID, userNow)
if err != nil {
return err
@@ -376,7 +433,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
// UpdateCompletion handles PUT /api/task-completions/:id/
func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_completion_id")
@@ -386,6 +446,9 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.UpdateCompletion(uint(completionID), user.ID, &req)
if err != nil {
@@ -396,7 +459,10 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
// DeleteCompletion handles DELETE /api/task-completions/:id/
func (h *TaskHandler) DeleteCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_completion_id")

View File

@@ -506,6 +506,52 @@ func TestTaskHandler_CreateCompletion(t *testing.T) {
})
}
func TestTaskHandler_CreateCompletion_Rating6_Returns400(t *testing.T) {
handler, e, db := setupTaskHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Rate Me")
authGroup := e.Group("/api/task-completions")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateCompletion)
t.Run("rating out of bounds rejected", func(t *testing.T) {
rating := 6
req := requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("rating zero rejected", func(t *testing.T) {
rating := 0
req := requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("rating 5 accepted", func(t *testing.T) {
rating := 5
completedAt := time.Now().UTC()
req := requests.CreateTaskCompletionRequest{
TaskID: task.ID,
CompletedAt: &completedAt,
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}
func TestTaskHandler_ListCompletions(t *testing.T) {
handler, e, db := setupTaskHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
@@ -603,6 +649,71 @@ func TestTaskHandler_DeleteCompletion(t *testing.T) {
})
}
func TestTaskHandler_CreateTask_EmptyTitle_Returns400(t *testing.T) {
handler, e, db := setupTaskHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/tasks")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateTask)
t.Run("empty body returns 400 with validation errors", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/tasks/", map[string]interface{}{}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Should contain structured validation error
assert.Contains(t, response, "error")
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
assert.Contains(t, fields, "title", "validation error should reference 'title'")
})
t.Run("missing title returns 400", func(t *testing.T) {
req := map[string]interface{}{
"residence_id": residence.ID,
}
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "title", "validation error should reference 'title'")
})
t.Run("missing residence_id returns 400", func(t *testing.T) {
req := map[string]interface{}{
"title": "Test Task",
}
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
})
}
func TestTaskHandler_GetLookups(t *testing.T) {
handler, e, db := setupTaskHandler(t)
testutil.SeedLookupData(t, db)

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/services"
)
@@ -32,7 +33,14 @@ func (h *TrackingHandler) TrackEmailOpen(c echo.Context) error {
if trackingID != "" && h.onboardingService != nil {
// Record the open (async, don't block response)
go func() {
_ = h.onboardingService.RecordEmailOpened(trackingID)
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("tracking_id", trackingID).Msg("Panic in email open tracking goroutine")
}
}()
if err := h.onboardingService.RecordEmailOpened(trackingID); err != nil {
log.Error().Err(err).Str("tracking_id", trackingID).Msg("Failed to record email open")
}
}()
}

View File

@@ -4,8 +4,11 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -73,17 +76,38 @@ func (h *UploadHandler) UploadCompletion(c echo.Context) error {
return c.JSON(http.StatusOK, result)
}
// DeleteFileRequest is the request body for deleting a file.
type DeleteFileRequest struct {
URL string `json:"url" validate:"required"`
}
// DeleteFile handles DELETE /api/uploads
// Expects JSON body with "url" field
// Expects JSON body with "url" field.
//
// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete
// any file if they know the URL. The upload system does not track which user uploaded
// which file, so a proper fix requires adding an uploads table or file ownership metadata.
// For now, deletions are logged with user ID for audit trail, and StorageService.Delete
// enforces path containment to prevent deleting files outside the upload directory.
func (h *UploadHandler) DeleteFile(c echo.Context) error {
var req struct {
URL string `json:"url" binding:"required"`
}
var req DeleteFileRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return apperrors.BadRequest("error.url_required")
}
// Log the deletion with user ID for audit trail
if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok {
log.Info().
Uint("user_id", user.ID).
Str("file_url", req.URL).
Msg("File deletion requested")
}
if err := h.storageService.Delete(req.URL); err != nil {
return err
}

View File

@@ -0,0 +1,43 @@
package handlers
import (
"net/http"
"testing"
"github.com/treytartt/casera-api/internal/i18n"
"github.com/treytartt/casera-api/internal/testutil"
)
func init() {
// Initialize i18n so the custom error handler can localize error messages.
// Other handler tests get this from testutil.SetupTestDB, but these tests
// don't need a database.
i18n.Init()
}
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
// Use a test storage service — DeleteFile won't reach storage since validation fails first
storageSvc := newTestStorageService("/var/uploads")
handler := NewUploadHandler(storageSvc)
e := testutil.SetupTestRouter()
// Register route
e.DELETE("/api/uploads/", handler.DeleteFile)
// Send request with empty JSON body (url field missing)
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
}
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
handler := NewUploadHandler(storageSvc)
e := testutil.SetupTestRouter()
e.DELETE("/api/uploads/", handler.DeleteFile)
// Send request with empty url field
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -26,7 +25,10 @@ func NewUserHandler(userService *services.UserService) *UserHandler {
// ListUsers handles GET /api/users/
func (h *UserHandler) ListUsers(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
// Only allow listing users that share residences with the current user
users, err := h.userService.ListUsersInSharedResidences(user.ID)
@@ -42,7 +44,10 @@ func (h *UserHandler) ListUsers(c echo.Context) error {
// GetUser handles GET /api/users/:id/
func (h *UserHandler) GetUser(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -60,7 +65,10 @@ func (h *UserHandler) GetUser(c echo.Context) error {
// ListProfiles handles GET /api/users/profiles/
func (h *UserHandler) ListProfiles(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
// List profiles of users in shared residences
profiles, err := h.userService.ListProfilesInSharedResidences(user.ID)

View File

@@ -0,0 +1,633 @@
package integration
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/admin/dto"
adminhandlers "github.com/treytartt/casera-api/internal/admin/handlers"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/handlers"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
"github.com/treytartt/casera-api/internal/validator"
)
// ============ Security Regression Test App ============
// SecurityTestApp holds components for security regression integration testing.
type SecurityTestApp struct {
DB *gorm.DB
Router *echo.Echo
SubscriptionService *services.SubscriptionService
SubscriptionRepo *repositories.SubscriptionRepository
}
func setupSecurityTest(t *testing.T) *SecurityTestApp {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
// Create repositories
userRepo := repositories.NewUserRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
notificationRepo := repositories.NewNotificationRepository(db)
// Create config
cfg := &config.Config{
Security: config.SecurityConfig{
SecretKey: "test-secret-key-for-security-tests",
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
},
}
// Create services
authService := services.NewAuthService(userRepo, cfg)
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
taskService := services.NewTaskService(taskRepo, residenceRepo)
notificationService := services.NewNotificationService(notificationRepo, nil)
// Wire up subscription service for tier limit enforcement
residenceService.SetSubscriptionService(subscriptionService)
// Create handlers
authHandler := handlers.NewAuthHandler(authService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo))
notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
// Create router with real middleware
e := echo.New()
e.Validator = validator.NewCustomValidator()
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
e.Use(middleware.TimezoneMiddleware())
// Public routes
auth := e.Group("/api/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
// Protected routes
authMiddleware := middleware.NewAuthMiddleware(db, nil)
api := e.Group("/api")
api.Use(authMiddleware.TokenAuth())
{
api.GET("/auth/me", authHandler.CurrentUser)
api.POST("/auth/logout", authHandler.Logout)
residences := api.Group("/residences")
{
residences.GET("", residenceHandler.ListResidences)
residences.POST("", residenceHandler.CreateResidence)
residences.GET("/:id", residenceHandler.GetResidence)
residences.PUT("/:id", residenceHandler.UpdateResidence)
residences.DELETE("/:id", residenceHandler.DeleteResidence)
}
tasks := api.Group("/tasks")
{
tasks.GET("", taskHandler.ListTasks)
tasks.POST("", taskHandler.CreateTask)
tasks.GET("/:id", taskHandler.GetTask)
tasks.PUT("/:id", taskHandler.UpdateTask)
tasks.DELETE("/:id", taskHandler.DeleteTask)
}
completions := api.Group("/completions")
{
completions.GET("", taskHandler.ListCompletions)
completions.POST("", taskHandler.CreateCompletion)
completions.GET("/:id", taskHandler.GetCompletion)
completions.DELETE("/:id", taskHandler.DeleteCompletion)
}
contractors := api.Group("/contractors")
{
contractors.GET("", contractorHandler.ListContractors)
contractors.POST("", contractorHandler.CreateContractor)
contractors.GET("/:id", contractorHandler.GetContractor)
}
subscription := api.Group("/subscription")
{
subscription.GET("/", subscriptionHandler.GetSubscription)
subscription.GET("/status/", subscriptionHandler.GetSubscriptionStatus)
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
}
notifications := api.Group("/notifications")
{
notifications.GET("", notificationHandler.ListNotifications)
}
}
return &SecurityTestApp{
DB: db,
Router: e,
SubscriptionService: subscriptionService,
SubscriptionRepo: subscriptionRepo,
}
}
// registerAndLoginSec registers and logs in a user, returns token and user ID.
func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, password string) (string, uint) {
// Register
registerBody := map[string]string{
"username": username,
"email": email,
"password": password,
}
w := app.makeAuthReq(t, "POST", "/api/auth/register", registerBody, "")
require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed for %s", username)
// Login
loginBody := map[string]string{
"username": username,
"password": password,
}
w = app.makeAuthReq(t, "POST", "/api/auth/login", loginBody, "")
require.Equal(t, http.StatusOK, w.Code, "Login should succeed for %s", username)
var loginResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
require.NoError(t, err)
token := loginResp["token"].(string)
userMap := loginResp["user"].(map[string]interface{})
userID := uint(userMap["id"].(float64))
return token, userID
}
// makeAuthReq creates and sends an HTTP request through the router.
func (app *SecurityTestApp) makeAuthReq(t *testing.T, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
var reqBody []byte
var err error
if body != nil {
reqBody, err = json.Marshal(body)
require.NoError(t, err)
}
req := httptest.NewRequest(method, path, bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "Token "+token)
}
w := httptest.NewRecorder()
app.Router.ServeHTTP(w, req)
return w
}
// ============ Test 1: Path Traversal Blocked ============
// TestE2E_PathTraversal_AllMediaEndpoints_Blocked verifies that the SafeResolvePath
// function (used by all media endpoints) blocks path traversal attempts.
// A document with a traversal URL like ../../../etc/passwd cannot be used to read
// arbitrary files from the filesystem.
func TestE2E_PathTraversal_AllMediaEndpoints_Blocked(t *testing.T) {
// Test the SafeResolvePath function that guards all three media endpoints:
// ServeDocument, ServeDocumentImage, ServeCompletionImage
// Each calls resolveFilePath -> SafeResolvePath to validate containment.
traversalPaths := []struct {
name string
url string
}{
{"simple dotdot", "../../../etc/passwd"},
{"nested dotdot", "../../etc/shadow"},
{"embedded dotdot", "images/../../../../../../etc/passwd"},
{"deep traversal", "a/b/c/../../../../etc/passwd"},
{"uploads prefix with dotdot", "../../../etc/passwd"},
}
for _, tt := range traversalPaths {
t.Run(tt.name, func(t *testing.T) {
// SafeResolvePath must reject all traversal attempts
_, err := services.SafeResolvePath("/var/uploads", tt.url)
assert.Error(t, err, "Path traversal should be blocked for: %s", tt.url)
})
}
// Verify that a legitimate path still works
t.Run("legitimate_path_allowed", func(t *testing.T) {
result, err := services.SafeResolvePath("/var/uploads", "documents/file.pdf")
assert.NoError(t, err, "Legitimate path should be allowed")
assert.Equal(t, "/var/uploads/documents/file.pdf", result)
})
// Verify absolute paths are blocked
t.Run("absolute_path_blocked", func(t *testing.T) {
_, err := services.SafeResolvePath("/var/uploads", "/etc/passwd")
assert.Error(t, err, "Absolute paths should be blocked")
})
// Verify empty paths are blocked
t.Run("empty_path_blocked", func(t *testing.T) {
_, err := services.SafeResolvePath("/var/uploads", "")
assert.Error(t, err, "Empty paths should be blocked")
})
}
// ============ Test 2: SQL Injection in Admin Sort ============
// TestE2E_SQLInjection_AdminSort_Blocked verifies that the admin user list endpoint
// uses the allowlist-based sort column sanitization and does not execute injected SQL.
func TestE2E_SQLInjection_AdminSort_Blocked(t *testing.T) {
db := testutil.SetupTestDB(t)
// Create admin user handler which uses the sort_by parameter
adminUserHandler := adminhandlers.NewAdminUserHandler(db)
// Create a couple of test users to have data to sort
testutil.CreateTestUser(t, db, "alice", "alice@test.com", "password123")
testutil.CreateTestUser(t, db, "bob", "bob@test.com", "password123")
// Set up a minimal Echo instance with the admin handler
e := echo.New()
e.Validator = validator.NewCustomValidator()
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
e.GET("/api/admin/users", adminUserHandler.List)
injections := []struct {
name string
sortBy string
}{
{"DROP TABLE", "created_at; DROP TABLE auth_user; --"},
{"UNION SELECT", "id UNION SELECT password FROM auth_user"},
{"subquery", "(SELECT password FROM auth_user LIMIT 1)"},
{"OR 1=1", "created_at OR 1=1"},
{"semicolon", "created_at;"},
{"single quotes", "name'; DROP TABLE auth_user; --"},
}
for _, tt := range injections {
t.Run(tt.name, func(t *testing.T) {
path := fmt.Sprintf("/api/admin/users?sort_by=%s", tt.sortBy)
w := testutil.MakeRequest(e, "GET", path, nil, "")
// Handler should return 200 (using safe default sort), NOT 500
assert.Equal(t, http.StatusOK, w.Code,
"Admin user list should succeed with safe default sort, not crash from injection: %s", tt.sortBy)
// Parse response to verify valid paginated data
var resp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &resp)
assert.NoError(t, err, "Response should be valid JSON")
// Verify the auth_user table still exists (not dropped)
var count int64
dbErr := db.Model(&models.User{}).Count(&count).Error
assert.NoError(t, dbErr, "auth_user table should still exist after injection attempt")
assert.GreaterOrEqual(t, count, int64(2), "Users should still be in the database")
})
}
// Verify the DTO allowlist directly
t.Run("DTO_GetSafeSortBy_rejects_injection", func(t *testing.T) {
p := dto.PaginationParams{SortBy: "created_at; DROP TABLE auth_user; --"}
result := p.GetSafeSortBy([]string{"id", "username", "email", "date_joined"}, "date_joined")
assert.Equal(t, "date_joined", result, "Injection should fall back to default column")
})
}
// ============ Test 3: IAP Invalid Receipt Does Not Grant Pro ============
// TestE2E_IAP_InvalidReceipt_NoPro verifies that submitting a purchase with
// garbage receipt data does NOT upgrade the user to Pro tier.
func TestE2E_IAP_InvalidReceipt_NoPro(t *testing.T) {
app := setupSecurityTest(t)
token, userID := app.registerAndLoginSec(t, "iapuser", "iap@test.com", "password123")
// Create initial subscription (free tier)
sub := &models.UserSubscription{UserID: userID, Tier: models.TierFree}
require.NoError(t, app.DB.Create(sub).Error)
// Submit a purchase with garbage receipt data
purchaseBody := map[string]interface{}{
"platform": "ios",
"receipt_data": "GARBAGE_RECEIPT_DATA_THAT_IS_NOT_VALID",
}
w := app.makeAuthReq(t, "POST", "/api/subscription/purchase/", purchaseBody, token)
// The purchase should fail (Apple client is nil in test environment)
assert.NotEqual(t, http.StatusOK, w.Code,
"Purchase with garbage receipt should NOT succeed")
// Verify user is still on free tier
updatedSub, err := app.SubscriptionRepo.GetOrCreate(userID)
require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier,
"User should remain on free tier after invalid receipt submission")
}
// ============ Test 4: Completion Transaction Atomicity ============
// TestE2E_CompletionTransaction_Atomic verifies that creating a task completion
// updates both the completion record and the task's NextDueDate together (P1-5/P1-6).
func TestE2E_CompletionTransaction_Atomic(t *testing.T) {
app := setupSecurityTest(t)
token, _ := app.registerAndLoginSec(t, "atomicuser", "atomic@test.com", "password123")
// Create a residence
residenceBody := map[string]interface{}{"name": "Atomic Test House"}
w := app.makeAuthReq(t, "POST", "/api/residences", residenceBody, token)
require.Equal(t, http.StatusCreated, w.Code)
var residenceResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &residenceResp)
residenceData := residenceResp["data"].(map[string]interface{})
residenceID := residenceData["id"].(float64)
// Create a one-time task with a due date
dueDate := time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")
taskBody := map[string]interface{}{
"residence_id": uint(residenceID),
"title": "One-Time Atomic Task",
"due_date": dueDate,
}
w = app.makeAuthReq(t, "POST", "/api/tasks", taskBody, token)
require.Equal(t, http.StatusCreated, w.Code)
var taskResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskResp)
taskData := taskResp["data"].(map[string]interface{})
taskID := taskData["id"].(float64)
// Verify task has a next_due_date before completion
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
require.Equal(t, http.StatusOK, w.Code)
var taskBefore map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskBefore)
assert.NotNil(t, taskBefore["next_due_date"], "Task should have next_due_date before completion")
// Create completion
completionBody := map[string]interface{}{
"task_id": uint(taskID),
"notes": "Completed for atomicity test",
}
w = app.makeAuthReq(t, "POST", "/api/completions", completionBody, token)
require.Equal(t, http.StatusCreated, w.Code)
var completionResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &completionResp)
completionData := completionResp["data"].(map[string]interface{})
completionID := completionData["id"].(float64)
assert.NotZero(t, completionID, "Completion should be created with valid ID")
// Verify task is now completed (next_due_date should be nil for one-time task)
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
require.Equal(t, http.StatusOK, w.Code)
var taskAfter map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskAfter)
assert.Nil(t, taskAfter["next_due_date"],
"One-time task should have nil next_due_date after completion (atomic update)")
assert.Equal(t, "completed_tasks", taskAfter["kanban_column"],
"Task should be in completed column after completion")
// Verify completion record exists
w = app.makeAuthReq(t, "GET", "/api/completions/"+formatID(completionID), nil, token)
assert.Equal(t, http.StatusOK, w.Code, "Completion record should exist")
}
// ============ Test 5: Delete Completion Recalculates NextDueDate ============
// TestE2E_DeleteCompletion_RecalculatesNextDueDate verifies that deleting a completion
// on a recurring task recalculates NextDueDate back to the correct value (P1-7).
func TestE2E_DeleteCompletion_RecalculatesNextDueDate(t *testing.T) {
app := setupSecurityTest(t)
token, _ := app.registerAndLoginSec(t, "recuruser", "recur@test.com", "password123")
// Create a residence
residenceBody := map[string]interface{}{"name": "Recurring Test House"}
w := app.makeAuthReq(t, "POST", "/api/residences", residenceBody, token)
require.Equal(t, http.StatusCreated, w.Code)
var residenceResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &residenceResp)
residenceData := residenceResp["data"].(map[string]interface{})
residenceID := residenceData["id"].(float64)
// Get the "Weekly" frequency ID from the database
var weeklyFreq models.TaskFrequency
err := app.DB.Where("name = ?", "Weekly").First(&weeklyFreq).Error
require.NoError(t, err, "Weekly frequency should exist from seed data")
// Create a recurring (weekly) task with a due date
dueDate := time.Now().Add(-1 * 24 * time.Hour).Format("2006-01-02")
taskBody := map[string]interface{}{
"residence_id": uint(residenceID),
"title": "Weekly Recurring Task",
"frequency_id": weeklyFreq.ID,
"due_date": dueDate,
}
w = app.makeAuthReq(t, "POST", "/api/tasks", taskBody, token)
require.Equal(t, http.StatusCreated, w.Code)
var taskResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskResp)
taskData := taskResp["data"].(map[string]interface{})
taskID := taskData["id"].(float64)
// Record original next_due_date
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
require.Equal(t, http.StatusOK, w.Code)
var taskOriginal map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskOriginal)
originalNextDueDate := taskOriginal["next_due_date"]
require.NotNil(t, originalNextDueDate, "Recurring task should have initial next_due_date")
// Create a completion (should advance NextDueDate by 7 days from completion date)
completionBody := map[string]interface{}{
"task_id": uint(taskID),
"notes": "Weekly completion",
}
w = app.makeAuthReq(t, "POST", "/api/completions", completionBody, token)
require.Equal(t, http.StatusCreated, w.Code)
var completionResp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &completionResp)
completionData := completionResp["data"].(map[string]interface{})
completionID := completionData["id"].(float64)
// Verify NextDueDate advanced
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
require.Equal(t, http.StatusOK, w.Code)
var taskAfterCompletion map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskAfterCompletion)
advancedNextDueDate := taskAfterCompletion["next_due_date"]
assert.NotNil(t, advancedNextDueDate, "Recurring task should still have next_due_date after completion")
assert.NotEqual(t, originalNextDueDate, advancedNextDueDate,
"NextDueDate should have advanced after completion")
// Delete the completion
w = app.makeAuthReq(t, "DELETE", "/api/completions/"+formatID(completionID), nil, token)
require.Equal(t, http.StatusOK, w.Code)
// Verify NextDueDate was recalculated back to original due date
w = app.makeAuthReq(t, "GET", "/api/tasks/"+formatID(taskID), nil, token)
require.Equal(t, http.StatusOK, w.Code)
var taskAfterDelete map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &taskAfterDelete)
restoredNextDueDate := taskAfterDelete["next_due_date"]
// After deleting the only completion, NextDueDate should be restored to the original DueDate
assert.NotNil(t, restoredNextDueDate, "NextDueDate should be restored after deleting the only completion")
assert.Equal(t, originalNextDueDate, restoredNextDueDate,
"NextDueDate should be recalculated back to original due date after completion deletion")
}
// ============ Test 6: Tier Limits Enforced ============
// TestE2E_TierLimits_Enforced verifies that a free-tier user cannot exceed the
// configured property limit.
func TestE2E_TierLimits_Enforced(t *testing.T) {
app := setupSecurityTest(t)
token, userID := app.registerAndLoginSec(t, "tieruser", "tier@test.com", "password123")
// Enable global limitations
app.DB.Where("1=1").Delete(&models.SubscriptionSettings{})
settings := &models.SubscriptionSettings{EnableLimitations: true}
require.NoError(t, app.DB.Create(settings).Error)
// Set free tier limit to 1 property
one := 1
app.DB.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{})
freeLimits := &models.TierLimits{
Tier: models.TierFree,
PropertiesLimit: &one,
}
require.NoError(t, app.DB.Create(freeLimits).Error)
// Ensure user is on free tier
sub, err := app.SubscriptionRepo.GetOrCreate(userID)
require.NoError(t, err)
require.Equal(t, models.TierFree, sub.Tier)
// First residence should succeed
residenceBody := map[string]interface{}{"name": "First Property"}
w := app.makeAuthReq(t, "POST", "/api/residences", residenceBody, token)
require.Equal(t, http.StatusCreated, w.Code, "First residence should be allowed within limit")
// Second residence should be blocked
residenceBody2 := map[string]interface{}{"name": "Second Property (over limit)"}
w = app.makeAuthReq(t, "POST", "/api/residences", residenceBody2, token)
assert.Equal(t, http.StatusForbidden, w.Code,
"Second residence should be blocked by tier limit")
// Verify error response
var errResp map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &errResp)
require.NoError(t, err)
assert.Contains(t, fmt.Sprintf("%v", errResp), "limit",
"Error response should reference the limit")
}
// ============ Test 7: Auth Assertion -- No Panics on Missing User ============
// TestE2E_AuthAssertion_NoPanics verifies that all protected endpoints return
// 401 Unauthorized (not 500 panic) when no auth token is provided.
func TestE2E_AuthAssertion_NoPanics(t *testing.T) {
app := setupSecurityTest(t)
// Make requests to protected endpoints WITHOUT any token.
endpoints := []struct {
name string
method string
path string
}{
{"ListTasks", "GET", "/api/tasks"},
{"CreateTask", "POST", "/api/tasks"},
{"GetTask", "GET", "/api/tasks/1"},
{"ListResidences", "GET", "/api/residences"},
{"CreateResidence", "POST", "/api/residences"},
{"GetResidence", "GET", "/api/residences/1"},
{"ListCompletions", "GET", "/api/completions"},
{"CreateCompletion", "POST", "/api/completions"},
{"ListContractors", "GET", "/api/contractors"},
{"CreateContractor", "POST", "/api/contractors"},
{"GetSubscription", "GET", "/api/subscription/"},
{"SubscriptionStatus", "GET", "/api/subscription/status/"},
{"ProcessPurchase", "POST", "/api/subscription/purchase/"},
{"ListNotifications", "GET", "/api/notifications"},
{"CurrentUser", "GET", "/api/auth/me"},
}
for _, ep := range endpoints {
t.Run(ep.name, func(t *testing.T) {
w := app.makeAuthReq(t, ep.method, ep.path, nil, "")
assert.Equal(t, http.StatusUnauthorized, w.Code,
"Endpoint %s %s should return 401, not panic with 500", ep.method, ep.path)
})
}
// Also test with an invalid token (should be 401, not 500)
t.Run("InvalidToken", func(t *testing.T) {
w := app.makeAuthReq(t, "GET", "/api/tasks", nil, "completely-invalid-token-xyz")
assert.Equal(t, http.StatusUnauthorized, w.Code,
"Invalid token should return 401, not panic")
})
}
// ============ Test 8: Notification Limit Capped ============
// TestE2E_NotificationLimit_Capped verifies that the notification list endpoint
// caps the limit parameter to 200 even if the client requests more.
func TestE2E_NotificationLimit_Capped(t *testing.T) {
app := setupSecurityTest(t)
token, userID := app.registerAndLoginSec(t, "notifuser", "notif@test.com", "password123")
// Create 210 notifications directly in the database
for i := 0; i < 210; i++ {
notification := &models.Notification{
UserID: userID,
NotificationType: models.NotificationTaskCompleted,
Title: fmt.Sprintf("Test Notification %d", i),
Body: fmt.Sprintf("Body for notification %d", i),
}
require.NoError(t, app.DB.Create(notification).Error)
}
// Request with limit=999 (should be capped to 200 by the handler)
w := app.makeAuthReq(t, "GET", "/api/notifications?limit=999", nil, token)
require.Equal(t, http.StatusOK, w.Code)
var notifResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &notifResp)
require.NoError(t, err)
count := int(notifResp["count"].(float64))
assert.LessOrEqual(t, count, 200,
"Notification count should be capped at 200 even when requesting limit=999")
results := notifResp["results"].([]interface{})
assert.LessOrEqual(t, len(results), 200,
"Notification results should have at most 200 items")
}

View File

@@ -35,7 +35,9 @@ func AdminAuthMiddleware(cfg *config.Config, adminRepo *repositories.AdminReposi
return func(c echo.Context) error {
var tokenString string
// Get token from Authorization header
// Get token from Authorization header only.
// Query parameter authentication is intentionally not supported
// because tokens in URLs leak into server logs and browser history.
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" {
// Check Bearer prefix
@@ -45,11 +47,6 @@ func AdminAuthMiddleware(cfg *config.Config, adminRepo *repositories.AdminReposi
}
}
// If no header token, check query parameter (for WebSocket connections)
if tokenString == "" {
tokenString = c.QueryParam("token")
}
if tokenString == "" {
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Authorization required"})
}
@@ -121,7 +118,10 @@ func RequireSuperAdmin() echo.MiddlewareFunc {
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Admin authentication required"})
}
adminUser := admin.(*models.AdminUser)
adminUser, ok := admin.(*models.AdminUser)
if !ok {
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "Admin authentication required"})
}
if !adminUser.IsSuperAdmin() {
return c.JSON(http.StatusForbidden, map[string]interface{}{"error": "Super admin privileges required"})
}

View File

@@ -63,7 +63,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
// Cache miss - look up token in database
user, err = m.getUserFromDatabase(token)
if err != nil {
log.Debug().Err(err).Str("token", token[:8]+"...").Msg("Token authentication failed")
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
return apperrors.Unauthorized("error.invalid_token")
}
@@ -200,13 +200,18 @@ func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) erro
return m.cache.InvalidateAuthToken(ctx, token)
}
// GetAuthUser retrieves the authenticated user from the Echo context
// GetAuthUser retrieves the authenticated user from the Echo context.
// Returns nil if the context value is missing or not of the expected type.
func GetAuthUser(c echo.Context) *models.User {
user := c.Get(AuthUserKey)
if user == nil {
val := c.Get(AuthUserKey)
if val == nil {
return nil
}
return user.(*models.User)
user, ok := val.(*models.User)
if !ok {
return nil
}
return user
}
// GetAuthToken retrieves the auth token from the Echo context
@@ -226,3 +231,12 @@ func MustGetAuthUser(c echo.Context) (*models.User, error) {
}
return user, nil
}
// truncateToken safely truncates a token string for logging.
// Returns at most the first 8 characters followed by "...".
func truncateToken(token string) string {
if len(token) > 8 {
return token[:8] + "..."
}
return token + "..."
}

View File

@@ -0,0 +1,119 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/models"
)
func TestGetAuthUser_NilContext_ReturnsNil(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// No user set in context
user := GetAuthUser(c)
assert.Nil(t, user)
}
func TestGetAuthUser_WrongType_ReturnsNil(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Set wrong type in context — should NOT panic
c.Set(AuthUserKey, "not-a-user")
user := GetAuthUser(c)
assert.Nil(t, user)
}
func TestGetAuthUser_ValidUser_ReturnsUser(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
expected := &models.User{Username: "testuser"}
c.Set(AuthUserKey, expected)
user := GetAuthUser(c)
require.NotNil(t, user)
assert.Equal(t, "testuser", user.Username)
}
func TestMustGetAuthUser_Nil_Returns401(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
user, err := MustGetAuthUser(c)
assert.Nil(t, user)
assert.Error(t, err)
}
func TestMustGetAuthUser_WrongType_Returns401(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(AuthUserKey, 12345)
user, err := MustGetAuthUser(c)
assert.Nil(t, user)
assert.Error(t, err)
}
func TestTokenTruncation_ShortToken_NoPanic(t *testing.T) {
// Ensure truncateToken does not panic on short tokens
assert.NotPanics(t, func() {
result := truncateToken("ab")
assert.Equal(t, "ab...", result)
})
}
func TestTokenTruncation_EmptyToken_NoPanic(t *testing.T) {
assert.NotPanics(t, func() {
result := truncateToken("")
assert.Equal(t, "...", result)
})
}
func TestTokenTruncation_LongToken_Truncated(t *testing.T) {
result := truncateToken("abcdefghijklmnop")
assert.Equal(t, "abcdefgh...", result)
}
func TestAdminAuth_QueryParamToken_Rejected(t *testing.T) {
// SEC-20: Admin JWT via query parameter must be rejected.
// Tokens in URLs leak into server logs and browser history.
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
mw := AdminAuthMiddleware(cfg, nil)
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "should not reach here")
})
e := echo.New()
// Request with token only in query param, no Authorization header
req := httptest.NewRequest(http.MethodGet, "/admin/test?token=some-jwt-token", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.NoError(t, err) // handler writes JSON directly, no Echo error
assert.Equal(t, http.StatusUnauthorized, rec.Code, "query param token must be rejected")
assert.Contains(t, rec.Body.String(), "Authorization required")
}

View File

@@ -1,10 +1,15 @@
package middleware
import (
"regexp"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
)
// validRequestID matches alphanumeric characters and hyphens, 1-64 chars.
var validRequestID = regexp.MustCompile(`^[a-zA-Z0-9\-]{1,64}$`)
const (
// HeaderXRequestID is the header key for request correlation IDs
HeaderXRequestID = "X-Request-ID"
@@ -17,9 +22,11 @@ const (
func RequestIDMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Use existing request ID from header if present, otherwise generate one
// Use existing request ID from header if present and valid, otherwise generate one.
// Sanitize to alphanumeric + hyphens only (max 64 chars) to prevent
// log injection via control characters or overly long values.
reqID := c.Request().Header.Get(HeaderXRequestID)
if reqID == "" {
if reqID == "" || !validRequestID.MatchString(reqID) {
reqID = uuid.New().String()
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestRequestID_ValidID_Preserved(t *testing.T) {
e := echo.New()
mw := RequestIDMiddleware()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, GetRequestID(c))
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(HeaderXRequestID, "abc-123-def")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.NoError(t, err)
assert.Equal(t, "abc-123-def", rec.Body.String())
assert.Equal(t, "abc-123-def", rec.Header().Get(HeaderXRequestID))
}
func TestRequestID_Empty_GeneratesNew(t *testing.T) {
e := echo.New()
mw := RequestIDMiddleware()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, GetRequestID(c))
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
// No X-Request-ID header
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.NoError(t, err)
// Should be a UUID (36 chars: 8-4-4-4-12)
assert.Len(t, rec.Body.String(), 36)
}
func TestRequestID_ControlChars_Sanitized(t *testing.T) {
// SEC-29: Client-supplied X-Request-ID with control characters must be rejected.
tests := []struct {
name string
inputID string
}{
{"newline injection", "abc\ndef"},
{"carriage return", "abc\rdef"},
{"null byte", "abc\x00def"},
{"tab character", "abc\tdef"},
{"html tags", "abc<script>alert(1)</script>"},
{"spaces", "abc def"},
{"semicolons", "abc;def"},
{"unicode", "abc\u200bdef"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := echo.New()
mw := RequestIDMiddleware()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, GetRequestID(c))
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(HeaderXRequestID, tt.inputID)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.NoError(t, err)
// The malicious ID should be replaced with a generated UUID
assert.NotEqual(t, tt.inputID, rec.Body.String(),
"control chars should be rejected, got original ID back")
assert.Len(t, rec.Body.String(), 36, "should be a generated UUID")
})
}
}
func TestRequestID_TooLong_Sanitized(t *testing.T) {
// SEC-29: X-Request-ID longer than 64 chars should be rejected.
e := echo.New()
mw := RequestIDMiddleware()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, GetRequestID(c))
})
longID := strings.Repeat("a", 65)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(HeaderXRequestID, longID)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.NoError(t, err)
assert.NotEqual(t, longID, rec.Body.String(), "overly long ID should be replaced")
assert.Len(t, rec.Body.String(), 36, "should be a generated UUID")
}
func TestRequestID_MaxLength_Accepted(t *testing.T) {
// Exactly 64 chars of valid characters should be accepted
e := echo.New()
mw := RequestIDMiddleware()
handler := mw(func(c echo.Context) error {
return c.String(http.StatusOK, GetRequestID(c))
})
maxID := strings.Repeat("a", 64)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(HeaderXRequestID, maxID)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.NoError(t, err)
assert.Equal(t, maxID, rec.Body.String(), "64-char valid ID should be accepted")
}

View File

@@ -0,0 +1,19 @@
package middleware
import "strings"
// SanitizeSortColumn validates a user-supplied sort column against an allowlist.
// Returns defaultCol if the input is empty or not in the allowlist.
// This prevents SQL injection via ORDER BY clauses.
func SanitizeSortColumn(input string, allowedCols []string, defaultCol string) string {
input = strings.TrimSpace(input)
if input == "" {
return defaultCol
}
for _, col := range allowedCols {
if strings.EqualFold(input, col) {
return col
}
}
return defaultCol
}

View File

@@ -0,0 +1,59 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSanitizeSortColumn_AllowedColumn_Passes(t *testing.T) {
allowed := []string{"created_at", "updated_at", "name"}
result := SanitizeSortColumn("created_at", allowed, "created_at")
assert.Equal(t, "created_at", result)
}
func TestSanitizeSortColumn_CaseInsensitive(t *testing.T) {
allowed := []string{"created_at", "updated_at", "name"}
result := SanitizeSortColumn("Created_At", allowed, "created_at")
assert.Equal(t, "created_at", result)
}
func TestSanitizeSortColumn_SQLInjection_ReturnsDefault(t *testing.T) {
allowed := []string{"created_at", "updated_at", "name"}
tests := []struct {
name string
input string
}{
{"drop table", "created_at; DROP TABLE auth_user; --"},
{"union select", "name UNION SELECT * FROM auth_user"},
{"or 1=1", "name OR 1=1"},
{"semicolon", "created_at;"},
{"subquery", "(SELECT password FROM auth_user LIMIT 1)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeSortColumn(tt.input, allowed, "created_at")
assert.Equal(t, "created_at", result, "SQL injection attempt should return default")
})
}
}
func TestSanitizeSortColumn_Empty_ReturnsDefault(t *testing.T) {
allowed := []string{"created_at", "updated_at", "name"}
result := SanitizeSortColumn("", allowed, "created_at")
assert.Equal(t, "created_at", result)
}
func TestSanitizeSortColumn_Whitespace_ReturnsDefault(t *testing.T) {
allowed := []string{"created_at", "updated_at", "name"}
result := SanitizeSortColumn(" ", allowed, "created_at")
assert.Equal(t, "created_at", result)
}
func TestSanitizeSortColumn_UnknownColumn_ReturnsDefault(t *testing.T) {
allowed := []string{"created_at", "updated_at", "name"}
result := SanitizeSortColumn("nonexistent_column", allowed, "created_at")
assert.Equal(t, "created_at", result)
}

View File

@@ -79,22 +79,30 @@ func parseTimezone(tz string) *time.Location {
}
// GetUserTimezone retrieves the user's timezone from the Echo context.
// Returns UTC if not set.
// Returns UTC if not set or if the stored value is not a *time.Location.
func GetUserTimezone(c echo.Context) *time.Location {
loc := c.Get(TimezoneKey)
if loc == nil {
val := c.Get(TimezoneKey)
if val == nil {
return time.UTC
}
return loc.(*time.Location)
loc, ok := val.(*time.Location)
if !ok {
return time.UTC
}
return loc
}
// GetUserNow retrieves the timezone-aware "now" time from the Echo context.
// This represents the start of the current day in the user's timezone.
// Returns time.Now().UTC() if not set.
// Returns time.Now().UTC() if not set or if the stored value is not a time.Time.
func GetUserNow(c echo.Context) time.Time {
now := c.Get(UserNowKey)
if now == nil {
val := c.Get(UserNowKey)
if val == nil {
return time.Now().UTC()
}
return now.(time.Time)
now, ok := val.(time.Time)
if !ok {
return time.Now().UTC()
}
return now
}

View File

@@ -52,14 +52,18 @@ func (c *Collector) Collect() SystemStats {
// CPU stats
c.collectCPU(&stats)
// Read Go runtime memory stats once (used by both memory and runtime collectors)
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// Memory stats (system + Go runtime)
c.collectMemory(&stats)
c.collectMemory(&stats, &memStats)
// Disk stats
c.collectDisk(&stats)
// Go runtime stats
c.collectRuntime(&stats)
c.collectRuntime(&stats, &memStats)
// HTTP stats (API only)
if c.httpCollector != nil {
@@ -77,9 +81,9 @@ func (c *Collector) Collect() SystemStats {
}
func (c *Collector) collectCPU(stats *SystemStats) {
// Get CPU usage percentage (blocks for 1 second to get accurate sample)
// Shorter intervals can give inaccurate readings
if cpuPercent, err := cpu.Percent(time.Second, false); err == nil && len(cpuPercent) > 0 {
// Get CPU usage percentage (blocks for 200ms to sample)
// This is called periodically, so a shorter window is acceptable
if cpuPercent, err := cpu.Percent(200*time.Millisecond, false); err == nil && len(cpuPercent) > 0 {
stats.CPU.UsagePercent = cpuPercent[0]
}
@@ -93,7 +97,7 @@ func (c *Collector) collectCPU(stats *SystemStats) {
}
}
func (c *Collector) collectMemory(stats *SystemStats) {
func (c *Collector) collectMemory(stats *SystemStats, m *runtime.MemStats) {
// System memory
if vmem, err := mem.VirtualMemory(); err == nil {
stats.Memory.UsedBytes = vmem.Used
@@ -101,9 +105,7 @@ func (c *Collector) collectMemory(stats *SystemStats) {
stats.Memory.UsagePercent = vmem.UsedPercent
}
// Go runtime memory
var m runtime.MemStats
runtime.ReadMemStats(&m)
// Go runtime memory (reuses pre-read MemStats)
stats.Memory.HeapAlloc = m.HeapAlloc
stats.Memory.HeapSys = m.HeapSys
stats.Memory.HeapInuse = m.HeapInuse
@@ -119,10 +121,7 @@ func (c *Collector) collectDisk(stats *SystemStats) {
}
}
func (c *Collector) collectRuntime(stats *SystemStats) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
func (c *Collector) collectRuntime(stats *SystemStats, m *runtime.MemStats) {
stats.Runtime.Goroutines = runtime.NumGoroutine()
stats.Runtime.NumGC = m.NumGC
if m.NumGC > 0 {

View File

@@ -17,8 +17,13 @@ var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Allow connections from admin panel
return true
origin := r.Header.Get("Origin")
if origin == "" {
// Same-origin requests may omit the Origin header
return true
}
// Allow if origin matches the request host
return strings.HasPrefix(origin, "https://"+r.Host) || strings.HasPrefix(origin, "http://"+r.Host)
},
}
@@ -116,6 +121,7 @@ func (h *Handler) WebSocket(c echo.Context) error {
conn, err := upgrader.Upgrade(c.Response().Writer, c.Request(), nil)
if err != nil {
log.Error().Err(err).Msg("Failed to upgrade WebSocket connection")
return err
}
defer conn.Close()
@@ -174,6 +180,7 @@ func (h *Handler) WebSocket(c echo.Context) error {
h.sendStats(conn, &wsMu)
case <-ctx.Done():
return nil
}
}
}

View File

@@ -108,6 +108,10 @@ func (s *Service) Stop() {
close(s.settingsStopCh)
s.collector.Stop()
// Flush and close the log writer's background goroutine
s.logWriter.Close()
log.Info().Str("process", s.process).Msg("Monitoring service stopped")
}

View File

@@ -8,23 +8,56 @@ import (
"github.com/google/uuid"
)
// RedisLogWriter implements io.Writer to capture zerolog output to Redis
const (
// writerChannelSize is the buffer size for the async log write channel.
// Entries beyond this limit are dropped to prevent unbounded memory growth.
writerChannelSize = 256
)
// RedisLogWriter implements io.Writer to capture zerolog output to Redis.
// It uses a single background goroutine with a buffered channel instead of
// spawning a new goroutine per log line, preventing unbounded goroutine growth.
type RedisLogWriter struct {
buffer *LogBuffer
process string
enabled atomic.Bool
ch chan LogEntry
done chan struct{}
}
// NewRedisLogWriter creates a new writer that captures logs to Redis
// NewRedisLogWriter creates a new writer that captures logs to Redis.
// It starts a single background goroutine that drains the buffered channel.
func NewRedisLogWriter(buffer *LogBuffer, process string) *RedisLogWriter {
w := &RedisLogWriter{
buffer: buffer,
process: process,
ch: make(chan LogEntry, writerChannelSize),
done: make(chan struct{}),
}
w.enabled.Store(true) // enabled by default
// Single background goroutine drains the channel
go w.drainLoop()
return w
}
// drainLoop reads entries from the buffered channel and pushes them to Redis.
// It runs in a single goroutine for the lifetime of the writer.
func (w *RedisLogWriter) drainLoop() {
defer close(w.done)
for entry := range w.ch {
_ = w.buffer.Push(entry) // Ignore errors to avoid blocking log output
}
}
// Close shuts down the background goroutine. It should be called during
// graceful shutdown to ensure all buffered entries are flushed.
func (w *RedisLogWriter) Close() {
close(w.ch)
<-w.done // Wait for drain to finish
}
// SetEnabled enables or disables log capture to Redis
func (w *RedisLogWriter) SetEnabled(enabled bool) {
w.enabled.Store(enabled)
@@ -35,8 +68,10 @@ func (w *RedisLogWriter) IsEnabled() bool {
return w.enabled.Load()
}
// Write implements io.Writer interface
// It parses zerolog JSON output and writes to Redis asynchronously
// Write implements io.Writer interface.
// It parses zerolog JSON output and sends it to the buffered channel for
// async Redis writes. If the channel is full, the entry is dropped to
// avoid blocking the caller (back-pressure shedding).
func (w *RedisLogWriter) Write(p []byte) (n int, err error) {
// Skip if monitoring is disabled
if !w.enabled.Load() {
@@ -86,10 +121,14 @@ func (w *RedisLogWriter) Write(p []byte) (n int, err error) {
}
}
// Write to Redis asynchronously to avoid blocking
go func() {
_ = w.buffer.Push(entry) // Ignore errors to avoid blocking log output
}()
// Non-blocking send: drop entries if channel is full rather than
// spawning unbounded goroutines or blocking the logger
select {
case w.ch <- entry:
// Sent successfully
default:
// Channel full — drop this entry to avoid back-pressure on the logger
}
return len(p), nil
}

View File

@@ -117,6 +117,9 @@ func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message st
// Log individual results
for i, result := range fcmResp.Results {
if i >= len(tokens) {
break
}
if result.Error != "" {
log.Error().
Str("token", truncateToken(tokens[i])).

186
internal/push/fcm_test.go Normal file
View File

@@ -0,0 +1,186 @@
package push
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestFCMClient creates an FCMClient pointing at the given test server URL.
func newTestFCMClient(serverURL string) *FCMClient {
return &FCMClient{
serverKey: "test-server-key",
httpClient: http.DefaultClient,
}
}
// serveFCMResponse creates an httptest.Server that returns the given FCMResponse as JSON.
func serveFCMResponse(t *testing.T, resp FCMResponse) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(resp)
require.NoError(t, err)
}))
}
// sendWithEndpoint is a helper that sends an FCM notification using a custom endpoint
// (the test server) instead of the real FCM endpoint. This avoids modifying the
// production code to be testable and instead temporarily overrides the client's HTTP
// transport to redirect requests to our test server.
func sendWithEndpoint(client *FCMClient, server *httptest.Server, ctx context.Context, tokens []string, title, message string, data map[string]string) error {
// Override the HTTP client to redirect all requests to the test server
client.httpClient = server.Client()
// We need to intercept the request and redirect it to our test server.
// Use a custom RoundTripper that rewrites the URL.
originalTransport := server.Client().Transport
client.httpClient.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
// Rewrite the URL to point to the test server
req.URL.Scheme = "http"
req.URL.Host = server.Listener.Addr().String()
if originalTransport != nil {
return originalTransport.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
})
return client.Send(ctx, tokens, title, message, data)
}
// roundTripFunc is a function that implements http.RoundTripper.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestFCMSend_MoreResultsThanTokens_NoPanic(t *testing.T) {
// FCM returns 5 results but we only sent 2 tokens.
// Before the bounds check fix, this would panic with index out of range.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 2,
Failure: 3,
Results: []FCMResult{
{MessageID: "msg1"},
{MessageID: "msg2"},
{Error: "InvalidRegistration"},
{Error: "NotRegistered"},
{Error: "InvalidRegistration"},
},
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111", "token-bbb-222"}
// This must not panic
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
assert.NoError(t, err)
}
func TestFCMSend_FewerResultsThanTokens_NoPanic(t *testing.T) {
// FCM returns fewer results than tokens we sent.
// This is also a malformed response but should not panic.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 1,
Failure: 0,
Results: []FCMResult{
{MessageID: "msg1"},
},
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
assert.NoError(t, err)
}
func TestFCMSend_EmptyResponse_NoPanic(t *testing.T) {
// FCM returns an empty Results slice.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 0,
Failure: 0,
Results: []FCMResult{},
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
// No panic expected. The function returns nil because fcmResp.Success == 0
// and fcmResp.Failure == 0 (the "all failed" check requires Failure > 0).
assert.NoError(t, err)
}
func TestFCMSend_NilResultsSlice_NoPanic(t *testing.T) {
// FCM returns a response with nil Results (e.g., malformed JSON).
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 0,
Failure: 1,
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
// Should return error because Success == 0 and Failure > 0
assert.Error(t, err)
assert.Contains(t, err.Error(), "all FCM notifications failed")
}
func TestFCMSend_EmptyTokens_ReturnsNil(t *testing.T) {
// Verify the early return for empty tokens.
client := &FCMClient{
serverKey: "test-key",
httpClient: http.DefaultClient,
}
err := client.Send(context.Background(), []string{}, "Test", "Body", nil)
assert.NoError(t, err)
}
func TestFCMSend_ResultsWithErrorsMatchTokens(t *testing.T) {
// Normal case: results count matches tokens count, all with errors.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 0,
Failure: 2,
Results: []FCMResult{
{Error: "InvalidRegistration"},
{Error: "NotRegistered"},
},
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111", "token-bbb-222"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "all FCM notifications failed")
}

View File

@@ -63,7 +63,7 @@ func (r *ContractorRepository) FindByUser(userID uint, residenceIDs []uint) ([]m
query = query.Where("residence_id IS NULL AND created_by_id = ?", userID)
}
err := query.Order("is_favorite DESC, name ASC").Find(&contractors).Error
err := query.Order("is_favorite DESC, name ASC").Limit(500).Find(&contractors).Error
return contractors, err
}
@@ -85,18 +85,31 @@ func (r *ContractorRepository) Delete(id uint) error {
Update("is_active", false).Error
}
// ToggleFavorite toggles the favorite status of a contractor
// ToggleFavorite toggles the favorite status of a contractor atomically.
// Uses a single UPDATE with NOT to avoid read-then-write race conditions.
// Only toggles active contractors to prevent toggling soft-deleted records.
func (r *ContractorRepository) ToggleFavorite(id uint) (bool, error) {
var contractor models.Contractor
if err := r.db.First(&contractor, id).Error; err != nil {
return false, err
}
newStatus := !contractor.IsFavorite
err := r.db.Model(&models.Contractor{}).
Where("id = ?", id).
Update("is_favorite", newStatus).Error
var newStatus bool
err := r.db.Transaction(func(tx *gorm.DB) error {
// Atomic toggle: SET is_favorite = NOT is_favorite for active contractors only
result := tx.Model(&models.Contractor{}).
Where("id = ? AND is_active = ?", id, true).
Update("is_favorite", gorm.Expr("NOT is_favorite"))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
// Read back the new value within the same transaction
var contractor models.Contractor
if err := tx.Select("is_favorite").First(&contractor, id).Error; err != nil {
return err
}
newStatus = contractor.IsFavorite
return nil
})
return newStatus, err
}
@@ -145,6 +158,19 @@ func (r *ContractorRepository) CountByResidence(residenceID uint) (int64, error)
return count, err
}
// CountByResidenceIDs counts all active contractors across multiple residences in a single query.
// Returns the total count of active contractors for the given residence IDs.
func (r *ContractorRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
if len(residenceIDs) == 0 {
return 0, nil
}
var count int64
err := r.db.Model(&models.Contractor{}).
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Count(&count).Error
return count, err
}
// === Specialty Operations ===
// GetAllSpecialties returns all contractor specialties

View File

@@ -0,0 +1,96 @@
package repositories
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestToggleFavorite_Active_Toggles(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
// Initially is_favorite is false
assert.False(t, contractor.IsFavorite, "contractor should start as not favorite")
// First toggle: false -> true
newStatus, err := repo.ToggleFavorite(contractor.ID)
require.NoError(t, err)
assert.True(t, newStatus, "first toggle should set favorite to true")
// Verify in database
var found models.Contractor
err = db.First(&found, contractor.ID).Error
require.NoError(t, err)
assert.True(t, found.IsFavorite, "database should reflect favorite = true")
// Second toggle: true -> false
newStatus, err = repo.ToggleFavorite(contractor.ID)
require.NoError(t, err)
assert.False(t, newStatus, "second toggle should set favorite to false")
// Verify in database
err = db.First(&found, contractor.ID).Error
require.NoError(t, err)
assert.False(t, found.IsFavorite, "database should reflect favorite = false")
}
func TestToggleFavorite_SoftDeleted_ReturnsError(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Deleted Contractor")
// Soft-delete the contractor
err := db.Model(&models.Contractor{}).
Where("id = ?", contractor.ID).
Update("is_active", false).Error
require.NoError(t, err)
// Toggling a soft-deleted contractor should fail (record not found)
_, err = repo.ToggleFavorite(contractor.ID)
assert.Error(t, err, "toggling a soft-deleted contractor should return an error")
}
func TestToggleFavorite_NonExistent_ReturnsError(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
_, err := repo.ToggleFavorite(99999)
assert.Error(t, err, "toggling a non-existent contractor should return an error")
}
func TestContractorRepository_FindByUser_HasDefaultLimit(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create 510 contractors to exceed the default limit of 500
for i := 0; i < 510; i++ {
c := &models.Contractor{
ResidenceID: &residence.ID,
CreatedByID: user.ID,
Name: fmt.Sprintf("Contractor %d", i+1),
IsActive: true,
}
err := db.Create(c).Error
require.NoError(t, err)
}
contractors, err := repo.FindByUser(user.ID, []uint{residence.ID})
require.NoError(t, err)
assert.Equal(t, 500, len(contractors), "FindByUser should return at most 500 contractors by default")
}

View File

@@ -52,7 +52,8 @@ func (r *DocumentRepository) FindByResidence(residenceID uint) ([]models.Documen
return documents, err
}
// FindByUser finds all documents accessible to a user
// FindByUser finds all documents accessible to a user.
// A default limit of 500 is applied to prevent unbounded result sets.
func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, error) {
var documents []models.Document
err := r.db.Preload("CreatedBy").
@@ -60,6 +61,7 @@ func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document,
Preload("Images").
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Order("created_at DESC").
Limit(500).
Find(&documents).Error
return documents, err
}
@@ -89,7 +91,8 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc
query = query.Where("expiry_date IS NOT NULL AND expiry_date > ? AND expiry_date <= ?", now, threshold)
}
if filter.Search != "" {
searchPattern := "%" + filter.Search + "%"
escaped := escapeLikeWildcards(filter.Search)
searchPattern := "%" + escaped + "%"
query = query.Where("(title ILIKE ? OR description ILIKE ?)", searchPattern, searchPattern)
}
}
@@ -169,6 +172,19 @@ func (r *DocumentRepository) CountByResidence(residenceID uint) (int64, error) {
return count, err
}
// CountByResidenceIDs counts all active documents across multiple residences in a single query.
// Returns the total count of active documents for the given residence IDs.
func (r *DocumentRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
if len(residenceIDs) == 0 {
return 0, nil
}
var count int64
err := r.db.Model(&models.Document{}).
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Count(&count).Error
return count, err
}
// FindByIDIncludingInactive finds a document by ID including inactive ones
func (r *DocumentRepository) FindByIDIncludingInactive(id uint, document *models.Document) error {
return r.db.Preload("CreatedBy").Preload("Images").First(document, id).Error

View File

@@ -0,0 +1,38 @@
package repositories
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestDocumentRepository_FindByUser_HasDefaultLimit(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create 510 documents to exceed the default limit of 500
for i := 0; i < 510; i++ {
doc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: fmt.Sprintf("Doc %d", i+1),
DocumentType: models.DocumentTypeGeneral,
FileURL: "https://example.com/doc.pdf",
IsActive: true,
}
err := db.Create(doc).Error
require.NoError(t, err)
}
docs, err := repo.FindByUser([]uint{residence.ID})
require.NoError(t, err)
assert.Equal(t, 500, len(docs), "FindByUser should return at most 500 documents by default")
}

View File

@@ -1,6 +1,7 @@
package repositories
import (
"errors"
"time"
"gorm.io/gorm"
@@ -130,18 +131,25 @@ func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPre
// UpdatePreferences updates notification preferences
func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error {
return r.db.Save(prefs).Error
return r.db.Omit("User").Save(prefs).Error
}
// GetOrCreatePreferences gets or creates notification preferences for a user
// GetOrCreatePreferences gets or creates notification preferences for a user.
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) {
prefs, err := r.FindPreferencesByUser(userID)
if err == nil {
return prefs, nil
}
var prefs models.NotificationPreference
if err == gorm.ErrRecordNotFound {
prefs = &models.NotificationPreference{
err := r.db.Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ?", userID).First(&prefs).Error
if err == nil {
return nil // Found existing preferences
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err // Unexpected error
}
// Record not found -- create with defaults
prefs = models.NotificationPreference{
UserID: userID,
TaskDueSoon: true,
TaskOverdue: true,
@@ -151,17 +159,36 @@ func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.No
WarrantyExpiring: true,
EmailTaskCompleted: true,
}
if err := r.CreatePreferences(prefs); err != nil {
return nil, err
}
return prefs, nil
return tx.Create(&prefs).Error
})
if err != nil {
return nil, err
}
return nil, err
return &prefs, nil
}
// === Device Registration ===
// FindAPNSDeviceByID finds an APNS device by ID
func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) {
var device models.APNSDevice
err := r.db.First(&device, id).Error
if err != nil {
return nil, err
}
return &device, nil
}
// FindGCMDeviceByID finds a GCM device by ID
func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) {
var device models.GCMDevice
err := r.db.First(&device, id).Error
if err != nil {
return nil, err
}
return &device, nil
}
// FindAPNSDeviceByToken finds an APNS device by registration token
func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) {
var device models.APNSDevice
@@ -243,12 +270,12 @@ func (r *NotificationRepository) DeactivateGCMDevice(id uint) error {
// GetActiveTokensForUser gets all active push tokens for a user
func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) {
apnsDevices, err := r.FindAPNSDevicesByUser(userID)
if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}
gcmDevices, err := r.FindGCMDevicesByUser(userID)
if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}

View File

@@ -0,0 +1,96 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestGetOrCreatePreferences_New_Creates(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// No preferences exist yet for this user
prefs, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
require.NotNil(t, prefs)
// Verify defaults were set
assert.Equal(t, user.ID, prefs.UserID)
assert.True(t, prefs.TaskDueSoon)
assert.True(t, prefs.TaskOverdue)
assert.True(t, prefs.TaskCompleted)
assert.True(t, prefs.TaskAssigned)
assert.True(t, prefs.ResidenceShared)
assert.True(t, prefs.WarrantyExpiring)
assert.True(t, prefs.EmailTaskCompleted)
// Verify it was actually persisted
var count int64
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one preferences record")
}
func TestGetOrCreatePreferences_AlreadyExists_Returns(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Create preferences manually first
existingPrefs := &models.NotificationPreference{
UserID: user.ID,
TaskDueSoon: true,
TaskOverdue: true,
TaskCompleted: true,
TaskAssigned: true,
ResidenceShared: true,
WarrantyExpiring: true,
EmailTaskCompleted: true,
}
err := db.Create(existingPrefs).Error
require.NoError(t, err)
require.NotZero(t, existingPrefs.ID)
// GetOrCreatePreferences should return the existing record, not create a new one
prefs, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
require.NotNil(t, prefs)
// The returned record should have the same ID as the existing one
assert.Equal(t, existingPrefs.ID, prefs.ID, "should return the existing record by ID")
assert.Equal(t, user.ID, prefs.UserID, "should have correct user_id")
// Verify still only one record exists (no duplicate created)
var count int64
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should still have exactly one preferences record")
}
func TestGetOrCreatePreferences_Idempotent(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Call twice in succession
prefs1, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
prefs2, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
// Both should return the same record
assert.Equal(t, prefs1.ID, prefs2.ID)
// Should only have one record
var count int64
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one preferences record after two calls")
}

View File

@@ -37,6 +37,84 @@ func (r *ReminderRepository) HasSentReminder(taskID, userID uint, dueDate time.T
return count > 0, nil
}
// ReminderKey uniquely identifies a reminder that may have been sent.
type ReminderKey struct {
TaskID uint
UserID uint
DueDate time.Time
Stage models.ReminderStage
}
// HasSentReminderBatch checks which reminders from the given list have already been sent.
// Returns a set of indices into the input slice that have already been sent.
// This replaces N individual HasSentReminder calls with a single query.
func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]bool, error) {
result := make(map[int]bool)
if len(keys) == 0 {
return result, nil
}
// Build a lookup from (task_id, user_id, due_date, stage) -> index
type normalizedKey struct {
TaskID uint
UserID uint
DueDate string
Stage models.ReminderStage
}
keyToIdx := make(map[normalizedKey][]int, len(keys))
// Collect unique task IDs and user IDs for the WHERE clause
taskIDSet := make(map[uint]bool)
userIDSet := make(map[uint]bool)
for i, k := range keys {
taskIDSet[k.TaskID] = true
userIDSet[k.UserID] = true
dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC)
nk := normalizedKey{
TaskID: k.TaskID,
UserID: k.UserID,
DueDate: dueDateOnly.Format("2006-01-02"),
Stage: k.Stage,
}
keyToIdx[nk] = append(keyToIdx[nk], i)
}
taskIDs := make([]uint, 0, len(taskIDSet))
for id := range taskIDSet {
taskIDs = append(taskIDs, id)
}
userIDs := make([]uint, 0, len(userIDSet))
for id := range userIDSet {
userIDs = append(userIDs, id)
}
// Query all matching reminder logs in one query
var logs []models.TaskReminderLog
err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs).
Find(&logs).Error
if err != nil {
return nil, err
}
// Match returned logs against our key set
for _, l := range logs {
dueDateStr := l.DueDate.Format("2006-01-02")
nk := normalizedKey{
TaskID: l.TaskID,
UserID: l.UserID,
DueDate: dueDateStr,
Stage: l.ReminderStage,
}
if indices, ok := keyToIdx[nk]; ok {
for _, idx := range indices {
result[idx] = true
}
}
}
return result, nil
}
// LogReminder records that a reminder was sent.
// Returns the created log entry or an error if the reminder was already sent
// (unique constraint violation).

View File

@@ -6,6 +6,7 @@ import (
"math/big"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
@@ -269,7 +270,9 @@ func (r *ResidenceRepository) GetActiveShareCode(residenceID uint) (*models.Resi
// Check if expired
if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) {
// Auto-deactivate expired code
r.DeactivateShareCode(shareCode.ID)
if err := r.DeactivateShareCode(shareCode.ID); err != nil {
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate expired share code")
}
return nil, nil
}
@@ -296,9 +299,11 @@ func (r *ResidenceRepository) generateUniqueCode() (string, error) {
// Check if code already exists
var count int64
r.db.Model(&models.ResidenceShareCode{}).
if err := r.db.Model(&models.ResidenceShareCode{}).
Where("code = ? AND is_active = ?", codeStr, true).
Count(&count)
Count(&count).Error; err != nil {
return "", err
}
if count == 0 {
return codeStr, nil

View File

@@ -1,9 +1,11 @@
package repositories
import (
"errors"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/treytartt/casera-api/internal/models"
)
@@ -30,31 +32,37 @@ func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscrip
return &sub, nil
}
// GetOrCreate gets or creates a subscription for a user (defaults to free tier)
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
sub, err := r.FindByUserID(userID)
if err == nil {
return sub, nil
}
var sub models.UserSubscription
if err == gorm.ErrRecordNotFound {
sub = &models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
err := r.db.Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ?", userID).First(&sub).Error
if err == nil {
return nil // Found existing subscription
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err // Unexpected error
}
// Record not found -- create with free tier defaults
sub = models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
AutoRenew: true,
}
if err := r.db.Create(sub).Error; err != nil {
return nil, err
}
return sub, nil
return tx.Create(&sub).Error
})
if err != nil {
return nil, err
}
return nil, err
return &sub, nil
}
// Update updates a subscription
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
return r.db.Save(sub).Error
return r.db.Omit("User").Save(sub).Error
}
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
@@ -63,7 +71,7 @@ func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time,
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Set("gorm:query_option", "FOR UPDATE").
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
@@ -86,7 +94,7 @@ func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Set("gorm:query_option", "FOR UPDATE").
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
@@ -165,7 +173,7 @@ func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*m
var limits models.TierLimits
err := r.db.Where("tier = ?", tier).First(&limits).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return defaults
if tier == models.TierFree {
defaults := models.GetDefaultFreeLimits()
@@ -193,7 +201,7 @@ func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, er
var settings models.SubscriptionSettings
err := r.db.First(&settings).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return default settings (limitations disabled)
return &models.SubscriptionSettings{
EnableLimitations: false,

Some files were not shown because too many files have changed in this diff Show More