Add Stripe billing, free trials, and cross-platform subscription guards

- Stripe integration: add StripeService with checkout sessions, customer
  portal, and webhook handling for subscription lifecycle events.
- Free trials: auto-start configurable trial on first subscription check,
  with admin-controllable duration and enable/disable toggle.
- Cross-platform guard: prevent duplicate subscriptions across iOS, Android,
  and Stripe by checking existing platform before allowing purchase.
- Subscription model: add Stripe fields (customer_id, subscription_id,
  price_id), trial fields (trial_start, trial_end, trial_used), and
  SubscriptionSource/IsTrialActive helpers.
- API: add trial and source fields to status response, update OpenAPI spec.
- Clean up stale migration and audit docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-05 11:36:14 -06:00
parent d5bb123cd0
commit 72db9050f8
35 changed files with 1555 additions and 1120 deletions

View File

@@ -1,170 +0,0 @@
# Phase 4: Gin to Echo Handler Migration Status
## Completed Files
### ✅ auth_handler.go
- **Status**: Fully migrated
- **Methods migrated**: 13 methods
- Login, Register, Logout, CurrentUser, UpdateProfile
- VerifyEmail, ResendVerification
- ForgotPassword, VerifyResetCode, ResetPassword
- AppleSignIn, GoogleSignIn
- **Key changes applied**:
- Import changed from gin to echo
- Added validator import
- All handlers return error
- c.Bind + c.Validate pattern implemented
- c.MustGet → c.Get
- gin.H → map[string]interface{}
- c.Request.Context() → c.Request().Context()
- All c.JSON calls use `return`
## Remaining Files to Migrate
### 🔧 residence_handler.go
- **Status**: Partially migrated (needs cleanup)
- **Methods**: 13 methods
- **Issue**: Sed-based automated migration created syntax errors
- **Next steps**: Manual cleanup needed
### ⏳ task_handler.go
- **Methods**: ~17 methods
- **Complexity**: High (multipart form handling for completions)
- **Special considerations**:
- Has multipart/form-data handling in CreateCompletion
- Multiple lookup endpoints (categories, priorities, frequencies)
### ⏳ contractor_handler.go
- **Methods**: 8 methods
- **Complexity**: Medium
### ⏳ document_handler.go
- **Methods**: 8 methods
- **Complexity**: High (multipart form handling)
- **Special considerations**: File upload in CreateDocument
### ⏳ notification_handler.go
- **Methods**: 9 methods
- **Complexity**: Medium
- **Special considerations**: Query parameters for pagination
### ⏳ subscription_handler.go
- **Status**: Unknown
- **Estimated complexity**: Medium
### ⏳ upload_handler.go
- **Methods**: 4 methods
- **Complexity**: Medium
- **Special considerations**: c.FormFile handling, c.DefaultQuery
### ⏳ user_handler.go
- **Methods**: 3 methods
- **Complexity**: Low
### ⏳ media_handler.go
- **Status**: Unknown
- **Estimated complexity**: Medium
### ⏳ static_data_handler.go
- **Methods**: Unknown
- **Complexity**: Low (likely just lookups)
### ⏳ task_template_handler.go
- **Status**: Unknown
- **Estimated complexity**: Medium
### ⏳ tracking_handler.go
- **Status**: Unknown
- **Estimated complexity**: Low
### ⏳ subscription_webhook_handler.go
- **Status**: Unknown
- **Estimated complexity**: Medium-High (webhook handling)
## Migration Pattern
All handlers must follow these transformations:
```go
// BEFORE (Gin)
func (h *Handler) Method(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.SomeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
result, err := h.service.DoSomething(&req)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
c.JSON(200, result)
}
// AFTER (Echo)
func (h *Handler) Method(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
var req requests.SomeRequest
if err := c.Bind(&req); err != nil {
return c.JSON(400, map[string]interface{}{"error": err.Error()})
}
if err := c.Validate(&req); err != nil {
return c.JSON(400, validator.FormatValidationErrors(err))
}
result, err := h.service.DoSomething(&req)
if err != nil {
return c.JSON(500, map[string]interface{}{"error": err.Error()})
}
return c.JSON(200, result)
}
```
## Critical Context Changes
| Gin | Echo |
|-----|------|
| `c.MustGet()` | `c.Get()` |
| `c.ShouldBindJSON()` | `c.Bind()` + `c.Validate()` |
| `c.JSON(status, data)` | `return c.JSON(status, data)` |
| `c.Query("key")` | `c.QueryParam("key")` |
| `c.DefaultQuery("k", "v")` | Manual: `if v := c.QueryParam("k"); v != "" { } else { v = "default" }` |
| `c.PostForm("field")` | `c.FormValue("field")` |
| `c.GetHeader("X-...")` | `c.Request().Header.Get("X-...")` |
| `c.Request.Context()` | `c.Request().Context()` |
| `c.Status(code)` | `return c.NoContent(code)` |
| `gin.H{...}` | `map[string]interface{}{...}` |
## Multipart Form Handling
For handlers with file uploads (document_handler, task_handler):
```go
// Request parsing
c.Request.ParseMultipartForm(32 << 20) // Same
c.PostForm("field") c.FormValue("field")
c.FormFile("file") // Same
```
## Next Steps
1. Clean up residence_handler.go manually
2. Migrate contractor_handler.go (simpler, good template)
3. Migrate smaller files: user_handler.go, upload_handler.go, notification_handler.go
4. Migrate complex files: task_handler.go, document_handler.go
5. Migrate remaining files
6. Test compilation
7. Update route registration (if not already done in Phase 3)
## Automation Lessons Learned
- Sed-based bulk replacements are error-prone for complex Go code
- Better approach: Manual migration with copy-paste for repetitive patterns
- Python script provided in migrate_handlers.py (not yet tested)
- Best approach: Methodical manual migration with validation at each step

View File

@@ -1,36 +0,0 @@
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
## Getting Started
First, run the development server:
```bash
npm run dev
# or
yarn dev
# or
pnpm dev
# or
bun dev
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
## Learn More
To learn more about Next.js, take a look at the following resources:
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
## Deploy on Vercel
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.

View File

@@ -1,38 +0,0 @@
# Digest 1: cmd/, admin/dto, admin/handlers (first 15 files)
## Systemic Issues (across all admin handlers)
- **SQL Injection via SortBy**: Every admin list handler concatenates `filters.SortBy` directly into GORM `Order()` without allowlist validation
- **Unchecked Count() errors**: Every paginated handler ignores GORM Count error returns
- **Unchecked post-mutation Preload errors**: After Save/Create, handlers reload with Preload but ignore errors
- **`binding` vs `validate` tag mismatch**: Some request DTOs use `binding` (Gin) instead of `validate` (Echo)
- **Direct DB access**: All admin handlers bypass Service layer, accessing `*gorm.DB` directly
- **Unsafe type assertions**: `c.Get(key).(*models.AdminUser)` without comma-ok
## Per-File Highlights
### cmd/api/main.go - App entry point, wires dependencies
### cmd/worker/main.go - Background worker entry point
### admin/handlers/admin_user_handler.go (347 lines)
- N+1 query: `toUserResponse` does 2 extra DB queries per user (residence count, task count)
- Line 64: SortBy SQL injection
- Line 173: Unchecked profile creation error (user created without profile)
### admin/handlers/apple_social_auth_handler.go - CRUD for Apple social auth records
- Same systemic SQL injection and unchecked errors
### admin/handlers/auth_handler.go - Admin login/session management
### admin/handlers/auth_token_handler.go - Auth token CRUD
### admin/handlers/completion_handler.go - Task completion CRUD
### admin/handlers/completion_image_handler.go - Completion image CRUD
### admin/handlers/confirmation_code_handler.go - Email confirmation code CRUD
### admin/handlers/contractor_handler.go - Contractor CRUD
### admin/handlers/dashboard_handler.go - Admin dashboard stats
### admin/handlers/device_handler.go (317 lines)
- Exposes device push tokens (RegistrationID) in API responses
- Lines 293-296: Unchecked Count errors in GetStats
### admin/handlers/document_handler.go (394 lines)
- Lines 176-183: Date parsing errors silently ignored
- Line 379: Precision loss from decimal.Float64() discarded

View File

@@ -1,51 +0,0 @@
# Digest 2: admin/handlers (remaining 15 files)
### admin/handlers/document_image_handler.go (245 lines)
- N+1: toResponse queries DB per image in List
- Same SortBy SQL injection
### admin/handlers/feature_benefit_handler.go (231 lines)
- `binding` tags instead of `validate` - required fields never enforced
### admin/handlers/limitations_handler.go (451 lines)
- Line 37: Unchecked Create error for default settings
- Line 191-197: UpdateTierLimits overwrites ALL fields even for partial updates
### admin/handlers/lookup_handler.go (877 lines)
- **CRITICAL**: Lines 30-32, 50-52, etc.: refreshXxxCache checks `if cache == nil {}` with EMPTY body, then calls cache.CacheXxx() — nil pointer panic when cache is nil
- Line 792: Hardcoded join table name "task_contractor_specialties"
### admin/handlers/notification_handler.go (419 lines)
- Line 351-363: HTML template built by string concatenation with user-provided subject/body — XSS in admin emails
### admin/handlers/notification_prefs_handler.go (347 lines)
- Line 154: Unchecked user lookup — deleted user produces zero-value username/email
### admin/handlers/onboarding_handler.go (343 lines)
- Line 304: Internal error details leaked to client
### admin/handlers/password_reset_code_handler.go (161 lines)
- **BUG**: Line 85: `code.ResetToken[:8] + "..." + code.ResetToken[len-4:]` panics if token < 8 chars
### admin/handlers/promotion_handler.go (304 lines)
- `binding` tags: required fields never enforced
### admin/handlers/residence_handler.go (371 lines)
- Lines 121-122: Unchecked Count errors for task/document counts
### admin/handlers/settings_handler.go (794 lines)
- Line 378: Raw SQL execution from seed files (no parameterization)
- Line 529-793: ClearAllData is destructive with no double-auth check
- Line 536-539: Panic in ClearAllData silently swallowed
### admin/handlers/share_code_handler.go (225 lines)
- Line 155-162: `IsActive` as non-pointer bool — absent field defaults to false, deactivating codes
### admin/handlers/subscription_handler.go (237 lines)
- **BUG**: Line 40-41: JOIN uses "users" but actual table is "auth_user" — query fails on PostgreSQL
### admin/handlers/task_handler.go (401 lines)
- Line 247-296: Admin Create bypasses service layer — no business logic applied
### admin/handlers/task_template_handler.go (347 lines)
- Lines 29-31: Same nil cache panic as lookup_handler.go

View File

@@ -1,30 +0,0 @@
# Digest 3: admin routes, apperrors, config, database, dto/requests, dto/responses (first half)
### admin/routes.go (483 lines)
- Route ordering: users DELETE "/:id" before "/bulk" — "/bulk" matches as id param
- Line 454: Uses os.Getenv instead of Viper config
- Line 460-462: url.Parse failure returns silently, no logging
- Line 467-469: Proxy errors not surfaced (always returns nil)
### apperrors/errors.go (98 lines) - Clean. Error types with Wrap/Unwrap.
### apperrors/handler.go (67 lines) - c.JSON error returns discarded (minor)
### config/config.go (427 lines)
- Line 339: Hardcoded debug secret key "change-me-in-production-secret-key-12345"
- Lines 311-312: Comments say wrong UTC times (says 8PM/9AM, actually 2PM/3PM)
### database/database.go (468 lines)
- **SECURITY**: Line 372-382: Hardcoded bcrypt hash for GoAdmin with password "admin" — migration RESETS password every run
- **SECURITY**: Line 447-463: Hardcoded admin@mycrib.com / admin123 — password reset on every migration
- Line 182+: Multiple db.Exec errors unchecked for index creation
- Line 100-102: WithTransaction coupled to global db variable
### dto/requests/auth.go (66 lines) - LoginRequest min=1 password (intentional for login)
### dto/requests/contractor.go (37 lines) - Rating has no min/max validation bounds
### dto/requests/document.go (47 lines) - ImageURLs no length limit, Description no max length
### dto/requests/residence.go (59 lines) - Bedrooms/SquareFootage accept negative values, ExpiresInHours no validation
### dto/requests/task.go (110 lines) - Rating no bounds, ImageURLs no length limit, CustomIntervalDays no min
### dto/responses/auth.go (190 lines) - Clean. Proper nil checks on Profile.
### dto/responses/contractor.go (131 lines) - TaskCount depends on preloaded Tasks association
### dto/responses/document.go (126 lines) - Line 101: CreatedBy accessed as value type (fragile if changed to pointer)

View File

@@ -1,48 +0,0 @@
# Digest 4: dto/responses (remaining), echohelpers, handlers (first half)
### dto/responses/residence.go (215 lines) - NewResidenceResponse no nil check on param. Owner zero-value if not preloaded.
### dto/responses/task_template.go (135 lines) - No nil check on template param
### dto/responses/task.go (399 lines) - No nil checks on params in factory functions
### dto/responses/user.go (20 lines) - Clean data types
### echohelpers/helpers.go (46 lines) - Clean utilities
### echohelpers/pagination.go (33 lines) - Clean, properly bounded
### handlers/auth_handler.go (379 lines)
- **ARCHITECTURE**: Lines 83, 178, 207, 241, 329, 370: SIX goroutine spawns for email — violates "no goroutines in handlers" rule
- Line 308-312: AppError constructed directly instead of factory function
### handlers/contractor_handler.go (154 lines)
- Line 28+: Unchecked type assertions throughout (7 instances)
- Line 31: Raw err.Error() returned to client
- Line 55: CreateContractor missing c.Validate() call
### handlers/document_handler.go (336 lines)
- Line 37+: Unchecked type assertions (10 instances)
- Line 92-93: Raw error leaked to client
- Line 137: No DocumentType validation — any string accepted
- Lines 187, 217: Missing c.Validate() calls
### handlers/media_handler.go (172 lines)
- **SECURITY**: Line 156-171: resolveFilePath uses filepath.Join with user-influenced data — PATH TRAVERSAL vulnerability. TrimPrefix doesn't sanitize ../
- Line 19-22: Handler accesses repositories directly, bypasses service layer
### handlers/notification_handler.go (200 lines)
- Line 29-40: No upper bound on limit — unbounded query with limit=999999999
- Line 168: Silent default to "ios" platform
### handlers/residence_handler.go (365 lines)
- Line 38+: Unchecked type assertions (14 instances)
- Lines 187, 209, 303: Bind errors silently discarded
- Line 224: JoinWithCode missing c.Validate()
### handlers/static_data_handler.go (152 lines) - Uses interface{} instead of concrete types
### handlers/subscription_handler.go (176 lines) - Lines 97, 150: Missing c.Validate() calls
- Line 159-163: RestoreSubscription doesn't validate receipt/transaction ID presence
### handlers/subscription_webhook_handler.go (821 lines)
- **SECURITY**: Line 190-192: Apple JWS payload decoded WITHOUT signature verification
- **SECURITY**: Line 787-793: VerifyGooglePubSubToken ALWAYS returns true — webhook unauthenticated
- Line 639-643: Subscription duration guessed by string matching product ID
- Line 657, 694: Hardcoded 1-month extension regardless of actual plan
- Line 759, 772: Unchecked type assertions in VerifyAppleSignature
- Line 162: Apple renewal info error silently discarded

View File

@@ -1,45 +0,0 @@
# Digest 5: handlers (remaining), i18n, middleware, models (first half)
### handlers/task_handler.go (440 lines)
- Line 35+: Unchecked type assertions (18 locations)
- Line 42: Fire-and-forget goroutine for UpdateUserTimezone — no error handling, no context
- Lines 112-115, 134-137: Missing c.Validate() calls
- Line 317: 32MB multipart limit with no per-file size check
### handlers/task_template_handler.go (98 lines)
- Line 59: No max length on search query — slow LIKE queries possible
### handlers/tracking_handler.go (46 lines)
- Line 25: Package-level base64 decode error discarded
- Lines 34-36: Fire-and-forget goroutine — violates no-goroutines rule
### handlers/upload_handler.go (93 lines)
- Line 31: User-controlled `category` param passed to storage — potential path traversal
- Line 80: `binding` tag instead of `validate`
- No file type or size validation at handler level
### handlers/user_handler.go (76 lines) - Unchecked type assertions
### i18n/i18n.go (87 lines)
- Line 16: Global Bundle is nil until Init() — NewLocalizer dereferences without nil check
- Line 37: MustParseMessageFileBytes panics on malformed translation files
- Line 83: MustT panics on missing translations
### i18n/middleware.go (127 lines) - Clean
### middleware/admin_auth.go (133 lines)
- **SECURITY**: Line 50: Admin JWT accepted via query param — tokens leak into server/proxy logs
- Line 124: Unchecked type assertion
### middleware/auth.go (229 lines)
- **BUG**: Line 66: `token[:8]` panics if token is fewer than 8 characters
- Line 104: cacheUserID error silently discarded
- Line 209: Unchecked type assertion
### middleware/logger.go (54 lines) - Clean
### middleware/request_id.go (44 lines) - Line 21: Client-supplied X-Request-ID accepted without validation (log injection)
### middleware/timezone.go (101 lines) - Lines 88, 99: Unchecked type assertions
### models/admin.go (64 lines) - Line 38: No max password length check; bcrypt truncates at 72 bytes
### models/base.go (39 lines) - Clean GORM hooks
### models/contractor.go (54 lines) - *float64 mapped to decimal(2,1) — minor precision mismatch

View File

@@ -1,40 +0,0 @@
# Digest 6: models (remaining), monitoring
### models/document.go (100 lines) - Clean
### models/notification.go (141 lines) - Clean
### models/onboarding_email.go (35 lines) - Clean
### models/reminder_log.go (92 lines) - Clean
### models/residence.go (106 lines)
- Lines 65-70: GetAllUsers/HasAccess assumes Owner and Users are preloaded — returns wrong results if not
### models/subscription.go (169 lines)
- Lines 57-65: IsActive()/IsPro() don't account for IsFree admin override field — misleading method names
### models/task_template.go (23 lines) - Clean
### models/task.go (317 lines)
- Lines 189-264: GetKanbanColumnWithTimezone duplicates categorization chain logic — maintenance drift risk
- Lines 158-182: IsDueSoon uses time.Now() internally — non-deterministic, harder to test
### models/user.go (268 lines)
- Line 101: crypto/rand.Read error unchecked (safe in practice since Go 1.20)
- Line 164-172: GenerateConfirmationCode has slight distribution bias (negligible)
### monitoring/buffer.go (166 lines) - Line 75: Corrupted Redis data silently dropped
### monitoring/collector.go (201 lines)
- Line 82: cpu.Percent blocks 1 second per collection
- Line 96-110: ReadMemStats called TWICE per cycle (also in collectRuntime)
### monitoring/handler.go (195 lines)
- **SECURITY**: Line 19-22: WebSocket CheckOrigin always returns true
- **BUG**: Line 117-119: After upgrader.Upgrade fails, execution continues to conn.Close() — nil pointer panic
- **BUG**: Line 177: Missing return after ctx.Done() — goroutine spins
- Lines 183-184: GetAllStats error silently ignored
- Line 192: WriteJSON error unchecked
### monitoring/middleware.go (220 lines) - Clean
### monitoring/models.go (129 lines) - Clean
### monitoring/service.go (196 lines)
- Line 121: Hardcoded primary key 1 for singleton settings row
- Line 191-194: MetricsMiddleware returns nil when httpCollector is nil — caller must nil-check

View File

@@ -1,57 +0,0 @@
# Digest 7: monitoring/writer, notifications, push, repositories
### monitoring/writer.go (96 lines)
- Line 90-92: Unbounded fire-and-forget goroutines for Redis push — no rate limiting
### notifications/reminder_config.go (64 lines) - Clean config data
### notifications/reminder_schedule.go (199 lines)
- Line 112: Integer truncation of float division — DST off-by-one possible
- Line 148-161: Custom itoa reimplements strconv.Itoa
### push/apns.go (209 lines)
- Line 44: Double-negative logic — both Production=false and Sandbox=false defaults to production
### push/client.go (158 lines)
- Line 89-105: SendToAll last-error-wins — cannot tell which platform failed
- Line 150-157: HealthCheck always returns nil — useless health check
### push/fcm.go (140 lines)
- Line 16: Legacy FCM HTTP API (deprecated by Google)
- Line 119-126: If FCM returns fewer results than tokens, index out of bounds panic
### repositories/admin_repo.go (108 lines)
- Line 92: Negative page produces negative offset
### repositories/contractor_repo.go (166 lines)
- **RACE**: Line 89-101: ToggleFavorite read-then-write without transaction
- Line 91: ToggleFavorite doesn't filter is_active — can toggle deleted contractors
### repositories/document_repo.go (201 lines)
- Line 92: LIKE wildcards in user input not escaped
- Line 12: DocumentFilter.ResidenceID field defined but never used
### repositories/notification_repo.go (267 lines)
- **RACE**: Line 137-161: GetOrCreatePreferences race — concurrent calls both create, duplicate key error
- Line 143: Uses == instead of errors.Is for ErrRecordNotFound
### repositories/reminder_repo.go (126 lines)
- Line 115-122: rows.Err() not checked after iteration loop
### repositories/residence_repo.go (344 lines)
- Line 272: DeactivateShareCode error silently ignored
- Line 298-301: Count error unchecked in generateUniqueCode — potential duplicate codes
- Line 125-128: PostgreSQL-specific ON CONFLICT — fails on SQLite in tests
### repositories/subscription_repo.go (257 lines)
- **RACE**: Line 40: GetOrCreate race condition (same as notification_repo)
- Line 66: GORM v1 pattern `gorm:query_option` for FOR UPDATE — may not work in GORM v2
- Line 129: LIKE search on receipt data blobs — inefficient, no index
- Lines 40, 168, 196: Uses == instead of errors.Is
### repositories/task_repo.go (765 lines)
- Line 707-709: DeleteCompletion ignores image deletion error
- Line 62-101: applyFilterOptions applies NO scope when no filter set — could query all tasks
### repositories/task_template_repo.go (124 lines)
- Line 48: LIKE wildcard escape issue
- Line 79-81: Save without Omit could corrupt Category/Frequency lookup data

View File

@@ -1,35 +0,0 @@
# Digest 8: repositories (remaining), router, services (first half)
### repositories/user_repo.go - Standard GORM CRUD
### repositories/webhook_event_repo.go - Webhook event storage
### router/router.go - Route registration wiring
### services/apple_auth.go - Apple Sign In JWT validation
### services/auth_service.go - Token management, password hashing, email verification
### services/cache_service.go - Redis caching for lookups
### services/contractor_service.go - Contractor CRUD via repository
### services/document_service.go - Document management
### services/email_service.go - SMTP email sending
### services/google_auth.go - Google OAuth token validation
### services/iap_validation.go - Apple/Google receipt validation
### services/notification_service.go - Push notifications, preferences
### services/onboarding_email_service.go (371 lines)
- **ARCHITECTURE**: Direct *gorm.DB access — bypasses repository layer entirely
- Line 43-46: HasSentEmail ignores Count error — could send duplicate emails
- Line 128-133: GetEmailStats ignores 4 Count errors
- Line 170: Raw SQL references "auth_user" table
- Line 354: Delete error silently ignored
### services/pdf_service.go (179 lines)
- **BUG**: Line 131-133: Byte-level truncation of title — breaks multi-byte UTF-8 (CJK, emoji)
### services/residence_service.go (648 lines)
- Line 155: TODO comment — subscription tier limit check commented out (free tier bypass)
- Line 447-450: Empty if block — DeactivateShareCode error completely ignored
- Line 625: Status only set for in-progress tasks; all others have empty string

View File

@@ -1,49 +0,0 @@
# Digest 9: services (remaining), task package, testutil, validator, worker, pkg
### services/storage_service.go (184 lines)
- Line 75: UUID truncated to 8 chars — increased collision risk
- **SECURITY**: Line 137-138: filepath.Abs errors ignored — path traversal check could be bypassed
### services/subscription_service.go (659 lines)
- **PERFORMANCE**: Line 186-204: N+1 queries in getUserUsage — 3 queries per residence
- **SECURITY/BUSINESS**: Line 371: Apple validation failure grants 1-month free Pro
- **SECURITY/BUSINESS**: Line 381: Apple validation not configured grants 1-year free Pro
- **SECURITY/BUSINESS**: Line 429, 449: Same for Google — errors/misconfiguration grant free Pro
- Line 7: Uses stdlib "log" instead of zerolog
### services/task_button_types.go (85 lines) - Clean, uses predicates correctly
### services/task_service.go (1092 lines)
- **DATA INTEGRITY**: Line 601: If task update fails after completion creation, error only logged not returned — stale NextDueDate/InProgress
- Line 735: Goroutine in QuickComplete (service method) — inconsistent with synchronous CreateCompletion
- Line 773: Unbounded goroutine creation per user for notifications
- Line 790: Fail-open email notification on error (intentional but risky)
- **SECURITY**: Line 857-862: resolveImageFilePath has NO path traversal validation
### services/task_template_service.go (70 lines) - Errors returned raw (not wrapped with apperrors)
### services/user_service.go (88 lines) - Returns nil instead of empty slice (JSON null vs [])
### task/categorization/chain.go (359 lines) - Clean chain-of-responsibility
### task/predicates/predicates.go (217 lines)
- Line 210: IsRecurring requires Frequency preloaded — returns false without it
### task/scopes/scopes.go (270 lines)
- Line 118: ScopeOverdue doesn't exclude InProgress — differs from categorization chain
### task/task.go (261 lines) - Clean facade re-exports
### testutil/testutil.go (359 lines)
- Line 86: json.Marshal error ignored in MakeRequest
- Line 92: http.NewRequest error ignored
### validator/validator.go (103 lines) - Clean
### worker/jobs/email_jobs.go (118 lines) - Clean
### worker/jobs/handler.go (810 lines)
- Lines 95-106, 193-204: Direct DB access bypasses repository layer
- Line 627-635: Raw SQL with fmt.Sprintf (not currently user-supplied but fragile)
- Line 154, 251: O(N*M) lookup instead of map
### worker/scheduler.go (240 lines)
- Line 200-212: Cron schedules at fixed UTC times may conflict with smart reminder system — potential duplicate notifications
### pkg/utils/logger.go (132 lines) - Panic recovery bypasses apperrors.HTTPErrorHandler

View File

@@ -1,417 +0,0 @@
# Gin to Echo Framework Migration Guide
This document outlines the migration of the MyCrib Go API from Gin to Echo v4 with direct go-playground/validator integration.
## Overview
| Aspect | Before | After |
|--------|--------|-------|
| Framework | Gin v1.10 | Echo v4.11 |
| Validation | Gin's binding wrapper | Direct go-playground/validator |
| Validation tags | `binding:"..."` | `validate:"..."` |
| Error format | Inconsistent | Structured field-level |
## Scope
- **56 files** requiring modification
- **110+ routes** across public and admin APIs
- **45 handlers** (17 core + 28 admin)
- **5 middleware** files
---
## API Mapping Reference
### Context Methods
| Gin | Echo | Notes |
|-----|------|-------|
| `c.ShouldBindJSON(&req)` | `c.Bind(&req)` | Bind only, no validation |
| `c.ShouldBindJSON(&req)` | `c.Validate(&req)` | Validation only (call after Bind) |
| `c.Param("id")` | `c.Param("id")` | Same |
| `c.Query("name")` | `c.QueryParam("name")` | Different method name |
| `c.DefaultQuery("k","v")` | Custom helper | No built-in equivalent |
| `c.PostForm("field")` | `c.FormValue("field")` | Different method name |
| `c.FormFile("file")` | `c.FormFile("file")` | Same |
| `c.Get(key)` | `c.Get(key)` | Same |
| `c.Set(key, val)` | `c.Set(key, val)` | Same |
| `c.MustGet(key)` | `c.Get(key)` | No MustGet, add nil check |
| `c.JSON(status, obj)` | `return c.JSON(status, obj)` | Must return |
| `c.AbortWithStatusJSON()` | `return c.JSON()` | Return-based flow |
| `c.Status(200)` | `return c.NoContent(200)` | Different method |
| `c.GetHeader("X-...")` | `c.Request().Header.Get()` | Access via Request |
| `c.ClientIP()` | `c.RealIP()` | Different method name |
| `c.File(path)` | `return c.File(path)` | Must return |
| `gin.H{...}` | `echo.Map{...}` | Or `map[string]any{}` |
### Handler Signature
```go
// Gin - void return
func (h *Handler) Method(c *gin.Context) {
// ...
c.JSON(200, response)
}
// Echo - error return (MUST return)
func (h *Handler) Method(c echo.Context) error {
// ...
return c.JSON(200, response)
}
```
### Middleware Signature
```go
// Gin
func MyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// before
c.Next()
// after
}
}
// Echo
func MyMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// before
err := next(c)
// after
return err
}
}
}
```
---
## Validation Changes
### Tag Migration
Change all `binding:` tags to `validate:` tags:
```go
// Before
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
}
// After
type LoginRequest struct {
Email string `json:"email" validate:"required,email"`
Password string `json:"password" validate:"required,min=8"`
}
```
### Supported Validation Tags
| Tag | Description |
|-----|-------------|
| `required` | Field must be present and non-zero |
| `required_without=Field` | Required if other field is empty |
| `omitempty` | Skip validation if empty |
| `email` | Must be valid email format |
| `min=N` | Minimum length/value |
| `max=N` | Maximum length/value |
| `len=N` | Exact length |
| `oneof=a b c` | Must be one of listed values |
| `url` | Must be valid URL |
| `uuid` | Must be valid UUID |
### New Error Response Format
```json
{
"error": "Validation failed",
"fields": {
"email": {
"message": "Must be a valid email address",
"tag": "email"
},
"password": {
"message": "Must be at least 8 characters",
"tag": "min"
}
}
}
```
**Mobile clients must update** error parsing to handle the `fields` object.
---
## Handler Migration Pattern
### Before (Gin)
```go
func (h *AuthHandler) Login(c *gin.Context) {
var req requests.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
})
return
}
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.authService.Login(&req)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
```
### After (Echo)
```go
func (h *AuthHandler) Login(c echo.Context) error {
var req requests.LoginRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
})
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
user := c.Get(middleware.AuthUserKey).(*models.User)
response, err := h.authService.Login(&req)
if err != nil {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": err.Error()})
}
return c.JSON(http.StatusOK, response)
}
```
---
## Middleware Migration Examples
### Auth Middleware
```go
// Before (Gin)
func (m *AuthMiddleware) TokenAuth() gin.HandlerFunc {
return func(c *gin.Context) {
token, err := extractToken(c)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
user, err := m.validateToken(token)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
return
}
c.Set(AuthUserKey, user)
c.Next()
}
}
// After (Echo)
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
token, err := extractToken(c)
if err != nil {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
user, err := m.validateToken(token)
if err != nil {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Invalid token"})
}
c.Set(AuthUserKey, user)
return next(c)
}
}
}
```
### Timezone Middleware
```go
// Before (Gin)
func TimezoneMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tzName := c.GetHeader(TimezoneHeader)
loc := parseTimezone(tzName)
c.Set(TimezoneKey, loc)
c.Next()
}
}
// After (Echo)
func TimezoneMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
tzName := c.Request().Header.Get(TimezoneHeader)
loc := parseTimezone(tzName)
c.Set(TimezoneKey, loc)
return next(c)
}
}
}
```
---
## Router Setup
### Before (Gin)
```go
func SetupRouter(deps *Dependencies) *gin.Engine {
r := gin.New()
r.Use(gin.Recovery())
r.Use(gin.Logger())
r.Use(cors.New(cors.Config{...}))
api := r.Group("/api")
api.POST("/auth/login/", authHandler.Login)
protected := api.Group("")
protected.Use(authMiddleware.TokenAuth())
protected.GET("/residences/", residenceHandler.List)
return r
}
```
### After (Echo)
```go
func SetupRouter(deps *Dependencies) *echo.Echo {
e := echo.New()
e.HideBanner = true
e.Validator = validator.NewCustomValidator()
// Trailing slash handling
e.Pre(middleware.AddTrailingSlash())
e.Use(middleware.Recover())
e.Use(middleware.Logger())
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{...}))
api := e.Group("/api")
api.POST("/auth/login/", authHandler.Login)
protected := api.Group("")
protected.Use(authMiddleware.TokenAuth())
protected.GET("/residences/", residenceHandler.List)
return e
}
```
---
## Helper Functions
### DefaultQuery Helper
```go
// internal/echohelpers/helpers.go
func DefaultQuery(c echo.Context, key, defaultValue string) string {
if val := c.QueryParam(key); val != "" {
return val
}
return defaultValue
}
```
### Safe Context Get
```go
// internal/middleware/helpers.go
func GetAuthUser(c echo.Context) *models.User {
val := c.Get(AuthUserKey)
if val == nil {
return nil
}
user, ok := val.(*models.User)
if !ok {
return nil
}
return user
}
func MustGetAuthUser(c echo.Context) (*models.User, error) {
user := GetAuthUser(c)
if user == nil {
return nil, echo.NewHTTPError(http.StatusUnauthorized, "Authentication required")
}
return user, nil
}
```
---
## Testing Changes
### Test Setup
```go
// Before (Gin)
func SetupTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
return gin.New()
}
// After (Echo)
func SetupTestRouter() *echo.Echo {
e := echo.New()
e.Validator = validator.NewCustomValidator()
return e
}
```
### Making Test Requests
```go
// Before (Gin)
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/auth/login/", body)
router.ServeHTTP(w, req)
// After (Echo) - Same pattern works
rec := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/auth/login/", body)
e.ServeHTTP(rec, req)
```
---
## Important Notes
1. **All handlers must return error** - Echo uses return-based flow
2. **Trailing slashes** - Use `middleware.AddTrailingSlash()` to maintain API compatibility
3. **Type assertions** - Always add nil checks when using `c.Get()`
4. **CORS** - Use Echo's built-in CORS middleware
5. **Bind vs Validate** - Echo separates these; call both for full validation
---
## Files Modified
| Category | Count |
|----------|-------|
| New files | 2 (`validator/`, `echohelpers/`) |
| DTOs | 6 |
| Middleware | 5 |
| Core handlers | 14 |
| Admin handlers | 28 |
| Router | 2 |
| Tests | 6 |
| **Total** | **63** |

View File

@@ -2350,6 +2350,121 @@ paths:
'401': '401':
$ref: '#/components/responses/Unauthorized' $ref: '#/components/responses/Unauthorized'
/subscription/checkout/:
post:
tags: [Subscriptions]
operationId: createCheckoutSession
summary: Create a Stripe Checkout session for web subscription purchase
security:
- tokenAuth: []
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [price_id, success_url, cancel_url]
properties:
price_id:
type: string
description: Stripe Price ID
success_url:
type: string
format: uri
cancel_url:
type: string
format: uri
responses:
'200':
description: Checkout session created
content:
application/json:
schema:
type: object
properties:
checkout_url:
type: string
format: uri
'400':
$ref: '#/components/responses/Error'
'401':
$ref: '#/components/responses/Unauthorized'
'409':
description: Already subscribed on another platform
content:
application/json:
schema:
type: object
properties:
error:
type: string
existing_platform:
type: string
message:
type: string
/subscription/portal/:
post:
tags: [Subscriptions]
operationId: createPortalSession
summary: Create a Stripe Customer Portal session for managing web subscriptions
security:
- tokenAuth: []
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [return_url]
properties:
return_url:
type: string
format: uri
responses:
'200':
description: Portal session created
content:
application/json:
schema:
type: object
properties:
portal_url:
type: string
format: uri
'400':
$ref: '#/components/responses/Error'
'401':
$ref: '#/components/responses/Unauthorized'
/subscription/webhook/stripe/:
post:
tags: [Subscriptions]
operationId: handleStripeWebhook
summary: Handle Stripe webhook events (server-to-server)
description: |
Receives Stripe webhook events for subscription lifecycle management.
Verifies the webhook signature using the configured signing secret.
No auth token required — uses Stripe signature verification.
requestBody:
required: true
content:
application/json:
schema:
type: object
responses:
'200':
description: Webhook processed successfully
content:
application/json:
schema:
type: object
properties:
received:
type: boolean
'400':
$ref: '#/components/responses/Error'
# =========================================================================== # ===========================================================================
# Uploads # Uploads
# =========================================================================== # ===========================================================================
@@ -4434,6 +4549,12 @@ components:
SubscriptionStatusResponse: SubscriptionStatusResponse:
type: object type: object
properties: properties:
tier:
type: string
description: 'Subscription tier (free or pro)'
is_active:
type: boolean
description: Whether the subscription is currently active
subscribed_at: subscribed_at:
type: string type: string
format: date-time format: date-time
@@ -4444,6 +4565,20 @@ components:
nullable: true nullable: true
auto_renew: auto_renew:
type: boolean type: boolean
trial_start:
type: string
format: date-time
nullable: true
trial_end:
type: string
format: date-time
nullable: true
trial_active:
type: boolean
subscription_source:
type: string
nullable: true
description: 'Platform source of the active subscription (ios, android, stripe, or null)'
usage: usage:
$ref: '#/components/schemas/UsageResponse' $ref: '#/components/schemas/UsageResponse'
limits: limits:

1
go.mod
View File

@@ -71,6 +71,7 @@ require (
github.com/spf13/afero v1.14.0 // indirect github.com/spf13/afero v1.14.0 // indirect
github.com/spf13/cast v1.10.0 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/pflag v1.0.10 // indirect
github.com/stripe/stripe-go/v81 v81.4.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect

7
go.sum
View File

@@ -156,6 +156,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
@@ -188,6 +190,7 @@ golang.org/x/crypto v0.0.0-20170512130425-ab89591268e0/go.mod h1:6SG95UA2DQfeDnf
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
@@ -196,7 +199,9 @@ golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwE
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -206,8 +211,10 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=

View File

@@ -1,37 +0,0 @@
# Go Backend Hardening Audit Report
## Audit Sources
- 9 mapper agents (100% file coverage)
- 8 specialized domain auditors (parallel)
- 1 cross-cutting deep audit (parallel)
- Total source files: 136 (excluding 27 test files)
---
## CRITICAL — Will crash or lose data
## BUG — Incorrect behavior
## SILENT FAILURE — Error swallowed or ignored
## RACE CONDITION — Concurrency issue
## LOGIC ERROR — Code doesn't match intent
## PERFORMANCE — Unnecessary cost
## SECURITY — Vulnerability or exposure
## AUTHORIZATION — Access control gap
## DATA INTEGRITY — GORM / database issue
## API CONTRACT — Request/response issue
## ARCHITECTURE — Layer or pattern violation
## FRAGILE — Works now but will break easily
---
## Summary

View File

@@ -234,11 +234,13 @@ type UpdateNotificationRequest struct {
// SubscriptionFilters holds subscription-specific filter parameters // SubscriptionFilters holds subscription-specific filter parameters
type SubscriptionFilters struct { type SubscriptionFilters struct {
PaginationParams PaginationParams
UserID *uint `form:"user_id"` UserID *uint `form:"user_id"`
Tier *string `form:"tier"` Tier *string `form:"tier"`
Platform *string `form:"platform"` Platform *string `form:"platform"`
AutoRenew *bool `form:"auto_renew"` AutoRenew *bool `form:"auto_renew"`
Active *bool `form:"active"` Active *bool `form:"active"`
HasStripe *bool `form:"has_stripe"`
TrialActive *bool `form:"trial_active"`
} }
// UpdateSubscriptionRequest for updating a subscription // UpdateSubscriptionRequest for updating a subscription
@@ -250,6 +252,14 @@ type UpdateSubscriptionRequest struct {
SubscribedAt *string `json:"subscribed_at"` SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"` ExpiresAt *string `json:"expires_at"`
CancelledAt *string `json:"cancelled_at"` CancelledAt *string `json:"cancelled_at"`
// Stripe fields
StripeCustomerID *string `json:"stripe_customer_id"`
StripeSubscriptionID *string `json:"stripe_subscription_id"`
StripePriceID *string `json:"stripe_price_id"`
// Trial fields
TrialStart *string `json:"trial_start"`
TrialEnd *string `json:"trial_end"`
TrialUsed *bool `json:"trial_used"`
} }
// CreateResidenceRequest for creating a new residence // CreateResidenceRequest for creating a new residence

View File

@@ -264,7 +264,16 @@ type SubscriptionResponse struct {
SubscribedAt *string `json:"subscribed_at,omitempty"` SubscribedAt *string `json:"subscribed_at,omitempty"`
ExpiresAt *string `json:"expires_at,omitempty"` ExpiresAt *string `json:"expires_at,omitempty"`
CancelledAt *string `json:"cancelled_at,omitempty"` CancelledAt *string `json:"cancelled_at,omitempty"`
CreatedAt string `json:"created_at"` // Stripe fields
StripeCustomerID *string `json:"stripe_customer_id,omitempty"`
StripeSubscriptionID *string `json:"stripe_subscription_id,omitempty"`
StripePriceID *string `json:"stripe_price_id,omitempty"`
// Trial fields
TrialStart *string `json:"trial_start,omitempty"`
TrialEnd *string `json:"trial_end,omitempty"`
TrialUsed bool `json:"trial_used"`
TrialActive bool `json:"trial_active"`
CreatedAt string `json:"created_at"`
} }
// SubscriptionDetailResponse includes more details for single subscription view // SubscriptionDetailResponse includes more details for single subscription view

View File

@@ -30,6 +30,8 @@ func NewAdminSettingsHandler(db *gorm.DB) *AdminSettingsHandler {
type SettingsResponse struct { type SettingsResponse struct {
EnableLimitations bool `json:"enable_limitations"` EnableLimitations bool `json:"enable_limitations"`
EnableMonitoring bool `json:"enable_monitoring"` EnableMonitoring bool `json:"enable_monitoring"`
TrialEnabled bool `json:"trial_enabled"`
TrialDurationDays int `json:"trial_duration_days"`
} }
// GetSettings handles GET /api/admin/settings // GetSettings handles GET /api/admin/settings
@@ -38,7 +40,13 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
if err := h.db.First(&settings, 1).Error; err != nil { if err := h.db.First(&settings, 1).Error; err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
// Create default settings // Create default settings
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false, EnableMonitoring: true} settings = models.SubscriptionSettings{
ID: 1,
EnableLimitations: false,
EnableMonitoring: true,
TrialEnabled: true,
TrialDurationDays: 14,
}
h.db.Create(&settings) h.db.Create(&settings)
} else { } else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"}) return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
@@ -48,6 +56,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
return c.JSON(http.StatusOK, SettingsResponse{ return c.JSON(http.StatusOK, SettingsResponse{
EnableLimitations: settings.EnableLimitations, EnableLimitations: settings.EnableLimitations,
EnableMonitoring: settings.EnableMonitoring, EnableMonitoring: settings.EnableMonitoring,
TrialEnabled: settings.TrialEnabled,
TrialDurationDays: settings.TrialDurationDays,
}) })
} }
@@ -55,6 +65,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
type UpdateSettingsRequest struct { type UpdateSettingsRequest struct {
EnableLimitations *bool `json:"enable_limitations"` EnableLimitations *bool `json:"enable_limitations"`
EnableMonitoring *bool `json:"enable_monitoring"` EnableMonitoring *bool `json:"enable_monitoring"`
TrialEnabled *bool `json:"trial_enabled"`
TrialDurationDays *int `json:"trial_duration_days"`
} }
// UpdateSettings handles PUT /api/admin/settings // UpdateSettings handles PUT /api/admin/settings
@@ -67,7 +79,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
var settings models.SubscriptionSettings var settings models.SubscriptionSettings
if err := h.db.First(&settings, 1).Error; err != nil { if err := h.db.First(&settings, 1).Error; err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
settings = models.SubscriptionSettings{ID: 1, EnableMonitoring: true} settings = models.SubscriptionSettings{
ID: 1,
EnableMonitoring: true,
TrialEnabled: true,
TrialDurationDays: 14,
}
} else { } else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"}) return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
} }
@@ -81,6 +98,14 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
settings.EnableMonitoring = *req.EnableMonitoring settings.EnableMonitoring = *req.EnableMonitoring
} }
if req.TrialEnabled != nil {
settings.TrialEnabled = *req.TrialEnabled
}
if req.TrialDurationDays != nil {
settings.TrialDurationDays = *req.TrialDurationDays
}
if err := h.db.Save(&settings).Error; err != nil { if err := h.db.Save(&settings).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update settings"}) return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update settings"})
} }
@@ -88,6 +113,8 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
return c.JSON(http.StatusOK, SettingsResponse{ return c.JSON(http.StatusOK, SettingsResponse{
EnableLimitations: settings.EnableLimitations, EnableLimitations: settings.EnableLimitations,
EnableMonitoring: settings.EnableMonitoring, EnableMonitoring: settings.EnableMonitoring,
TrialEnabled: settings.TrialEnabled,
TrialDurationDays: settings.TrialDurationDays,
}) })
} }

View File

@@ -3,6 +3,7 @@ package handlers
import ( import (
"net/http" "net/http"
"strconv" "strconv"
"time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"gorm.io/gorm" "gorm.io/gorm"
@@ -61,6 +62,20 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
query = query.Where("expires_at IS NOT NULL AND expires_at <= NOW()") query = query.Where("expires_at IS NOT NULL AND expires_at <= NOW()")
} }
} }
if filters.HasStripe != nil {
if *filters.HasStripe {
query = query.Where("stripe_subscription_id IS NOT NULL")
} else {
query = query.Where("stripe_subscription_id IS NULL")
}
}
if filters.TrialActive != nil {
if *filters.TrialActive {
query = query.Where("trial_end IS NOT NULL AND trial_end > NOW()")
} else {
query = query.Where("trial_end IS NULL OR trial_end <= NOW()")
}
}
// Get total count // Get total count
query.Count(&total) query.Count(&total)
@@ -137,6 +152,32 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error {
if req.IsFree != nil { if req.IsFree != nil {
subscription.IsFree = *req.IsFree subscription.IsFree = *req.IsFree
} }
if req.StripeCustomerID != nil {
subscription.StripeCustomerID = req.StripeCustomerID
}
if req.StripeSubscriptionID != nil {
subscription.StripeSubscriptionID = req.StripeSubscriptionID
}
if req.StripePriceID != nil {
subscription.StripePriceID = req.StripePriceID
}
if req.TrialStart != nil {
parsed, err := time.Parse(time.RFC3339, *req.TrialStart)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_start format, expected RFC3339"})
}
subscription.TrialStart = &parsed
}
if req.TrialEnd != nil {
parsed, err := time.Parse(time.RFC3339, *req.TrialEnd)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_end format, expected RFC3339"})
}
subscription.TrialEnd = &parsed
}
if req.TrialUsed != nil {
subscription.TrialUsed = *req.TrialUsed
}
if err := h.db.Save(&subscription).Error; err != nil { if err := h.db.Save(&subscription).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update subscription"}) return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update subscription"})
@@ -184,29 +225,39 @@ func (h *AdminSubscriptionHandler) GetByUser(c echo.Context) error {
// GetStats handles GET /api/admin/subscriptions/stats // GetStats handles GET /api/admin/subscriptions/stats
func (h *AdminSubscriptionHandler) GetStats(c echo.Context) error { func (h *AdminSubscriptionHandler) GetStats(c echo.Context) error {
var total, free, premium, pro int64 var total, free, premium, pro int64
var stripeSubscribers, activeTrials int64
h.db.Model(&models.UserSubscription{}).Count(&total) h.db.Model(&models.UserSubscription{}).Count(&total)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "free").Count(&free) h.db.Model(&models.UserSubscription{}).Where("tier = ?", "free").Count(&free)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&premium) h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&premium)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&pro) h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&pro)
h.db.Model(&models.UserSubscription{}).Where("stripe_subscription_id IS NOT NULL AND tier = ?", "pro").Count(&stripeSubscribers)
h.db.Model(&models.UserSubscription{}).Where("trial_end IS NOT NULL AND trial_end > NOW()").Count(&activeTrials)
return c.JSON(http.StatusOK, map[string]interface{}{ return c.JSON(http.StatusOK, map[string]interface{}{
"total": total, "total": total,
"free": free, "free": free,
"premium": premium, "premium": premium,
"pro": pro, "pro": pro,
"stripe_subscribers": stripeSubscribers,
"active_trials": activeTrials,
}) })
} }
func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscription) dto.SubscriptionResponse { func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscription) dto.SubscriptionResponse {
response := dto.SubscriptionResponse{ response := dto.SubscriptionResponse{
ID: sub.ID, ID: sub.ID,
UserID: sub.UserID, UserID: sub.UserID,
Tier: string(sub.Tier), Tier: string(sub.Tier),
Platform: sub.Platform, Platform: sub.Platform,
AutoRenew: sub.AutoRenew, AutoRenew: sub.AutoRenew,
IsFree: sub.IsFree, IsFree: sub.IsFree,
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"), StripeCustomerID: sub.StripeCustomerID,
StripeSubscriptionID: sub.StripeSubscriptionID,
StripePriceID: sub.StripePriceID,
TrialUsed: sub.TrialUsed,
TrialActive: sub.IsTrialActive(),
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
} }
if sub.User.ID != 0 { if sub.User.ID != 0 {
@@ -225,6 +276,14 @@ func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscr
cancelledAt := sub.CancelledAt.Format("2006-01-02T15:04:05Z") cancelledAt := sub.CancelledAt.Format("2006-01-02T15:04:05Z")
response.CancelledAt = &cancelledAt response.CancelledAt = &cancelledAt
} }
if sub.TrialStart != nil {
trialStart := sub.TrialStart.Format(time.RFC3339)
response.TrialStart = &trialStart
}
if sub.TrialEnd != nil {
trialEnd := sub.TrialEnd.Format(time.RFC3339)
response.TrialEnd = &trialEnd
}
return response return response
} }

View File

@@ -25,6 +25,7 @@ type Config struct {
GoogleAuth GoogleAuthConfig GoogleAuth GoogleAuthConfig
AppleIAP AppleIAPConfig AppleIAP AppleIAPConfig
GoogleIAP GoogleIAPConfig GoogleIAP GoogleIAPConfig
Stripe StripeConfig
Features FeatureFlags Features FeatureFlags
} }
@@ -104,6 +105,14 @@ type GoogleIAPConfig struct {
PackageName string // Android package name (e.g., com.tt.casera) PackageName string // Android package name (e.g., com.tt.casera)
} }
// StripeConfig holds Stripe payment configuration
type StripeConfig struct {
SecretKey string // Stripe secret API key
WebhookSecret string // Stripe webhook endpoint signing secret
PriceMonthly string // Stripe Price ID for monthly Pro subscription
PriceYearly string // Stripe Price ID for yearly Pro subscription
}
type WorkerConfig struct { type WorkerConfig struct {
// Scheduled job times (UTC) // Scheduled job times (UTC)
TaskReminderHour int TaskReminderHour int
@@ -248,6 +257,12 @@ func Load() (*Config, error) {
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"), ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"), PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
}, },
Stripe: StripeConfig{
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
},
Features: FeatureFlags{ Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"), PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"), EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),

View File

@@ -240,7 +240,7 @@ func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
contractorRepo := repositories.NewContractorRepository(db) contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db) documentRepo := repositories.NewDocumentRepository(db)
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo) subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
handler := NewSubscriptionHandler(subscriptionService) handler := NewSubscriptionHandler(subscriptionService, nil)
e := testutil.SetupTestRouter() e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware // Register routes WITHOUT auth middleware

View File

@@ -13,11 +13,15 @@ import (
// SubscriptionHandler handles subscription-related HTTP requests // SubscriptionHandler handles subscription-related HTTP requests
type SubscriptionHandler struct { type SubscriptionHandler struct {
subscriptionService *services.SubscriptionService subscriptionService *services.SubscriptionService
stripeService *services.StripeService
} }
// NewSubscriptionHandler creates a new subscription handler // NewSubscriptionHandler creates a new subscription handler
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *SubscriptionHandler { func NewSubscriptionHandler(subscriptionService *services.SubscriptionService, stripeService *services.StripeService) *SubscriptionHandler {
return &SubscriptionHandler{subscriptionService: subscriptionService} return &SubscriptionHandler{
subscriptionService: subscriptionService,
stripeService: stripeService,
}
} }
// GetSubscription handles GET /api/subscription/ // GetSubscription handles GET /api/subscription/
@@ -194,3 +198,82 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
"subscription": subscription, "subscription": subscription,
}) })
} }
// CreateCheckoutSession handles POST /api/subscription/checkout/
// Creates a Stripe Checkout Session for web subscription purchases
func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
if h.stripeService == nil {
return apperrors.BadRequest("error.stripe_not_configured")
}
// Check if already Pro from another platform
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe")
if err != nil {
return err
}
if alreadyPro {
return c.JSON(http.StatusConflict, map[string]interface{}{
"error": "error.already_subscribed_other_platform",
"existing_platform": existingPlatform,
"message": "You already have an active Pro subscription via " + existingPlatform + ". Manage it there to avoid double billing.",
})
}
var req struct {
PriceID string `json:"price_id" validate:"required"`
SuccessURL string `json:"success_url" validate:"required,url"`
CancelURL string `json:"cancel_url" validate:"required,url"`
}
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
sessionURL, err := h.stripeService.CreateCheckoutSession(user.ID, req.PriceID, req.SuccessURL, req.CancelURL)
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{
"checkout_url": sessionURL,
})
}
// CreatePortalSession handles POST /api/subscription/portal/
// Creates a Stripe Customer Portal session for managing web subscriptions
func (h *SubscriptionHandler) CreatePortalSession(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
if h.stripeService == nil {
return apperrors.BadRequest("error.stripe_not_configured")
}
var req struct {
ReturnURL string `json:"return_url" validate:"required,url"`
}
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
portalURL, err := h.stripeService.CreatePortalSession(user.ID, req.ReturnURL)
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{
"portal_url": portalURL,
})
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/treytartt/casera-api/internal/config" "github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/models" "github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories" "github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
) )
// SubscriptionWebhookHandler handles subscription webhook callbacks // SubscriptionWebhookHandler handles subscription webhook callbacks
@@ -28,6 +29,7 @@ type SubscriptionWebhookHandler struct {
userRepo *repositories.UserRepository userRepo *repositories.UserRepository
webhookEventRepo *repositories.WebhookEventRepository webhookEventRepo *repositories.WebhookEventRepository
appleRootCerts []*x509.Certificate appleRootCerts []*x509.Certificate
stripeService *services.StripeService
enabled bool enabled bool
} }
@@ -46,6 +48,11 @@ func NewSubscriptionWebhookHandler(
} }
} }
// SetStripeService sets the Stripe service for webhook handling
func (h *SubscriptionWebhookHandler) SetStripeService(stripeService *services.StripeService) {
h.stripeService = stripeService
}
// ==================== // ====================
// Apple App Store Server Notifications v2 // Apple App Store Server Notifications v2
// ==================== // ====================
@@ -377,38 +384,30 @@ func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *A
} }
func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTransactionInfo) error { func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil { if err := h.safeDowngradeToFree(userID, "Apple expired"); err != nil {
return err return err
} }
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
return nil return nil
} }
func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTransactionInfo) error { func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil { if err := h.safeDowngradeToFree(userID, "Apple refund"); err != nil {
return err return err
} }
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
return nil return nil
} }
func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTransactionInfo) error { func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil { if err := h.safeDowngradeToFree(userID, "Apple revoke"); err != nil {
return err return err
} }
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
return nil return nil
} }
func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint, tx *AppleTransactionInfo) error { func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil { if err := h.safeDowngradeToFree(userID, "Apple grace period expired"); err != nil {
return err return err
} }
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
return nil return nil
} }
@@ -705,22 +704,16 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
} }
func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notification *GoogleSubscriptionNotification) error { func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notification *GoogleSubscriptionNotification) error {
// Subscription revoked - immediate downgrade if err := h.safeDowngradeToFree(userID, "Google revoke"); err != nil {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
return err return err
} }
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
return nil return nil
} }
func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notification *GoogleSubscriptionNotification) error { func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notification *GoogleSubscriptionNotification) error {
// Subscription expired if err := h.safeDowngradeToFree(userID, "Google expired"); err != nil {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
return err return err
} }
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
return nil return nil
} }
@@ -730,6 +723,88 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
return nil return nil
} }
// ====================
// Multi-Source Downgrade Safety
// ====================
// safeDowngradeToFree checks if the user has active subscriptions from other sources
// before downgrading to free. If another source is still active, skip the downgrade.
func (h *SubscriptionWebhookHandler) safeDowngradeToFree(userID uint, reason string) error {
sub, err := h.subscriptionRepo.FindByUserID(userID)
if err != nil {
log.Warn().Err(err).Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Could not find subscription for multi-source check, proceeding with downgrade")
return h.subscriptionRepo.DowngradeToFree(userID)
}
// Check if Stripe subscription is still active
if sub.HasStripeSubscription() && sub.Platform != models.PlatformStripe {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Stripe subscription")
return nil
}
// Check if Apple subscription is still active (for Google/Stripe webhooks)
if sub.HasAppleSubscription() && sub.Platform != models.PlatformIOS {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Apple subscription")
return nil
}
// Check if Google subscription is still active (for Apple/Stripe webhooks)
if sub.HasGoogleSubscription() && sub.Platform != models.PlatformAndroid {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Google subscription")
return nil
}
// Check if trial is still active
if sub.IsTrialActive() {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active trial")
return nil
}
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
return err
}
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: User downgraded to free (no other active sources)")
return nil
}
// ====================
// Stripe Webhooks
// ====================
// HandleStripeWebhook handles POST /api/subscription/webhook/stripe/
func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error {
if !h.enabled {
log.Info().Msg("Stripe Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
if h.stripeService == nil {
log.Warn().Msg("Stripe Webhook: Stripe service not configured")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
}
signature := c.Request().Header.Get("Stripe-Signature")
if signature == "" {
log.Warn().Msg("Stripe Webhook: Missing Stripe-Signature header")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "missing signature"})
}
if err := h.stripeService.HandleWebhookEvent(body, signature); err != nil {
log.Error().Err(err).Msg("Stripe Webhook: Failed to process webhook")
// Still return 200 to prevent Stripe from retrying on business logic errors
// Only return error for signature verification failures
if strings.Contains(err.Error(), "signature") {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signature"})
}
}
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
}
// ==================== // ====================
// Signature Verification (Optional but Recommended) // Signature Verification (Optional but Recommended)
// ==================== // ====================

View File

@@ -98,6 +98,10 @@ var specEndpointsKMPSkips = map[routeKey]bool{
{Method: "POST", Path: "/notifications/devices/"}: true, // KMP uses /notifications/devices/register/ {Method: "POST", Path: "/notifications/devices/"}: true, // KMP uses /notifications/devices/register/
{Method: "POST", Path: "/notifications/devices/unregister/"}: true, // KMP uses DELETE on device ID {Method: "POST", Path: "/notifications/devices/unregister/"}: true, // KMP uses DELETE on device ID
{Method: "PATCH", Path: "/notifications/preferences/"}: true, // KMP uses PUT {Method: "PATCH", Path: "/notifications/preferences/"}: true, // KMP uses PUT
// Stripe web-only and server-to-server endpoints — not implemented in mobile KMP
{Method: "POST", Path: "/subscription/checkout/"}: true, // Web-only (Stripe Checkout)
{Method: "POST", Path: "/subscription/portal/"}: true, // Web-only (Stripe Customer Portal)
{Method: "POST", Path: "/subscription/webhook/stripe/"}: true, // Server-to-server (Stripe webhook)
} }
// kmpRouteAliases maps KMP paths to their canonical spec paths. // kmpRouteAliases maps KMP paths to their canonical spec paths.

View File

@@ -76,7 +76,7 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
taskHandler := handlers.NewTaskHandler(taskService, nil) taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo)) contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo))
notificationHandler := handlers.NewNotificationHandler(notificationService) notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
// Create router with real middleware // Create router with real middleware
e := echo.New() e := echo.New()

View File

@@ -64,7 +64,7 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
// Create handlers // Create handlers
authHandler := handlers.NewAuthHandler(authService, nil, nil) authHandler := handlers.NewAuthHandler(authService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true) residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
// Create router // Create router
e := echo.New() e := echo.New()

View File

@@ -12,11 +12,20 @@ const (
TierPro SubscriptionTier = "pro" TierPro SubscriptionTier = "pro"
) )
// SubscriptionPlatform constants
const (
PlatformIOS = "ios"
PlatformAndroid = "android"
PlatformStripe = "stripe"
)
// SubscriptionSettings represents the subscription_subscriptionsettings table (singleton) // SubscriptionSettings represents the subscription_subscriptionsettings table (singleton)
type SubscriptionSettings struct { type SubscriptionSettings struct {
ID uint `gorm:"primaryKey" json:"id"` ID uint `gorm:"primaryKey" json:"id"`
EnableLimitations bool `gorm:"column:enable_limitations;default:false" json:"enable_limitations"` EnableLimitations bool `gorm:"column:enable_limitations;default:false" json:"enable_limitations"`
EnableMonitoring bool `gorm:"column:enable_monitoring;default:true" json:"enable_monitoring"` EnableMonitoring bool `gorm:"column:enable_monitoring;default:true" json:"enable_monitoring"`
TrialEnabled bool `gorm:"column:trial_enabled;default:true" json:"trial_enabled"`
TrialDurationDays int `gorm:"column:trial_duration_days;default:14" json:"trial_duration_days"`
} }
// TableName returns the table name for GORM // TableName returns the table name for GORM
@@ -31,18 +40,28 @@ type UserSubscription struct {
User User `gorm:"foreignKey:UserID" json:"-"` User User `gorm:"foreignKey:UserID" json:"-"`
Tier SubscriptionTier `gorm:"column:tier;size:10;default:'free'" json:"tier"` Tier SubscriptionTier `gorm:"column:tier;size:10;default:'free'" json:"tier"`
// In-App Purchase data // In-App Purchase data (Apple / Google)
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"` AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"` GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
// Stripe data (web subscriptions)
StripeCustomerID *string `gorm:"column:stripe_customer_id;size:255" json:"-"`
StripeSubscriptionID *string `gorm:"column:stripe_subscription_id;size:255" json:"-"`
StripePriceID *string `gorm:"column:stripe_price_id;size:255" json:"-"`
// Subscription dates // Subscription dates
SubscribedAt *time.Time `gorm:"column:subscribed_at" json:"subscribed_at"` SubscribedAt *time.Time `gorm:"column:subscribed_at" json:"subscribed_at"`
ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"` ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"`
AutoRenew bool `gorm:"column:auto_renew;default:true" json:"auto_renew"` AutoRenew bool `gorm:"column:auto_renew;default:true" json:"auto_renew"`
// Trial
TrialStart *time.Time `gorm:"column:trial_start" json:"trial_start"`
TrialEnd *time.Time `gorm:"column:trial_end" json:"trial_end"`
TrialUsed bool `gorm:"column:trial_used;default:false" json:"trial_used"`
// Tracking // Tracking
CancelledAt *time.Time `gorm:"column:cancelled_at" json:"cancelled_at"` CancelledAt *time.Time `gorm:"column:cancelled_at" json:"cancelled_at"`
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android, stripe
// Admin override - bypasses all limitations regardless of global settings // Admin override - bypasses all limitations regardless of global settings
IsFree bool `gorm:"column:is_free;default:false" json:"is_free"` IsFree bool `gorm:"column:is_free;default:false" json:"is_free"`
@@ -53,8 +72,11 @@ func (UserSubscription) TableName() string {
return "subscription_usersubscription" return "subscription_usersubscription"
} }
// IsActive returns true if the subscription is active (pro tier and not expired) // IsActive returns true if the subscription is active (pro tier and not expired, or trial active)
func (s *UserSubscription) IsActive() bool { func (s *UserSubscription) IsActive() bool {
if s.IsTrialActive() {
return true
}
if s.Tier != TierPro { if s.Tier != TierPro {
return false return false
} }
@@ -64,9 +86,37 @@ func (s *UserSubscription) IsActive() bool {
return true return true
} }
// IsPro returns true if the user has a pro subscription // IsPro returns true if the user has a pro subscription or active trial
func (s *UserSubscription) IsPro() bool { func (s *UserSubscription) IsPro() bool {
return s.Tier == TierPro && s.IsActive() return s.IsActive()
}
// IsTrialActive returns true if the user has an active, unexpired trial
func (s *UserSubscription) IsTrialActive() bool {
if s.TrialEnd == nil {
return false
}
return time.Now().UTC().Before(*s.TrialEnd)
}
// HasStripeSubscription returns true if the user has Stripe subscription data
func (s *UserSubscription) HasStripeSubscription() bool {
return s.StripeSubscriptionID != nil && *s.StripeSubscriptionID != ""
}
// HasAppleSubscription returns true if the user has Apple receipt data
func (s *UserSubscription) HasAppleSubscription() bool {
return s.AppleReceiptData != nil && *s.AppleReceiptData != ""
}
// HasGoogleSubscription returns true if the user has Google purchase token
func (s *UserSubscription) HasGoogleSubscription() bool {
return s.GooglePurchaseToken != nil && *s.GooglePurchaseToken != ""
}
// SubscriptionSource returns the platform that the active subscription came from
func (s *UserSubscription) SubscriptionSource() string {
return s.Platform
} }
// UpgradeTrigger represents the subscription_upgradetrigger table // UpgradeTrigger represents the subscription_upgradetrigger table

View File

@@ -0,0 +1,187 @@
package models
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestIsTrialActive(t *testing.T) {
now := time.Now().UTC()
future := now.Add(24 * time.Hour)
past := now.Add(-24 * time.Hour)
tests := []struct {
name string
sub *UserSubscription
expected bool
}{
{
name: "trial_end in future returns true",
sub: &UserSubscription{TrialEnd: &future},
expected: true,
},
{
name: "trial_end in past returns false",
sub: &UserSubscription{TrialEnd: &past},
expected: false,
},
{
name: "trial_end nil returns false",
sub: &UserSubscription{TrialEnd: nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.sub.IsTrialActive()
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsPro(t *testing.T) {
now := time.Now().UTC()
future := now.Add(24 * time.Hour)
past := now.Add(-24 * time.Hour)
tests := []struct {
name string
sub *UserSubscription
expected bool
}{
{
name: "tier=pro, expires_at in future returns true",
sub: &UserSubscription{
Tier: TierPro,
ExpiresAt: &future,
},
expected: true,
},
{
name: "tier=pro, expires_at in past returns false",
sub: &UserSubscription{
Tier: TierPro,
ExpiresAt: &past,
},
expected: false,
},
{
name: "tier=free, trial active returns true",
sub: &UserSubscription{
Tier: TierFree,
TrialEnd: &future,
},
expected: true,
},
{
name: "tier=free, trial expired returns false",
sub: &UserSubscription{
Tier: TierFree,
TrialEnd: &past,
},
expected: false,
},
{
name: "tier=free, no trial returns false",
sub: &UserSubscription{
Tier: TierFree,
TrialEnd: nil,
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.sub.IsPro()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHasSubscriptionHelpers(t *testing.T) {
empty := ""
validStripeID := "sub_1234567890"
validReceipt := "MIIT..."
validToken := "google-purchase-token-123"
tests := []struct {
name string
sub *UserSubscription
method string
expected bool
}{
// HasStripeSubscription
{
name: "HasStripeSubscription with nil returns false",
sub: &UserSubscription{StripeSubscriptionID: nil},
method: "stripe",
expected: false,
},
{
name: "HasStripeSubscription with empty string returns false",
sub: &UserSubscription{StripeSubscriptionID: &empty},
method: "stripe",
expected: false,
},
{
name: "HasStripeSubscription with valid ID returns true",
sub: &UserSubscription{StripeSubscriptionID: &validStripeID},
method: "stripe",
expected: true,
},
// HasAppleSubscription
{
name: "HasAppleSubscription with nil returns false",
sub: &UserSubscription{AppleReceiptData: nil},
method: "apple",
expected: false,
},
{
name: "HasAppleSubscription with empty string returns false",
sub: &UserSubscription{AppleReceiptData: &empty},
method: "apple",
expected: false,
},
{
name: "HasAppleSubscription with valid receipt returns true",
sub: &UserSubscription{AppleReceiptData: &validReceipt},
method: "apple",
expected: true,
},
// HasGoogleSubscription
{
name: "HasGoogleSubscription with nil returns false",
sub: &UserSubscription{GooglePurchaseToken: nil},
method: "google",
expected: false,
},
{
name: "HasGoogleSubscription with empty string returns false",
sub: &UserSubscription{GooglePurchaseToken: &empty},
method: "google",
expected: false,
},
{
name: "HasGoogleSubscription with valid token returns true",
sub: &UserSubscription{GooglePurchaseToken: &validToken},
method: "google",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result bool
switch tt.method {
case "stripe":
result = tt.sub.HasStripeSubscription()
case "apple":
result = tt.sub.HasAppleSubscription()
case "google":
result = tt.sub.HasGoogleSubscription()
}
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -262,3 +262,59 @@ func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.P
} }
return &promotion, nil return &promotion, nil
} }
// === Stripe Lookups ===
// FindByStripeCustomerID finds a subscription by Stripe customer ID
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_customer_id": customerID,
"stripe_subscription_id": subscriptionID,
"stripe_price_id": priceID,
}).Error
}
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_subscription_id": nil,
"stripe_price_id": nil,
}).Error
}
// === Trial Management ===
// SetTrialDates sets the trial start, end, and marks trial as used
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"trial_start": trialStart,
"trial_end": trialEnd,
"trial_used": true,
}).Error
}
// UpdateExpiresAt updates the expires_at field for a user's subscription
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
Update("expires_at", expiresAt).Error
}

View File

@@ -2,9 +2,11 @@ package repositories
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models" "github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil" "github.com/treytartt/casera-api/internal/testutil"
@@ -77,3 +79,150 @@ func TestGetOrCreate_Idempotent(t *testing.T) {
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count) db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls") assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
} }
func TestFindByStripeCustomerID(t *testing.T) {
tests := []struct {
name string
customerID string
seedID string
wantErr bool
}{
{
name: "finds existing subscription by stripe customer ID",
customerID: "cus_test123",
seedID: "cus_test123",
wantErr: false,
},
{
name: "returns error for unknown stripe customer ID",
customerID: "cus_unknown999",
seedID: "cus_test456",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
StripeCustomerID: &tt.seedID,
}
err := db.Create(sub).Error
require.NoError(t, err)
found, err := repo.FindByStripeCustomerID(tt.customerID)
if tt.wantErr {
assert.Error(t, err)
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
} else {
require.NoError(t, err)
require.NotNil(t, found)
assert.Equal(t, user.ID, found.UserID)
assert.Equal(t, tt.seedID, *found.StripeCustomerID)
}
})
}
}
func TestUpdateStripeData(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
}
err := db.Create(sub).Error
require.NoError(t, err)
// Update all three Stripe fields
customerID := "cus_abc123"
subscriptionID := "sub_xyz789"
priceID := "price_monthly"
err = repo.UpdateStripeData(user.ID, customerID, subscriptionID, priceID)
require.NoError(t, err)
// Verify all three fields are set
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.StripeCustomerID)
require.NotNil(t, updated.StripeSubscriptionID)
require.NotNil(t, updated.StripePriceID)
assert.Equal(t, customerID, *updated.StripeCustomerID)
assert.Equal(t, subscriptionID, *updated.StripeSubscriptionID)
assert.Equal(t, priceID, *updated.StripePriceID)
// Now call ClearStripeData
err = repo.ClearStripeData(user.ID)
require.NoError(t, err)
// Verify subscription_id and price_id are cleared, customer_id preserved
var cleared models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&cleared).Error
require.NoError(t, err)
require.NotNil(t, cleared.StripeCustomerID, "customer_id should be preserved after ClearStripeData")
assert.Equal(t, customerID, *cleared.StripeCustomerID)
assert.Nil(t, cleared.StripeSubscriptionID, "subscription_id should be nil after ClearStripeData")
assert.Nil(t, cleared.StripePriceID, "price_id should be nil after ClearStripeData")
}
func TestSetTrialDates(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
TrialUsed: false,
}
err := db.Create(sub).Error
require.NoError(t, err)
trialStart := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
trialEnd := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
err = repo.SetTrialDates(user.ID, trialStart, trialEnd)
require.NoError(t, err)
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.TrialStart)
require.NotNil(t, updated.TrialEnd)
assert.True(t, updated.TrialUsed, "trial_used should be set to true")
assert.WithinDuration(t, trialStart, *updated.TrialStart, time.Second, "trial_start should match")
assert.WithinDuration(t, trialEnd, *updated.TrialEnd, time.Second, "trial_end should match")
}
func TestUpdateExpiresAt(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
}
err := db.Create(sub).Error
require.NoError(t, err)
newExpiry := time.Date(2027, 6, 15, 12, 0, 0, 0, time.UTC)
err = repo.UpdateExpiresAt(user.ID, newExpiry)
require.NoError(t, err)
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.ExpiresAt)
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
}

View File

@@ -58,7 +58,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
e.Use(custommiddleware.RequestIDMiddleware()) e.Use(custommiddleware.RequestIDMiddleware())
e.Use(utils.EchoRecovery()) e.Use(utils.EchoRecovery())
e.Use(custommiddleware.StructuredLogger()) e.Use(custommiddleware.StructuredLogger())
e.Use(middleware.BodyLimit("1M")) // 1MB default for JSON payloads e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
Limit: "1M", // 1MB default for JSON payloads
Skipper: func(c echo.Context) bool {
// Allow larger payloads for webhook endpoints (Apple/Google/Stripe notifications)
return strings.HasPrefix(c.Request().URL.Path, "/api/subscription/webhook")
},
}))
e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{ e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Skipper: func(c echo.Context) bool { Skipper: func(c echo.Context) bool {
@@ -143,11 +149,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
residenceService.SetSubscriptionService(subscriptionService) // Wire up subscription service for tier limit enforcement residenceService.SetSubscriptionService(subscriptionService) // Wire up subscription service for tier limit enforcement
taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo) taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo)
// Initialize Stripe service
stripeService := services.NewStripeService(subscriptionRepo, userRepo)
// Initialize webhook event repo for deduplication // Initialize webhook event repo for deduplication
webhookEventRepo := repositories.NewWebhookEventRepository(deps.DB) webhookEventRepo := repositories.NewWebhookEventRepository(deps.DB)
// Initialize webhook handler for Apple/Google subscription notifications // Initialize webhook handler for Apple/Google/Stripe subscription notifications
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled) subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled)
subscriptionWebhookHandler.SetStripeService(stripeService)
// Initialize middleware // Initialize middleware
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache) authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
@@ -166,7 +176,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
contractorHandler := handlers.NewContractorHandler(contractorService) contractorHandler := handlers.NewContractorHandler(contractorService)
documentHandler := handlers.NewDocumentHandler(documentService, deps.StorageService) documentHandler := handlers.NewDocumentHandler(documentService, deps.StorageService)
notificationHandler := handlers.NewNotificationHandler(notificationService) notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, stripeService)
staticDataHandler := handlers.NewStaticDataHandler(residenceService, taskService, contractorService, taskTemplateService, deps.Cache) staticDataHandler := handlers.NewStaticDataHandler(residenceService, taskService, contractorService, taskTemplateService, deps.Cache)
taskTemplateHandler := handlers.NewTaskTemplateHandler(taskTemplateService) taskTemplateHandler := handlers.NewTaskTemplateHandler(taskTemplateService)
@@ -458,6 +468,8 @@ func setupSubscriptionRoutes(api *echo.Group, subscriptionHandler *handlers.Subs
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase) subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
subscription.POST("/cancel/", subscriptionHandler.CancelSubscription) subscription.POST("/cancel/", subscriptionHandler.CancelSubscription)
subscription.POST("/restore/", subscriptionHandler.RestoreSubscription) subscription.POST("/restore/", subscriptionHandler.RestoreSubscription)
subscription.POST("/checkout/", subscriptionHandler.CreateCheckoutSession)
subscription.POST("/portal/", subscriptionHandler.CreatePortalSession)
} }
} }
@@ -499,6 +511,7 @@ func setupWebhookRoutes(api *echo.Group, webhookHandler *handlers.SubscriptionWe
{ {
webhooks.POST("/apple/", webhookHandler.HandleAppleWebhook) webhooks.POST("/apple/", webhookHandler.HandleAppleWebhook)
webhooks.POST("/google/", webhookHandler.HandleGoogleWebhook) webhooks.POST("/google/", webhookHandler.HandleGoogleWebhook)
webhooks.POST("/stripe/", webhookHandler.HandleStripeWebhook)
} }
} }

View File

@@ -0,0 +1,456 @@
package services
import (
"encoding/json"
"fmt"
"os"
"time"
"github.com/rs/zerolog/log"
"github.com/stripe/stripe-go/v81"
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
"github.com/stripe/stripe-go/v81/customer"
"github.com/stripe/stripe-go/v81/webhook"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
)
// StripeService handles Stripe checkout, portal, and webhook processing
// for web-based subscription purchases.
type StripeService struct {
subscriptionRepo *repositories.SubscriptionRepository
userRepo *repositories.UserRepository
webhookSecret string
}
// NewStripeService creates a new Stripe service. It initializes the global
// Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key
// is not set, a warning is logged but the service is still returned (matching
// the pattern used by the Apple/Google IAP clients).
func NewStripeService(
subscriptionRepo *repositories.SubscriptionRepository,
userRepo *repositories.UserRepository,
) *StripeService {
key := os.Getenv("STRIPE_SECRET_KEY")
if key == "" {
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
} else {
stripe.Key = key
log.Info().Msg("Stripe API key configured")
}
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
if webhookSecret == "" {
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
}
return &StripeService{
subscriptionRepo: subscriptionRepo,
userRepo: userRepo,
webhookSecret: webhookSecret,
}
}
// CreateCheckoutSession creates a Stripe Checkout Session for a web subscription purchase.
// It ensures the user has a Stripe customer record and configures the session with a trial
// period if the user has not used their trial yet.
func (s *StripeService) CreateCheckoutSession(userID uint, priceID string, successURL string, cancelURL string) (string, error) {
// Get or create the user's subscription record
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return "", apperrors.Internal(err)
}
// Get the user's email for the Stripe customer
user, err := s.userRepo.FindByID(userID)
if err != nil {
return "", apperrors.Internal(err)
}
// Get or create a Stripe customer
stripeCustomerID, err := s.getOrCreateStripeCustomer(sub, user)
if err != nil {
return "", apperrors.Internal(err)
}
// Build the checkout session parameters
params := &stripe.CheckoutSessionParams{
Customer: stripe.String(stripeCustomerID),
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
SuccessURL: stripe.String(successURL),
CancelURL: stripe.String(cancelURL),
ClientReferenceID: stripe.String(fmt.Sprintf("%d", userID)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
}
// Offer a trial period if the user has not used their trial yet
if !sub.TrialUsed {
trialDays, err := s.getTrialDays()
if err != nil {
log.Warn().Err(err).Msg("Failed to get trial duration from settings, skipping trial")
} else if trialDays > 0 {
params.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{
TrialPeriodDays: stripe.Int64(int64(trialDays)),
}
}
}
session, err := checkoutsession.New(params)
if err != nil {
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe checkout session")
return "", apperrors.Internal(err)
}
log.Info().
Uint("user_id", userID).
Str("session_id", session.ID).
Str("price_id", priceID).
Msg("Stripe checkout session created")
return session.URL, nil
}
// CreatePortalSession creates a Stripe Customer Portal session so the user
// can manage their subscription (cancel, change plan, update payment method).
func (s *StripeService) CreatePortalSession(userID uint, returnURL string) (string, error) {
sub, err := s.subscriptionRepo.FindByUserID(userID)
if err != nil {
return "", apperrors.NotFound("error.subscription_not_found")
}
if sub.StripeCustomerID == nil || *sub.StripeCustomerID == "" {
return "", apperrors.BadRequest("error.no_stripe_customer")
}
params := &stripe.BillingPortalSessionParams{
Customer: sub.StripeCustomerID,
ReturnURL: stripe.String(returnURL),
}
session, err := portalsession.New(params)
if err != nil {
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe portal session")
return "", apperrors.Internal(err)
}
return session.URL, nil
}
// HandleWebhookEvent verifies and processes a Stripe webhook event.
// It handles checkout completion, subscription lifecycle changes, and invoice events.
func (s *StripeService) HandleWebhookEvent(payload []byte, signature string) error {
event, err := webhook.ConstructEvent(payload, signature, s.webhookSecret)
if err != nil {
log.Warn().Err(err).Msg("Stripe webhook signature verification failed")
return apperrors.BadRequest("error.invalid_webhook_signature")
}
log.Info().
Str("event_type", string(event.Type)).
Str("event_id", event.ID).
Msg("Processing Stripe webhook event")
switch event.Type {
case "checkout.session.completed":
return s.handleCheckoutCompleted(event)
case "customer.subscription.updated":
return s.handleSubscriptionUpdated(event)
case "customer.subscription.deleted":
return s.handleSubscriptionDeleted(event)
case "invoice.paid":
return s.handleInvoicePaid(event)
case "invoice.payment_failed":
return s.handleInvoicePaymentFailed(event)
default:
log.Debug().Str("event_type", string(event.Type)).Msg("Unhandled Stripe webhook event type")
return nil
}
}
// handleCheckoutCompleted processes a successful checkout session. It links the Stripe
// customer and subscription to the user's record and upgrades them to Pro.
func (s *StripeService) handleCheckoutCompleted(event stripe.Event) error {
var session stripe.CheckoutSession
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal checkout session from webhook")
return apperrors.Internal(err)
}
// Extract the user ID from client_reference_id
var userID uint
if _, err := fmt.Sscanf(session.ClientReferenceID, "%d", &userID); err != nil {
log.Error().Str("client_reference_id", session.ClientReferenceID).Msg("Invalid client_reference_id in checkout session")
return apperrors.BadRequest("error.invalid_client_reference_id")
}
// Get or create the subscription record
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return apperrors.Internal(err)
}
// Save Stripe customer and subscription IDs
if session.Customer != nil {
sub.StripeCustomerID = &session.Customer.ID
}
if session.Subscription != nil {
sub.StripeSubscriptionID = &session.Subscription.ID
}
if err := s.subscriptionRepo.Update(sub); err != nil {
return apperrors.Internal(err)
}
// Upgrade to Pro. Use a far-future expiry for now; the invoice.paid event
// will set the real period_end once the first invoice is finalized.
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err)
}
customerID := ""
if session.Customer != nil {
customerID = session.Customer.ID
}
subscriptionID := ""
if session.Subscription != nil {
subscriptionID = session.Subscription.ID
}
log.Info().
Uint("user_id", userID).
Str("stripe_customer_id", customerID).
Str("stripe_subscription_id", subscriptionID).
Msg("Checkout completed, user upgraded to Pro")
// TODO: Send push notification to user's devices when subscription activates
return nil
}
// handleSubscriptionUpdated processes subscription status changes. It upgrades or
// downgrades the user depending on the subscription's current status.
func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
var subscription stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
return apperrors.Internal(err)
}
sub, err := s.findSubscriptionByStripeID(subscription.ID)
if err != nil {
return err
}
switch subscription.Status {
case stripe.SubscriptionStatusActive, stripe.SubscriptionStatusTrialing:
// Subscription is healthy, ensure user is Pro
expiresAt := time.Unix(subscription.CurrentPeriodEnd, 0).UTC()
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err)
}
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active")
case stripe.SubscriptionStatusPastDue:
log.Warn().Uint("user_id", sub.UserID).Msg("Stripe subscription past due, waiting for retry")
// Don't downgrade yet; Stripe will retry the payment automatically.
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusUnpaid:
// Check if the user has active subscriptions from other sources before downgrading
if s.isActiveFromOtherSources(sub) {
log.Info().
Uint("user_id", sub.UserID).
Str("status", string(subscription.Status)).
Msg("Stripe subscription ended but user has other active sources, keeping Pro")
return nil
}
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
return apperrors.Internal(err)
}
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended")
}
return nil
}
// handleSubscriptionDeleted processes a subscription that has been fully cancelled
// and is no longer active. It downgrades the user unless they have active subscriptions
// from other sources (Apple, Google).
func (s *StripeService) handleSubscriptionDeleted(event stripe.Event) error {
var subscription stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
return apperrors.Internal(err)
}
sub, err := s.findSubscriptionByStripeID(subscription.ID)
if err != nil {
return err
}
// Check multi-source before downgrading
if s.isActiveFromOtherSources(sub) {
log.Info().
Uint("user_id", sub.UserID).
Msg("Stripe subscription deleted but user has other active sources, keeping Pro")
return nil
}
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
return apperrors.Internal(err)
}
log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted")
// TODO: Send push notification to user's devices about subscription ending
return nil
}
// handleInvoicePaid processes a successful invoice payment. It updates the subscription
// expiry to the current billing period's end date and ensures the user is on Pro.
func (s *StripeService) handleInvoicePaid(event stripe.Event) error {
var invoice stripe.Invoice
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
return apperrors.Internal(err)
}
// Only process subscription invoices
if invoice.Subscription == nil {
return nil
}
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
if err != nil {
return err
}
// Update expiry from the invoice's period end
expiresAt := time.Unix(invoice.PeriodEnd, 0).UTC()
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err)
}
log.Info().
Uint("user_id", sub.UserID).
Time("expires_at", expiresAt).
Msg("Invoice paid, subscription renewed")
return nil
}
// handleInvoicePaymentFailed logs a warning when a payment fails. We do not downgrade
// the user here because Stripe will automatically retry the payment according to its
// Smart Retries schedule.
func (s *StripeService) handleInvoicePaymentFailed(event stripe.Event) error {
var invoice stripe.Invoice
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
return apperrors.Internal(err)
}
if invoice.Subscription == nil {
return nil
}
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
if err != nil {
// If we can't find the subscription, just log and return
log.Warn().Str("stripe_subscription_id", invoice.Subscription.ID).Msg("Invoice payment failed for unknown subscription")
return nil
}
log.Warn().
Uint("user_id", sub.UserID).
Str("invoice_id", invoice.ID).
Msg("Stripe invoice payment failed, Stripe will retry automatically")
return nil
}
// isActiveFromOtherSources checks if the user has active subscriptions from Apple or Google
// that should prevent a downgrade when the Stripe subscription ends.
func (s *StripeService) isActiveFromOtherSources(sub *models.UserSubscription) bool {
now := time.Now().UTC()
// Check Apple subscription
if sub.HasAppleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
return true
}
// Check Google subscription
if sub.HasGoogleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
return true
}
// Check active trial
if sub.IsTrialActive() {
return true
}
return false
}
// getOrCreateStripeCustomer returns the existing Stripe customer ID from the subscription
// record, or creates a new Stripe customer and persists the ID.
func (s *StripeService) getOrCreateStripeCustomer(sub *models.UserSubscription, user *models.User) (string, error) {
// If we already have a Stripe customer, return it
if sub.StripeCustomerID != nil && *sub.StripeCustomerID != "" {
return *sub.StripeCustomerID, nil
}
// Create a new Stripe customer
params := &stripe.CustomerParams{
Email: stripe.String(user.Email),
Name: stripe.String(user.GetFullName()),
}
params.AddMetadata("casera_user_id", fmt.Sprintf("%d", user.ID))
c, err := customer.New(params)
if err != nil {
return "", fmt.Errorf("failed to create Stripe customer: %w", err)
}
// Save the customer ID to the subscription record
sub.StripeCustomerID = &c.ID
if err := s.subscriptionRepo.Update(sub); err != nil {
return "", fmt.Errorf("failed to save Stripe customer ID: %w", err)
}
log.Info().
Uint("user_id", user.ID).
Str("stripe_customer_id", c.ID).
Msg("Created new Stripe customer")
return c.ID, nil
}
// findSubscriptionByStripeID looks up a UserSubscription by its Stripe subscription ID.
func (s *StripeService) findSubscriptionByStripeID(stripeSubID string) (*models.UserSubscription, error) {
sub, err := s.subscriptionRepo.FindByStripeSubscriptionID(stripeSubID)
if err != nil {
log.Warn().Str("stripe_subscription_id", stripeSubID).Err(err).Msg("Subscription not found for Stripe ID")
return nil, apperrors.NotFound("error.subscription_not_found")
}
return sub, nil
}
// getTrialDays reads the trial duration from SubscriptionSettings.
func (s *StripeService) getTrialDays() (int, error) {
settings, err := s.subscriptionRepo.GetSettings()
if err != nil {
return 0, err
}
if !settings.TrialEnabled {
return 0, nil
}
return settings.TrialDurationDays, nil
}

View File

@@ -118,6 +118,20 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Auto-start trial for new users who have never had a trial
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
now := time.Now().UTC()
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
return nil, apperrors.Internal(err)
}
// Re-fetch after starting trial so response reflects the new state
sub, err = s.subscriptionRepo.FindByUserID(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
}
// Get all tier limits and build a map // Get all tier limits and build a map
allLimits, err := s.subscriptionRepo.GetAllTierLimits() allLimits, err := s.subscriptionRepo.GetAllTierLimits()
if err != nil { if err != nil {
@@ -154,6 +168,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
// Build flattened response (KMM expects subscription fields at top level) // Build flattened response (KMM expects subscription fields at top level)
resp := &SubscriptionStatusResponse{ resp := &SubscriptionStatusResponse{
Tier: string(sub.Tier),
IsActive: sub.IsActive(),
AutoRenew: sub.AutoRenew, AutoRenew: sub.AutoRenew,
Limits: limitsMap, Limits: limitsMap,
Usage: usage, Usage: usage,
@@ -170,6 +186,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
resp.ExpiresAt = &t resp.ExpiresAt = &t
} }
// Populate trial fields
if sub.TrialStart != nil {
t := sub.TrialStart.Format("2006-01-02T15:04:05Z")
resp.TrialStart = &t
}
if sub.TrialEnd != nil {
t := sub.TrialEnd.Format("2006-01-02T15:04:05Z")
resp.TrialEnd = &t
}
resp.TrialActive = sub.IsTrialActive()
resp.SubscriptionSource = sub.SubscriptionSource()
return resp, nil return resp, nil
} }
@@ -449,28 +477,48 @@ func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResp
return s.GetSubscription(userID) return s.GetSubscription(userID)
} }
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return false, "", apperrors.Internal(err)
}
if !sub.IsPro() {
return false, "", nil
}
if sub.Platform == requestedPlatform {
return false, "", nil
}
return true, sub.Platform, nil
}
// === Response Types === // === Response Types ===
// SubscriptionResponse represents a subscription in API response // SubscriptionResponse represents a subscription in API response
type SubscriptionResponse struct { type SubscriptionResponse struct {
Tier string `json:"tier"` Tier string `json:"tier"`
SubscribedAt *string `json:"subscribed_at"` SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"` ExpiresAt *string `json:"expires_at"`
AutoRenew bool `json:"auto_renew"` AutoRenew bool `json:"auto_renew"`
CancelledAt *string `json:"cancelled_at"` CancelledAt *string `json:"cancelled_at"`
Platform string `json:"platform"` Platform string `json:"platform"`
IsActive bool `json:"is_active"` IsActive bool `json:"is_active"`
IsPro bool `json:"is_pro"` IsPro bool `json:"is_pro"`
TrialActive bool `json:"trial_active"`
SubscriptionSource string `json:"subscription_source"`
} }
// NewSubscriptionResponse creates a SubscriptionResponse from a model // NewSubscriptionResponse creates a SubscriptionResponse from a model
func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse { func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse {
resp := &SubscriptionResponse{ resp := &SubscriptionResponse{
Tier: string(s.Tier), Tier: string(s.Tier),
AutoRenew: s.AutoRenew, AutoRenew: s.AutoRenew,
Platform: s.Platform, Platform: s.Platform,
IsActive: s.IsActive(), IsActive: s.IsActive(),
IsPro: s.IsPro(), IsPro: s.IsPro(),
TrialActive: s.IsTrialActive(),
SubscriptionSource: s.SubscriptionSource(),
} }
if s.SubscribedAt != nil { if s.SubscribedAt != nil {
t := s.SubscribedAt.Format("2006-01-02T15:04:05Z") t := s.SubscribedAt.Format("2006-01-02T15:04:05Z")
@@ -536,11 +584,23 @@ func NewTierLimitsClientResponse(l *models.TierLimits) *TierLimitsClientResponse
// SubscriptionStatusResponse represents full subscription status // SubscriptionStatusResponse represents full subscription status
// Fields are flattened to match KMM client expectations // Fields are flattened to match KMM client expectations
type SubscriptionStatusResponse struct { type SubscriptionStatusResponse struct {
// Tier and active status
Tier string `json:"tier"`
IsActive bool `json:"is_active"`
// Flattened subscription fields (KMM expects these at top level) // Flattened subscription fields (KMM expects these at top level)
SubscribedAt *string `json:"subscribed_at"` SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"` ExpiresAt *string `json:"expires_at"`
AutoRenew bool `json:"auto_renew"` AutoRenew bool `json:"auto_renew"`
// Trial fields
TrialStart *string `json:"trial_start,omitempty"`
TrialEnd *string `json:"trial_end,omitempty"`
TrialActive bool `json:"trial_active"`
// Subscription source
SubscriptionSource string `json:"subscription_source"`
// Other fields // Other fields
Usage *UsageResponse `json:"usage"` Usage *UsageResponse `json:"usage"`
Limits map[string]*TierLimitsClientResponse `json:"limits"` Limits map[string]*TierLimitsClientResponse `json:"limits"`
@@ -638,5 +698,5 @@ type ProcessPurchaseRequest struct {
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
PurchaseToken string `json:"purchase_token"` // Android PurchaseToken string `json:"purchase_token"` // Android
ProductID string `json:"product_id"` // Android (optional, helps identify subscription) ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
Platform string `json:"platform" validate:"required,oneof=ios android"` Platform string `json:"platform" validate:"required,oneof=ios android stripe"`
} }

View File

@@ -2,6 +2,7 @@ package services
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -179,3 +180,94 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier") assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
} }
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
future := time.Now().UTC().Add(30 * 24 * time.Hour)
tests := []struct {
name string
tier models.SubscriptionTier
platform string
expiresAt *time.Time
trialEnd *time.Time
requestedPlatform string
wantConflict bool
wantPlatform string
}{
{
name: "free user returns no conflict",
tier: models.TierFree,
platform: "",
expiresAt: nil,
trialEnd: nil,
requestedPlatform: "stripe",
wantConflict: false,
wantPlatform: "",
},
{
name: "pro from ios, requesting ios returns no conflict (same platform)",
tier: models.TierPro,
platform: "ios",
expiresAt: &future,
trialEnd: nil,
requestedPlatform: "ios",
wantConflict: false,
wantPlatform: "",
},
{
name: "pro from ios, requesting stripe returns conflict",
tier: models.TierPro,
platform: "ios",
expiresAt: &future,
trialEnd: nil,
requestedPlatform: "stripe",
wantConflict: true,
wantPlatform: "ios",
},
{
name: "pro from stripe, requesting android returns conflict",
tier: models.TierPro,
platform: "stripe",
expiresAt: &future,
trialEnd: nil,
requestedPlatform: "android",
wantConflict: true,
wantPlatform: "stripe",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
}
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: tt.tier,
Platform: tt.platform,
ExpiresAt: tt.expiresAt,
TrialEnd: tt.trialEnd,
}
err := db.Create(sub).Error
require.NoError(t, err)
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
require.NoError(t, err)
assert.Equal(t, tt.wantConflict, conflict)
assert.Equal(t, tt.wantPlatform, existingPlatform)
})
}
}