Add Stripe billing, free trials, and cross-platform subscription guards
- Stripe integration: add StripeService with checkout sessions, customer portal, and webhook handling for subscription lifecycle events. - Free trials: auto-start configurable trial on first subscription check, with admin-controllable duration and enable/disable toggle. - Cross-platform guard: prevent duplicate subscriptions across iOS, Android, and Stripe by checking existing platform before allowing purchase. - Subscription model: add Stripe fields (customer_id, subscription_id, price_id), trial fields (trial_start, trial_end, trial_used), and SubscriptionSource/IsTrialActive helpers. - API: add trial and source fields to status response, update OpenAPI spec. - Clean up stale migration and audit docs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,170 +0,0 @@
|
||||
# Phase 4: Gin to Echo Handler Migration Status
|
||||
|
||||
## Completed Files
|
||||
|
||||
### ✅ auth_handler.go
|
||||
- **Status**: Fully migrated
|
||||
- **Methods migrated**: 13 methods
|
||||
- Login, Register, Logout, CurrentUser, UpdateProfile
|
||||
- VerifyEmail, ResendVerification
|
||||
- ForgotPassword, VerifyResetCode, ResetPassword
|
||||
- AppleSignIn, GoogleSignIn
|
||||
- **Key changes applied**:
|
||||
- Import changed from gin to echo
|
||||
- Added validator import
|
||||
- All handlers return error
|
||||
- c.Bind + c.Validate pattern implemented
|
||||
- c.MustGet → c.Get
|
||||
- gin.H → map[string]interface{}
|
||||
- c.Request.Context() → c.Request().Context()
|
||||
- All c.JSON calls use `return`
|
||||
|
||||
## Remaining Files to Migrate
|
||||
|
||||
### 🔧 residence_handler.go
|
||||
- **Status**: Partially migrated (needs cleanup)
|
||||
- **Methods**: 13 methods
|
||||
- **Issue**: Sed-based automated migration created syntax errors
|
||||
- **Next steps**: Manual cleanup needed
|
||||
|
||||
### ⏳ task_handler.go
|
||||
- **Methods**: ~17 methods
|
||||
- **Complexity**: High (multipart form handling for completions)
|
||||
- **Special considerations**:
|
||||
- Has multipart/form-data handling in CreateCompletion
|
||||
- Multiple lookup endpoints (categories, priorities, frequencies)
|
||||
|
||||
### ⏳ contractor_handler.go
|
||||
- **Methods**: 8 methods
|
||||
- **Complexity**: Medium
|
||||
|
||||
### ⏳ document_handler.go
|
||||
- **Methods**: 8 methods
|
||||
- **Complexity**: High (multipart form handling)
|
||||
- **Special considerations**: File upload in CreateDocument
|
||||
|
||||
### ⏳ notification_handler.go
|
||||
- **Methods**: 9 methods
|
||||
- **Complexity**: Medium
|
||||
- **Special considerations**: Query parameters for pagination
|
||||
|
||||
### ⏳ subscription_handler.go
|
||||
- **Status**: Unknown
|
||||
- **Estimated complexity**: Medium
|
||||
|
||||
### ⏳ upload_handler.go
|
||||
- **Methods**: 4 methods
|
||||
- **Complexity**: Medium
|
||||
- **Special considerations**: c.FormFile handling, c.DefaultQuery
|
||||
|
||||
### ⏳ user_handler.go
|
||||
- **Methods**: 3 methods
|
||||
- **Complexity**: Low
|
||||
|
||||
### ⏳ media_handler.go
|
||||
- **Status**: Unknown
|
||||
- **Estimated complexity**: Medium
|
||||
|
||||
### ⏳ static_data_handler.go
|
||||
- **Methods**: Unknown
|
||||
- **Complexity**: Low (likely just lookups)
|
||||
|
||||
### ⏳ task_template_handler.go
|
||||
- **Status**: Unknown
|
||||
- **Estimated complexity**: Medium
|
||||
|
||||
### ⏳ tracking_handler.go
|
||||
- **Status**: Unknown
|
||||
- **Estimated complexity**: Low
|
||||
|
||||
### ⏳ subscription_webhook_handler.go
|
||||
- **Status**: Unknown
|
||||
- **Estimated complexity**: Medium-High (webhook handling)
|
||||
|
||||
## Migration Pattern
|
||||
|
||||
All handlers must follow these transformations:
|
||||
|
||||
```go
|
||||
// BEFORE (Gin)
|
||||
func (h *Handler) Method(c *gin.Context) {
|
||||
user := c.MustGet(middleware.AuthUserKey).(*models.User)
|
||||
|
||||
var req requests.SomeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.service.DoSomething(&req)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, result)
|
||||
}
|
||||
|
||||
// AFTER (Echo)
|
||||
func (h *Handler) Method(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
|
||||
var req requests.SomeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(400, map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(400, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
result, err := h.service.DoSomething(&req)
|
||||
if err != nil {
|
||||
return c.JSON(500, map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(200, result)
|
||||
}
|
||||
```
|
||||
|
||||
## Critical Context Changes
|
||||
|
||||
| Gin | Echo |
|
||||
|-----|------|
|
||||
| `c.MustGet()` | `c.Get()` |
|
||||
| `c.ShouldBindJSON()` | `c.Bind()` + `c.Validate()` |
|
||||
| `c.JSON(status, data)` | `return c.JSON(status, data)` |
|
||||
| `c.Query("key")` | `c.QueryParam("key")` |
|
||||
| `c.DefaultQuery("k", "v")` | Manual: `if v := c.QueryParam("k"); v != "" { } else { v = "default" }` |
|
||||
| `c.PostForm("field")` | `c.FormValue("field")` |
|
||||
| `c.GetHeader("X-...")` | `c.Request().Header.Get("X-...")` |
|
||||
| `c.Request.Context()` | `c.Request().Context()` |
|
||||
| `c.Status(code)` | `return c.NoContent(code)` |
|
||||
| `gin.H{...}` | `map[string]interface{}{...}` |
|
||||
|
||||
## Multipart Form Handling
|
||||
|
||||
For handlers with file uploads (document_handler, task_handler):
|
||||
|
||||
```go
|
||||
// Request parsing
|
||||
c.Request.ParseMultipartForm(32 << 20) // Same
|
||||
c.PostForm("field") → c.FormValue("field")
|
||||
c.FormFile("file") // Same
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Clean up residence_handler.go manually
|
||||
2. Migrate contractor_handler.go (simpler, good template)
|
||||
3. Migrate smaller files: user_handler.go, upload_handler.go, notification_handler.go
|
||||
4. Migrate complex files: task_handler.go, document_handler.go
|
||||
5. Migrate remaining files
|
||||
6. Test compilation
|
||||
7. Update route registration (if not already done in Phase 3)
|
||||
|
||||
## Automation Lessons Learned
|
||||
|
||||
- Sed-based bulk replacements are error-prone for complex Go code
|
||||
- Better approach: Manual migration with copy-paste for repetitive patterns
|
||||
- Python script provided in migrate_handlers.py (not yet tested)
|
||||
- Best approach: Methodical manual migration with validation at each step
|
||||
@@ -1,36 +0,0 @@
|
||||
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
|
||||
|
||||
## Getting Started
|
||||
|
||||
First, run the development server:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
# or
|
||||
yarn dev
|
||||
# or
|
||||
pnpm dev
|
||||
# or
|
||||
bun dev
|
||||
```
|
||||
|
||||
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
|
||||
|
||||
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
|
||||
|
||||
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
|
||||
|
||||
## Learn More
|
||||
|
||||
To learn more about Next.js, take a look at the following resources:
|
||||
|
||||
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
|
||||
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
|
||||
|
||||
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
|
||||
|
||||
## Deploy on Vercel
|
||||
|
||||
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
|
||||
|
||||
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
|
||||
@@ -1,38 +0,0 @@
|
||||
# Digest 1: cmd/, admin/dto, admin/handlers (first 15 files)
|
||||
|
||||
## Systemic Issues (across all admin handlers)
|
||||
- **SQL Injection via SortBy**: Every admin list handler concatenates `filters.SortBy` directly into GORM `Order()` without allowlist validation
|
||||
- **Unchecked Count() errors**: Every paginated handler ignores GORM Count error returns
|
||||
- **Unchecked post-mutation Preload errors**: After Save/Create, handlers reload with Preload but ignore errors
|
||||
- **`binding` vs `validate` tag mismatch**: Some request DTOs use `binding` (Gin) instead of `validate` (Echo)
|
||||
- **Direct DB access**: All admin handlers bypass Service layer, accessing `*gorm.DB` directly
|
||||
- **Unsafe type assertions**: `c.Get(key).(*models.AdminUser)` without comma-ok
|
||||
|
||||
## Per-File Highlights
|
||||
|
||||
### cmd/api/main.go - App entry point, wires dependencies
|
||||
### cmd/worker/main.go - Background worker entry point
|
||||
|
||||
### admin/handlers/admin_user_handler.go (347 lines)
|
||||
- N+1 query: `toUserResponse` does 2 extra DB queries per user (residence count, task count)
|
||||
- Line 64: SortBy SQL injection
|
||||
- Line 173: Unchecked profile creation error (user created without profile)
|
||||
|
||||
### admin/handlers/apple_social_auth_handler.go - CRUD for Apple social auth records
|
||||
- Same systemic SQL injection and unchecked errors
|
||||
|
||||
### admin/handlers/auth_handler.go - Admin login/session management
|
||||
### admin/handlers/auth_token_handler.go - Auth token CRUD
|
||||
### admin/handlers/completion_handler.go - Task completion CRUD
|
||||
### admin/handlers/completion_image_handler.go - Completion image CRUD
|
||||
### admin/handlers/confirmation_code_handler.go - Email confirmation code CRUD
|
||||
### admin/handlers/contractor_handler.go - Contractor CRUD
|
||||
### admin/handlers/dashboard_handler.go - Admin dashboard stats
|
||||
|
||||
### admin/handlers/device_handler.go (317 lines)
|
||||
- Exposes device push tokens (RegistrationID) in API responses
|
||||
- Lines 293-296: Unchecked Count errors in GetStats
|
||||
|
||||
### admin/handlers/document_handler.go (394 lines)
|
||||
- Lines 176-183: Date parsing errors silently ignored
|
||||
- Line 379: Precision loss from decimal.Float64() discarded
|
||||
@@ -1,51 +0,0 @@
|
||||
# Digest 2: admin/handlers (remaining 15 files)
|
||||
|
||||
### admin/handlers/document_image_handler.go (245 lines)
|
||||
- N+1: toResponse queries DB per image in List
|
||||
- Same SortBy SQL injection
|
||||
|
||||
### admin/handlers/feature_benefit_handler.go (231 lines)
|
||||
- `binding` tags instead of `validate` - required fields never enforced
|
||||
|
||||
### admin/handlers/limitations_handler.go (451 lines)
|
||||
- Line 37: Unchecked Create error for default settings
|
||||
- Line 191-197: UpdateTierLimits overwrites ALL fields even for partial updates
|
||||
|
||||
### admin/handlers/lookup_handler.go (877 lines)
|
||||
- **CRITICAL**: Lines 30-32, 50-52, etc.: refreshXxxCache checks `if cache == nil {}` with EMPTY body, then calls cache.CacheXxx() — nil pointer panic when cache is nil
|
||||
- Line 792: Hardcoded join table name "task_contractor_specialties"
|
||||
|
||||
### admin/handlers/notification_handler.go (419 lines)
|
||||
- Line 351-363: HTML template built by string concatenation with user-provided subject/body — XSS in admin emails
|
||||
|
||||
### admin/handlers/notification_prefs_handler.go (347 lines)
|
||||
- Line 154: Unchecked user lookup — deleted user produces zero-value username/email
|
||||
|
||||
### admin/handlers/onboarding_handler.go (343 lines)
|
||||
- Line 304: Internal error details leaked to client
|
||||
|
||||
### admin/handlers/password_reset_code_handler.go (161 lines)
|
||||
- **BUG**: Line 85: `code.ResetToken[:8] + "..." + code.ResetToken[len-4:]` panics if token < 8 chars
|
||||
|
||||
### admin/handlers/promotion_handler.go (304 lines)
|
||||
- `binding` tags: required fields never enforced
|
||||
|
||||
### admin/handlers/residence_handler.go (371 lines)
|
||||
- Lines 121-122: Unchecked Count errors for task/document counts
|
||||
|
||||
### admin/handlers/settings_handler.go (794 lines)
|
||||
- Line 378: Raw SQL execution from seed files (no parameterization)
|
||||
- Line 529-793: ClearAllData is destructive with no double-auth check
|
||||
- Line 536-539: Panic in ClearAllData silently swallowed
|
||||
|
||||
### admin/handlers/share_code_handler.go (225 lines)
|
||||
- Line 155-162: `IsActive` as non-pointer bool — absent field defaults to false, deactivating codes
|
||||
|
||||
### admin/handlers/subscription_handler.go (237 lines)
|
||||
- **BUG**: Line 40-41: JOIN uses "users" but actual table is "auth_user" — query fails on PostgreSQL
|
||||
|
||||
### admin/handlers/task_handler.go (401 lines)
|
||||
- Line 247-296: Admin Create bypasses service layer — no business logic applied
|
||||
|
||||
### admin/handlers/task_template_handler.go (347 lines)
|
||||
- Lines 29-31: Same nil cache panic as lookup_handler.go
|
||||
@@ -1,30 +0,0 @@
|
||||
# Digest 3: admin routes, apperrors, config, database, dto/requests, dto/responses (first half)
|
||||
|
||||
### admin/routes.go (483 lines)
|
||||
- Route ordering: users DELETE "/:id" before "/bulk" — "/bulk" matches as id param
|
||||
- Line 454: Uses os.Getenv instead of Viper config
|
||||
- Line 460-462: url.Parse failure returns silently, no logging
|
||||
- Line 467-469: Proxy errors not surfaced (always returns nil)
|
||||
|
||||
### apperrors/errors.go (98 lines) - Clean. Error types with Wrap/Unwrap.
|
||||
### apperrors/handler.go (67 lines) - c.JSON error returns discarded (minor)
|
||||
|
||||
### config/config.go (427 lines)
|
||||
- Line 339: Hardcoded debug secret key "change-me-in-production-secret-key-12345"
|
||||
- Lines 311-312: Comments say wrong UTC times (says 8PM/9AM, actually 2PM/3PM)
|
||||
|
||||
### database/database.go (468 lines)
|
||||
- **SECURITY**: Line 372-382: Hardcoded bcrypt hash for GoAdmin with password "admin" — migration RESETS password every run
|
||||
- **SECURITY**: Line 447-463: Hardcoded admin@mycrib.com / admin123 — password reset on every migration
|
||||
- Line 182+: Multiple db.Exec errors unchecked for index creation
|
||||
- Line 100-102: WithTransaction coupled to global db variable
|
||||
|
||||
### dto/requests/auth.go (66 lines) - LoginRequest min=1 password (intentional for login)
|
||||
### dto/requests/contractor.go (37 lines) - Rating has no min/max validation bounds
|
||||
### dto/requests/document.go (47 lines) - ImageURLs no length limit, Description no max length
|
||||
### dto/requests/residence.go (59 lines) - Bedrooms/SquareFootage accept negative values, ExpiresInHours no validation
|
||||
### dto/requests/task.go (110 lines) - Rating no bounds, ImageURLs no length limit, CustomIntervalDays no min
|
||||
|
||||
### dto/responses/auth.go (190 lines) - Clean. Proper nil checks on Profile.
|
||||
### dto/responses/contractor.go (131 lines) - TaskCount depends on preloaded Tasks association
|
||||
### dto/responses/document.go (126 lines) - Line 101: CreatedBy accessed as value type (fragile if changed to pointer)
|
||||
@@ -1,48 +0,0 @@
|
||||
# Digest 4: dto/responses (remaining), echohelpers, handlers (first half)
|
||||
|
||||
### dto/responses/residence.go (215 lines) - NewResidenceResponse no nil check on param. Owner zero-value if not preloaded.
|
||||
### dto/responses/task_template.go (135 lines) - No nil check on template param
|
||||
### dto/responses/task.go (399 lines) - No nil checks on params in factory functions
|
||||
### dto/responses/user.go (20 lines) - Clean data types
|
||||
### echohelpers/helpers.go (46 lines) - Clean utilities
|
||||
### echohelpers/pagination.go (33 lines) - Clean, properly bounded
|
||||
|
||||
### handlers/auth_handler.go (379 lines)
|
||||
- **ARCHITECTURE**: Lines 83, 178, 207, 241, 329, 370: SIX goroutine spawns for email — violates "no goroutines in handlers" rule
|
||||
- Line 308-312: AppError constructed directly instead of factory function
|
||||
|
||||
### handlers/contractor_handler.go (154 lines)
|
||||
- Line 28+: Unchecked type assertions throughout (7 instances)
|
||||
- Line 31: Raw err.Error() returned to client
|
||||
- Line 55: CreateContractor missing c.Validate() call
|
||||
|
||||
### handlers/document_handler.go (336 lines)
|
||||
- Line 37+: Unchecked type assertions (10 instances)
|
||||
- Line 92-93: Raw error leaked to client
|
||||
- Line 137: No DocumentType validation — any string accepted
|
||||
- Lines 187, 217: Missing c.Validate() calls
|
||||
|
||||
### handlers/media_handler.go (172 lines)
|
||||
- **SECURITY**: Line 156-171: resolveFilePath uses filepath.Join with user-influenced data — PATH TRAVERSAL vulnerability. TrimPrefix doesn't sanitize ../
|
||||
- Line 19-22: Handler accesses repositories directly, bypasses service layer
|
||||
|
||||
### handlers/notification_handler.go (200 lines)
|
||||
- Line 29-40: No upper bound on limit — unbounded query with limit=999999999
|
||||
- Line 168: Silent default to "ios" platform
|
||||
|
||||
### handlers/residence_handler.go (365 lines)
|
||||
- Line 38+: Unchecked type assertions (14 instances)
|
||||
- Lines 187, 209, 303: Bind errors silently discarded
|
||||
- Line 224: JoinWithCode missing c.Validate()
|
||||
|
||||
### handlers/static_data_handler.go (152 lines) - Uses interface{} instead of concrete types
|
||||
### handlers/subscription_handler.go (176 lines) - Lines 97, 150: Missing c.Validate() calls
|
||||
- Line 159-163: RestoreSubscription doesn't validate receipt/transaction ID presence
|
||||
|
||||
### handlers/subscription_webhook_handler.go (821 lines)
|
||||
- **SECURITY**: Line 190-192: Apple JWS payload decoded WITHOUT signature verification
|
||||
- **SECURITY**: Line 787-793: VerifyGooglePubSubToken ALWAYS returns true — webhook unauthenticated
|
||||
- Line 639-643: Subscription duration guessed by string matching product ID
|
||||
- Line 657, 694: Hardcoded 1-month extension regardless of actual plan
|
||||
- Line 759, 772: Unchecked type assertions in VerifyAppleSignature
|
||||
- Line 162: Apple renewal info error silently discarded
|
||||
@@ -1,45 +0,0 @@
|
||||
# Digest 5: handlers (remaining), i18n, middleware, models (first half)
|
||||
|
||||
### handlers/task_handler.go (440 lines)
|
||||
- Line 35+: Unchecked type assertions (18 locations)
|
||||
- Line 42: Fire-and-forget goroutine for UpdateUserTimezone — no error handling, no context
|
||||
- Lines 112-115, 134-137: Missing c.Validate() calls
|
||||
- Line 317: 32MB multipart limit with no per-file size check
|
||||
|
||||
### handlers/task_template_handler.go (98 lines)
|
||||
- Line 59: No max length on search query — slow LIKE queries possible
|
||||
|
||||
### handlers/tracking_handler.go (46 lines)
|
||||
- Line 25: Package-level base64 decode error discarded
|
||||
- Lines 34-36: Fire-and-forget goroutine — violates no-goroutines rule
|
||||
|
||||
### handlers/upload_handler.go (93 lines)
|
||||
- Line 31: User-controlled `category` param passed to storage — potential path traversal
|
||||
- Line 80: `binding` tag instead of `validate`
|
||||
- No file type or size validation at handler level
|
||||
|
||||
### handlers/user_handler.go (76 lines) - Unchecked type assertions
|
||||
|
||||
### i18n/i18n.go (87 lines)
|
||||
- Line 16: Global Bundle is nil until Init() — NewLocalizer dereferences without nil check
|
||||
- Line 37: MustParseMessageFileBytes panics on malformed translation files
|
||||
- Line 83: MustT panics on missing translations
|
||||
|
||||
### i18n/middleware.go (127 lines) - Clean
|
||||
|
||||
### middleware/admin_auth.go (133 lines)
|
||||
- **SECURITY**: Line 50: Admin JWT accepted via query param — tokens leak into server/proxy logs
|
||||
- Line 124: Unchecked type assertion
|
||||
|
||||
### middleware/auth.go (229 lines)
|
||||
- **BUG**: Line 66: `token[:8]` panics if token is fewer than 8 characters
|
||||
- Line 104: cacheUserID error silently discarded
|
||||
- Line 209: Unchecked type assertion
|
||||
|
||||
### middleware/logger.go (54 lines) - Clean
|
||||
### middleware/request_id.go (44 lines) - Line 21: Client-supplied X-Request-ID accepted without validation (log injection)
|
||||
### middleware/timezone.go (101 lines) - Lines 88, 99: Unchecked type assertions
|
||||
|
||||
### models/admin.go (64 lines) - Line 38: No max password length check; bcrypt truncates at 72 bytes
|
||||
### models/base.go (39 lines) - Clean GORM hooks
|
||||
### models/contractor.go (54 lines) - *float64 mapped to decimal(2,1) — minor precision mismatch
|
||||
@@ -1,40 +0,0 @@
|
||||
# Digest 6: models (remaining), monitoring
|
||||
|
||||
### models/document.go (100 lines) - Clean
|
||||
### models/notification.go (141 lines) - Clean
|
||||
### models/onboarding_email.go (35 lines) - Clean
|
||||
### models/reminder_log.go (92 lines) - Clean
|
||||
|
||||
### models/residence.go (106 lines)
|
||||
- Lines 65-70: GetAllUsers/HasAccess assumes Owner and Users are preloaded — returns wrong results if not
|
||||
|
||||
### models/subscription.go (169 lines)
|
||||
- Lines 57-65: IsActive()/IsPro() don't account for IsFree admin override field — misleading method names
|
||||
|
||||
### models/task_template.go (23 lines) - Clean
|
||||
### models/task.go (317 lines)
|
||||
- Lines 189-264: GetKanbanColumnWithTimezone duplicates categorization chain logic — maintenance drift risk
|
||||
- Lines 158-182: IsDueSoon uses time.Now() internally — non-deterministic, harder to test
|
||||
|
||||
### models/user.go (268 lines)
|
||||
- Line 101: crypto/rand.Read error unchecked (safe in practice since Go 1.20)
|
||||
- Line 164-172: GenerateConfirmationCode has slight distribution bias (negligible)
|
||||
|
||||
### monitoring/buffer.go (166 lines) - Line 75: Corrupted Redis data silently dropped
|
||||
### monitoring/collector.go (201 lines)
|
||||
- Line 82: cpu.Percent blocks 1 second per collection
|
||||
- Line 96-110: ReadMemStats called TWICE per cycle (also in collectRuntime)
|
||||
|
||||
### monitoring/handler.go (195 lines)
|
||||
- **SECURITY**: Line 19-22: WebSocket CheckOrigin always returns true
|
||||
- **BUG**: Line 117-119: After upgrader.Upgrade fails, execution continues to conn.Close() — nil pointer panic
|
||||
- **BUG**: Line 177: Missing return after ctx.Done() — goroutine spins
|
||||
- Lines 183-184: GetAllStats error silently ignored
|
||||
- Line 192: WriteJSON error unchecked
|
||||
|
||||
### monitoring/middleware.go (220 lines) - Clean
|
||||
### monitoring/models.go (129 lines) - Clean
|
||||
|
||||
### monitoring/service.go (196 lines)
|
||||
- Line 121: Hardcoded primary key 1 for singleton settings row
|
||||
- Line 191-194: MetricsMiddleware returns nil when httpCollector is nil — caller must nil-check
|
||||
@@ -1,57 +0,0 @@
|
||||
# Digest 7: monitoring/writer, notifications, push, repositories
|
||||
|
||||
### monitoring/writer.go (96 lines)
|
||||
- Line 90-92: Unbounded fire-and-forget goroutines for Redis push — no rate limiting
|
||||
|
||||
### notifications/reminder_config.go (64 lines) - Clean config data
|
||||
### notifications/reminder_schedule.go (199 lines)
|
||||
- Line 112: Integer truncation of float division — DST off-by-one possible
|
||||
- Line 148-161: Custom itoa reimplements strconv.Itoa
|
||||
|
||||
### push/apns.go (209 lines)
|
||||
- Line 44: Double-negative logic — both Production=false and Sandbox=false defaults to production
|
||||
|
||||
### push/client.go (158 lines)
|
||||
- Line 89-105: SendToAll last-error-wins — cannot tell which platform failed
|
||||
- Line 150-157: HealthCheck always returns nil — useless health check
|
||||
|
||||
### push/fcm.go (140 lines)
|
||||
- Line 16: Legacy FCM HTTP API (deprecated by Google)
|
||||
- Line 119-126: If FCM returns fewer results than tokens, index out of bounds panic
|
||||
|
||||
### repositories/admin_repo.go (108 lines)
|
||||
- Line 92: Negative page produces negative offset
|
||||
|
||||
### repositories/contractor_repo.go (166 lines)
|
||||
- **RACE**: Line 89-101: ToggleFavorite read-then-write without transaction
|
||||
- Line 91: ToggleFavorite doesn't filter is_active — can toggle deleted contractors
|
||||
|
||||
### repositories/document_repo.go (201 lines)
|
||||
- Line 92: LIKE wildcards in user input not escaped
|
||||
- Line 12: DocumentFilter.ResidenceID field defined but never used
|
||||
|
||||
### repositories/notification_repo.go (267 lines)
|
||||
- **RACE**: Line 137-161: GetOrCreatePreferences race — concurrent calls both create, duplicate key error
|
||||
- Line 143: Uses == instead of errors.Is for ErrRecordNotFound
|
||||
|
||||
### repositories/reminder_repo.go (126 lines)
|
||||
- Line 115-122: rows.Err() not checked after iteration loop
|
||||
|
||||
### repositories/residence_repo.go (344 lines)
|
||||
- Line 272: DeactivateShareCode error silently ignored
|
||||
- Line 298-301: Count error unchecked in generateUniqueCode — potential duplicate codes
|
||||
- Line 125-128: PostgreSQL-specific ON CONFLICT — fails on SQLite in tests
|
||||
|
||||
### repositories/subscription_repo.go (257 lines)
|
||||
- **RACE**: Line 40: GetOrCreate race condition (same as notification_repo)
|
||||
- Line 66: GORM v1 pattern `gorm:query_option` for FOR UPDATE — may not work in GORM v2
|
||||
- Line 129: LIKE search on receipt data blobs — inefficient, no index
|
||||
- Lines 40, 168, 196: Uses == instead of errors.Is
|
||||
|
||||
### repositories/task_repo.go (765 lines)
|
||||
- Line 707-709: DeleteCompletion ignores image deletion error
|
||||
- Line 62-101: applyFilterOptions applies NO scope when no filter set — could query all tasks
|
||||
|
||||
### repositories/task_template_repo.go (124 lines)
|
||||
- Line 48: LIKE wildcard escape issue
|
||||
- Line 79-81: Save without Omit could corrupt Category/Frequency lookup data
|
||||
@@ -1,35 +0,0 @@
|
||||
# Digest 8: repositories (remaining), router, services (first half)
|
||||
|
||||
### repositories/user_repo.go - Standard GORM CRUD
|
||||
### repositories/webhook_event_repo.go - Webhook event storage
|
||||
|
||||
### router/router.go - Route registration wiring
|
||||
|
||||
### services/apple_auth.go - Apple Sign In JWT validation
|
||||
### services/auth_service.go - Token management, password hashing, email verification
|
||||
|
||||
### services/cache_service.go - Redis caching for lookups
|
||||
### services/contractor_service.go - Contractor CRUD via repository
|
||||
|
||||
### services/document_service.go - Document management
|
||||
### services/email_service.go - SMTP email sending
|
||||
|
||||
### services/google_auth.go - Google OAuth token validation
|
||||
### services/iap_validation.go - Apple/Google receipt validation
|
||||
|
||||
### services/notification_service.go - Push notifications, preferences
|
||||
|
||||
### services/onboarding_email_service.go (371 lines)
|
||||
- **ARCHITECTURE**: Direct *gorm.DB access — bypasses repository layer entirely
|
||||
- Line 43-46: HasSentEmail ignores Count error — could send duplicate emails
|
||||
- Line 128-133: GetEmailStats ignores 4 Count errors
|
||||
- Line 170: Raw SQL references "auth_user" table
|
||||
- Line 354: Delete error silently ignored
|
||||
|
||||
### services/pdf_service.go (179 lines)
|
||||
- **BUG**: Line 131-133: Byte-level truncation of title — breaks multi-byte UTF-8 (CJK, emoji)
|
||||
|
||||
### services/residence_service.go (648 lines)
|
||||
- Line 155: TODO comment — subscription tier limit check commented out (free tier bypass)
|
||||
- Line 447-450: Empty if block — DeactivateShareCode error completely ignored
|
||||
- Line 625: Status only set for in-progress tasks; all others have empty string
|
||||
@@ -1,49 +0,0 @@
|
||||
# Digest 9: services (remaining), task package, testutil, validator, worker, pkg
|
||||
|
||||
### services/storage_service.go (184 lines)
|
||||
- Line 75: UUID truncated to 8 chars — increased collision risk
|
||||
- **SECURITY**: Line 137-138: filepath.Abs errors ignored — path traversal check could be bypassed
|
||||
|
||||
### services/subscription_service.go (659 lines)
|
||||
- **PERFORMANCE**: Line 186-204: N+1 queries in getUserUsage — 3 queries per residence
|
||||
- **SECURITY/BUSINESS**: Line 371: Apple validation failure grants 1-month free Pro
|
||||
- **SECURITY/BUSINESS**: Line 381: Apple validation not configured grants 1-year free Pro
|
||||
- **SECURITY/BUSINESS**: Line 429, 449: Same for Google — errors/misconfiguration grant free Pro
|
||||
- Line 7: Uses stdlib "log" instead of zerolog
|
||||
|
||||
### services/task_button_types.go (85 lines) - Clean, uses predicates correctly
|
||||
### services/task_service.go (1092 lines)
|
||||
- **DATA INTEGRITY**: Line 601: If task update fails after completion creation, error only logged not returned — stale NextDueDate/InProgress
|
||||
- Line 735: Goroutine in QuickComplete (service method) — inconsistent with synchronous CreateCompletion
|
||||
- Line 773: Unbounded goroutine creation per user for notifications
|
||||
- Line 790: Fail-open email notification on error (intentional but risky)
|
||||
- **SECURITY**: Line 857-862: resolveImageFilePath has NO path traversal validation
|
||||
|
||||
### services/task_template_service.go (70 lines) - Errors returned raw (not wrapped with apperrors)
|
||||
### services/user_service.go (88 lines) - Returns nil instead of empty slice (JSON null vs [])
|
||||
|
||||
### task/categorization/chain.go (359 lines) - Clean chain-of-responsibility
|
||||
### task/predicates/predicates.go (217 lines)
|
||||
- Line 210: IsRecurring requires Frequency preloaded — returns false without it
|
||||
|
||||
### task/scopes/scopes.go (270 lines)
|
||||
- Line 118: ScopeOverdue doesn't exclude InProgress — differs from categorization chain
|
||||
|
||||
### task/task.go (261 lines) - Clean facade re-exports
|
||||
|
||||
### testutil/testutil.go (359 lines)
|
||||
- Line 86: json.Marshal error ignored in MakeRequest
|
||||
- Line 92: http.NewRequest error ignored
|
||||
|
||||
### validator/validator.go (103 lines) - Clean
|
||||
|
||||
### worker/jobs/email_jobs.go (118 lines) - Clean
|
||||
### worker/jobs/handler.go (810 lines)
|
||||
- Lines 95-106, 193-204: Direct DB access bypasses repository layer
|
||||
- Line 627-635: Raw SQL with fmt.Sprintf (not currently user-supplied but fragile)
|
||||
- Line 154, 251: O(N*M) lookup instead of map
|
||||
|
||||
### worker/scheduler.go (240 lines)
|
||||
- Line 200-212: Cron schedules at fixed UTC times may conflict with smart reminder system — potential duplicate notifications
|
||||
|
||||
### pkg/utils/logger.go (132 lines) - Panic recovery bypasses apperrors.HTTPErrorHandler
|
||||
@@ -1,417 +0,0 @@
|
||||
# Gin to Echo Framework Migration Guide
|
||||
|
||||
This document outlines the migration of the MyCrib Go API from Gin to Echo v4 with direct go-playground/validator integration.
|
||||
|
||||
## Overview
|
||||
|
||||
| Aspect | Before | After |
|
||||
|--------|--------|-------|
|
||||
| Framework | Gin v1.10 | Echo v4.11 |
|
||||
| Validation | Gin's binding wrapper | Direct go-playground/validator |
|
||||
| Validation tags | `binding:"..."` | `validate:"..."` |
|
||||
| Error format | Inconsistent | Structured field-level |
|
||||
|
||||
## Scope
|
||||
|
||||
- **56 files** requiring modification
|
||||
- **110+ routes** across public and admin APIs
|
||||
- **45 handlers** (17 core + 28 admin)
|
||||
- **5 middleware** files
|
||||
|
||||
---
|
||||
|
||||
## API Mapping Reference
|
||||
|
||||
### Context Methods
|
||||
|
||||
| Gin | Echo | Notes |
|
||||
|-----|------|-------|
|
||||
| `c.ShouldBindJSON(&req)` | `c.Bind(&req)` | Bind only, no validation |
|
||||
| `c.ShouldBindJSON(&req)` | `c.Validate(&req)` | Validation only (call after Bind) |
|
||||
| `c.Param("id")` | `c.Param("id")` | Same |
|
||||
| `c.Query("name")` | `c.QueryParam("name")` | Different method name |
|
||||
| `c.DefaultQuery("k","v")` | Custom helper | No built-in equivalent |
|
||||
| `c.PostForm("field")` | `c.FormValue("field")` | Different method name |
|
||||
| `c.FormFile("file")` | `c.FormFile("file")` | Same |
|
||||
| `c.Get(key)` | `c.Get(key)` | Same |
|
||||
| `c.Set(key, val)` | `c.Set(key, val)` | Same |
|
||||
| `c.MustGet(key)` | `c.Get(key)` | No MustGet, add nil check |
|
||||
| `c.JSON(status, obj)` | `return c.JSON(status, obj)` | Must return |
|
||||
| `c.AbortWithStatusJSON()` | `return c.JSON()` | Return-based flow |
|
||||
| `c.Status(200)` | `return c.NoContent(200)` | Different method |
|
||||
| `c.GetHeader("X-...")` | `c.Request().Header.Get()` | Access via Request |
|
||||
| `c.ClientIP()` | `c.RealIP()` | Different method name |
|
||||
| `c.File(path)` | `return c.File(path)` | Must return |
|
||||
| `gin.H{...}` | `echo.Map{...}` | Or `map[string]any{}` |
|
||||
|
||||
### Handler Signature
|
||||
|
||||
```go
|
||||
// Gin - void return
|
||||
func (h *Handler) Method(c *gin.Context) {
|
||||
// ...
|
||||
c.JSON(200, response)
|
||||
}
|
||||
|
||||
// Echo - error return (MUST return)
|
||||
func (h *Handler) Method(c echo.Context) error {
|
||||
// ...
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
```
|
||||
|
||||
### Middleware Signature
|
||||
|
||||
```go
|
||||
// Gin
|
||||
func MyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// before
|
||||
c.Next()
|
||||
// after
|
||||
}
|
||||
}
|
||||
|
||||
// Echo
|
||||
func MyMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// before
|
||||
err := next(c)
|
||||
// after
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Validation Changes
|
||||
|
||||
### Tag Migration
|
||||
|
||||
Change all `binding:` tags to `validate:` tags:
|
||||
|
||||
```go
|
||||
// Before
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
}
|
||||
|
||||
// After
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Password string `json:"password" validate:"required,min=8"`
|
||||
}
|
||||
```
|
||||
|
||||
### Supported Validation Tags
|
||||
|
||||
| Tag | Description |
|
||||
|-----|-------------|
|
||||
| `required` | Field must be present and non-zero |
|
||||
| `required_without=Field` | Required if other field is empty |
|
||||
| `omitempty` | Skip validation if empty |
|
||||
| `email` | Must be valid email format |
|
||||
| `min=N` | Minimum length/value |
|
||||
| `max=N` | Maximum length/value |
|
||||
| `len=N` | Exact length |
|
||||
| `oneof=a b c` | Must be one of listed values |
|
||||
| `url` | Must be valid URL |
|
||||
| `uuid` | Must be valid UUID |
|
||||
|
||||
### New Error Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "Validation failed",
|
||||
"fields": {
|
||||
"email": {
|
||||
"message": "Must be a valid email address",
|
||||
"tag": "email"
|
||||
},
|
||||
"password": {
|
||||
"message": "Must be at least 8 characters",
|
||||
"tag": "min"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Mobile clients must update** error parsing to handle the `fields` object.
|
||||
|
||||
---
|
||||
|
||||
## Handler Migration Pattern
|
||||
|
||||
### Before (Gin)
|
||||
|
||||
```go
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req requests.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
|
||||
Error: "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user := c.MustGet(middleware.AuthUserKey).(*models.User)
|
||||
|
||||
response, err := h.authService.Login(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
```
|
||||
|
||||
### After (Echo)
|
||||
|
||||
```go
|
||||
func (h *AuthHandler) Login(c echo.Context) error {
|
||||
var req requests.LoginRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, responses.ErrorResponse{
|
||||
Error: "Invalid request body",
|
||||
})
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
|
||||
response, err := h.authService.Login(&req)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Middleware Migration Examples
|
||||
|
||||
### Auth Middleware
|
||||
|
||||
```go
|
||||
// Before (Gin)
|
||||
func (m *AuthMiddleware) TokenAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token, err := extractToken(c)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
user, err := m.validateToken(token)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
return
|
||||
}
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// After (Echo)
|
||||
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
token, err := extractToken(c)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
}
|
||||
user, err := m.validateToken(token)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Invalid token"})
|
||||
}
|
||||
c.Set(AuthUserKey, user)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Timezone Middleware
|
||||
|
||||
```go
|
||||
// Before (Gin)
|
||||
func TimezoneMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tzName := c.GetHeader(TimezoneHeader)
|
||||
loc := parseTimezone(tzName)
|
||||
c.Set(TimezoneKey, loc)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// After (Echo)
|
||||
func TimezoneMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
tzName := c.Request().Header.Get(TimezoneHeader)
|
||||
loc := parseTimezone(tzName)
|
||||
c.Set(TimezoneKey, loc)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Router Setup
|
||||
|
||||
### Before (Gin)
|
||||
|
||||
```go
|
||||
func SetupRouter(deps *Dependencies) *gin.Engine {
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(gin.Logger())
|
||||
r.Use(cors.New(cors.Config{...}))
|
||||
|
||||
api := r.Group("/api")
|
||||
api.POST("/auth/login/", authHandler.Login)
|
||||
|
||||
protected := api.Group("")
|
||||
protected.Use(authMiddleware.TokenAuth())
|
||||
protected.GET("/residences/", residenceHandler.List)
|
||||
|
||||
return r
|
||||
}
|
||||
```
|
||||
|
||||
### After (Echo)
|
||||
|
||||
```go
|
||||
func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
e := echo.New()
|
||||
e.HideBanner = true
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
|
||||
// Trailing slash handling
|
||||
e.Pre(middleware.AddTrailingSlash())
|
||||
|
||||
e.Use(middleware.Recover())
|
||||
e.Use(middleware.Logger())
|
||||
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{...}))
|
||||
|
||||
api := e.Group("/api")
|
||||
api.POST("/auth/login/", authHandler.Login)
|
||||
|
||||
protected := api.Group("")
|
||||
protected.Use(authMiddleware.TokenAuth())
|
||||
protected.GET("/residences/", residenceHandler.List)
|
||||
|
||||
return e
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Helper Functions
|
||||
|
||||
### DefaultQuery Helper
|
||||
|
||||
```go
|
||||
// internal/echohelpers/helpers.go
|
||||
func DefaultQuery(c echo.Context, key, defaultValue string) string {
|
||||
if val := c.QueryParam(key); val != "" {
|
||||
return val
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
```
|
||||
|
||||
### Safe Context Get
|
||||
|
||||
```go
|
||||
// internal/middleware/helpers.go
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
val := c.Get(AuthUserKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
user, ok := val.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return nil, echo.NewHTTPError(http.StatusUnauthorized, "Authentication required")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing Changes
|
||||
|
||||
### Test Setup
|
||||
|
||||
```go
|
||||
// Before (Gin)
|
||||
func SetupTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
return gin.New()
|
||||
}
|
||||
|
||||
// After (Echo)
|
||||
func SetupTestRouter() *echo.Echo {
|
||||
e := echo.New()
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
return e
|
||||
}
|
||||
```
|
||||
|
||||
### Making Test Requests
|
||||
|
||||
```go
|
||||
// Before (Gin)
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/auth/login/", body)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// After (Echo) - Same pattern works
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/auth/login/", body)
|
||||
e.ServeHTTP(rec, req)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **All handlers must return error** - Echo uses return-based flow
|
||||
2. **Trailing slashes** - Use `middleware.AddTrailingSlash()` to maintain API compatibility
|
||||
3. **Type assertions** - Always add nil checks when using `c.Get()`
|
||||
4. **CORS** - Use Echo's built-in CORS middleware
|
||||
5. **Bind vs Validate** - Echo separates these; call both for full validation
|
||||
|
||||
---
|
||||
|
||||
## Files Modified
|
||||
|
||||
| Category | Count |
|
||||
|----------|-------|
|
||||
| New files | 2 (`validator/`, `echohelpers/`) |
|
||||
| DTOs | 6 |
|
||||
| Middleware | 5 |
|
||||
| Core handlers | 14 |
|
||||
| Admin handlers | 28 |
|
||||
| Router | 2 |
|
||||
| Tests | 6 |
|
||||
| **Total** | **63** |
|
||||
@@ -2350,6 +2350,121 @@ paths:
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
|
||||
/subscription/checkout/:
|
||||
post:
|
||||
tags: [Subscriptions]
|
||||
operationId: createCheckoutSession
|
||||
summary: Create a Stripe Checkout session for web subscription purchase
|
||||
security:
|
||||
- tokenAuth: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [price_id, success_url, cancel_url]
|
||||
properties:
|
||||
price_id:
|
||||
type: string
|
||||
description: Stripe Price ID
|
||||
success_url:
|
||||
type: string
|
||||
format: uri
|
||||
cancel_url:
|
||||
type: string
|
||||
format: uri
|
||||
responses:
|
||||
'200':
|
||||
description: Checkout session created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
checkout_url:
|
||||
type: string
|
||||
format: uri
|
||||
'400':
|
||||
$ref: '#/components/responses/Error'
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
'409':
|
||||
description: Already subscribed on another platform
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
existing_platform:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
|
||||
/subscription/portal/:
|
||||
post:
|
||||
tags: [Subscriptions]
|
||||
operationId: createPortalSession
|
||||
summary: Create a Stripe Customer Portal session for managing web subscriptions
|
||||
security:
|
||||
- tokenAuth: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [return_url]
|
||||
properties:
|
||||
return_url:
|
||||
type: string
|
||||
format: uri
|
||||
responses:
|
||||
'200':
|
||||
description: Portal session created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
portal_url:
|
||||
type: string
|
||||
format: uri
|
||||
'400':
|
||||
$ref: '#/components/responses/Error'
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
|
||||
/subscription/webhook/stripe/:
|
||||
post:
|
||||
tags: [Subscriptions]
|
||||
operationId: handleStripeWebhook
|
||||
summary: Handle Stripe webhook events (server-to-server)
|
||||
description: |
|
||||
Receives Stripe webhook events for subscription lifecycle management.
|
||||
Verifies the webhook signature using the configured signing secret.
|
||||
No auth token required — uses Stripe signature verification.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
responses:
|
||||
'200':
|
||||
description: Webhook processed successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
received:
|
||||
type: boolean
|
||||
'400':
|
||||
$ref: '#/components/responses/Error'
|
||||
|
||||
# ===========================================================================
|
||||
# Uploads
|
||||
# ===========================================================================
|
||||
@@ -4434,6 +4549,12 @@ components:
|
||||
SubscriptionStatusResponse:
|
||||
type: object
|
||||
properties:
|
||||
tier:
|
||||
type: string
|
||||
description: 'Subscription tier (free or pro)'
|
||||
is_active:
|
||||
type: boolean
|
||||
description: Whether the subscription is currently active
|
||||
subscribed_at:
|
||||
type: string
|
||||
format: date-time
|
||||
@@ -4444,6 +4565,20 @@ components:
|
||||
nullable: true
|
||||
auto_renew:
|
||||
type: boolean
|
||||
trial_start:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
trial_end:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
trial_active:
|
||||
type: boolean
|
||||
subscription_source:
|
||||
type: string
|
||||
nullable: true
|
||||
description: 'Platform source of the active subscription (ios, android, stripe, or null)'
|
||||
usage:
|
||||
$ref: '#/components/schemas/UsageResponse'
|
||||
limits:
|
||||
|
||||
1
go.mod
1
go.mod
@@ -71,6 +71,7 @@ require (
|
||||
github.com/spf13/afero v1.14.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stripe/stripe-go/v81 v81.4.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
|
||||
7
go.sum
7
go.sum
@@ -156,6 +156,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
@@ -188,6 +190,7 @@ golang.org/x/crypto v0.0.0-20170512130425-ab89591268e0/go.mod h1:6SG95UA2DQfeDnf
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
@@ -196,7 +199,9 @@ golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwE
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -206,8 +211,10 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Go Backend Hardening Audit Report
|
||||
|
||||
## Audit Sources
|
||||
- 9 mapper agents (100% file coverage)
|
||||
- 8 specialized domain auditors (parallel)
|
||||
- 1 cross-cutting deep audit (parallel)
|
||||
- Total source files: 136 (excluding 27 test files)
|
||||
|
||||
---
|
||||
|
||||
## CRITICAL — Will crash or lose data
|
||||
|
||||
## BUG — Incorrect behavior
|
||||
|
||||
## SILENT FAILURE — Error swallowed or ignored
|
||||
|
||||
## RACE CONDITION — Concurrency issue
|
||||
|
||||
## LOGIC ERROR — Code doesn't match intent
|
||||
|
||||
## PERFORMANCE — Unnecessary cost
|
||||
|
||||
## SECURITY — Vulnerability or exposure
|
||||
|
||||
## AUTHORIZATION — Access control gap
|
||||
|
||||
## DATA INTEGRITY — GORM / database issue
|
||||
|
||||
## API CONTRACT — Request/response issue
|
||||
|
||||
## ARCHITECTURE — Layer or pattern violation
|
||||
|
||||
## FRAGILE — Works now but will break easily
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
@@ -234,11 +234,13 @@ type UpdateNotificationRequest struct {
|
||||
// SubscriptionFilters holds subscription-specific filter parameters
|
||||
type SubscriptionFilters struct {
|
||||
PaginationParams
|
||||
UserID *uint `form:"user_id"`
|
||||
Tier *string `form:"tier"`
|
||||
Platform *string `form:"platform"`
|
||||
AutoRenew *bool `form:"auto_renew"`
|
||||
Active *bool `form:"active"`
|
||||
UserID *uint `form:"user_id"`
|
||||
Tier *string `form:"tier"`
|
||||
Platform *string `form:"platform"`
|
||||
AutoRenew *bool `form:"auto_renew"`
|
||||
Active *bool `form:"active"`
|
||||
HasStripe *bool `form:"has_stripe"`
|
||||
TrialActive *bool `form:"trial_active"`
|
||||
}
|
||||
|
||||
// UpdateSubscriptionRequest for updating a subscription
|
||||
@@ -250,6 +252,14 @@ type UpdateSubscriptionRequest struct {
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
// Stripe fields
|
||||
StripeCustomerID *string `json:"stripe_customer_id"`
|
||||
StripeSubscriptionID *string `json:"stripe_subscription_id"`
|
||||
StripePriceID *string `json:"stripe_price_id"`
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start"`
|
||||
TrialEnd *string `json:"trial_end"`
|
||||
TrialUsed *bool `json:"trial_used"`
|
||||
}
|
||||
|
||||
// CreateResidenceRequest for creating a new residence
|
||||
|
||||
@@ -264,7 +264,16 @@ type SubscriptionResponse struct {
|
||||
SubscribedAt *string `json:"subscribed_at,omitempty"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
CancelledAt *string `json:"cancelled_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
// Stripe fields
|
||||
StripeCustomerID *string `json:"stripe_customer_id,omitempty"`
|
||||
StripeSubscriptionID *string `json:"stripe_subscription_id,omitempty"`
|
||||
StripePriceID *string `json:"stripe_price_id,omitempty"`
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start,omitempty"`
|
||||
TrialEnd *string `json:"trial_end,omitempty"`
|
||||
TrialUsed bool `json:"trial_used"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// SubscriptionDetailResponse includes more details for single subscription view
|
||||
|
||||
@@ -30,6 +30,8 @@ func NewAdminSettingsHandler(db *gorm.DB) *AdminSettingsHandler {
|
||||
type SettingsResponse struct {
|
||||
EnableLimitations bool `json:"enable_limitations"`
|
||||
EnableMonitoring bool `json:"enable_monitoring"`
|
||||
TrialEnabled bool `json:"trial_enabled"`
|
||||
TrialDurationDays int `json:"trial_duration_days"`
|
||||
}
|
||||
|
||||
// GetSettings handles GET /api/admin/settings
|
||||
@@ -38,7 +40,13 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
||||
if err := h.db.First(&settings, 1).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// Create default settings
|
||||
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false, EnableMonitoring: true}
|
||||
settings = models.SubscriptionSettings{
|
||||
ID: 1,
|
||||
EnableLimitations: false,
|
||||
EnableMonitoring: true,
|
||||
TrialEnabled: true,
|
||||
TrialDurationDays: 14,
|
||||
}
|
||||
h.db.Create(&settings)
|
||||
} else {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
||||
@@ -48,6 +56,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, SettingsResponse{
|
||||
EnableLimitations: settings.EnableLimitations,
|
||||
EnableMonitoring: settings.EnableMonitoring,
|
||||
TrialEnabled: settings.TrialEnabled,
|
||||
TrialDurationDays: settings.TrialDurationDays,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -55,6 +65,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
||||
type UpdateSettingsRequest struct {
|
||||
EnableLimitations *bool `json:"enable_limitations"`
|
||||
EnableMonitoring *bool `json:"enable_monitoring"`
|
||||
TrialEnabled *bool `json:"trial_enabled"`
|
||||
TrialDurationDays *int `json:"trial_duration_days"`
|
||||
}
|
||||
|
||||
// UpdateSettings handles PUT /api/admin/settings
|
||||
@@ -67,7 +79,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||
var settings models.SubscriptionSettings
|
||||
if err := h.db.First(&settings, 1).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
settings = models.SubscriptionSettings{ID: 1, EnableMonitoring: true}
|
||||
settings = models.SubscriptionSettings{
|
||||
ID: 1,
|
||||
EnableMonitoring: true,
|
||||
TrialEnabled: true,
|
||||
TrialDurationDays: 14,
|
||||
}
|
||||
} else {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
||||
}
|
||||
@@ -81,6 +98,14 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||
settings.EnableMonitoring = *req.EnableMonitoring
|
||||
}
|
||||
|
||||
if req.TrialEnabled != nil {
|
||||
settings.TrialEnabled = *req.TrialEnabled
|
||||
}
|
||||
|
||||
if req.TrialDurationDays != nil {
|
||||
settings.TrialDurationDays = *req.TrialDurationDays
|
||||
}
|
||||
|
||||
if err := h.db.Save(&settings).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update settings"})
|
||||
}
|
||||
@@ -88,6 +113,8 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, SettingsResponse{
|
||||
EnableLimitations: settings.EnableLimitations,
|
||||
EnableMonitoring: settings.EnableMonitoring,
|
||||
TrialEnabled: settings.TrialEnabled,
|
||||
TrialDurationDays: settings.TrialDurationDays,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
@@ -61,6 +62,20 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
|
||||
query = query.Where("expires_at IS NOT NULL AND expires_at <= NOW()")
|
||||
}
|
||||
}
|
||||
if filters.HasStripe != nil {
|
||||
if *filters.HasStripe {
|
||||
query = query.Where("stripe_subscription_id IS NOT NULL")
|
||||
} else {
|
||||
query = query.Where("stripe_subscription_id IS NULL")
|
||||
}
|
||||
}
|
||||
if filters.TrialActive != nil {
|
||||
if *filters.TrialActive {
|
||||
query = query.Where("trial_end IS NOT NULL AND trial_end > NOW()")
|
||||
} else {
|
||||
query = query.Where("trial_end IS NULL OR trial_end <= NOW()")
|
||||
}
|
||||
}
|
||||
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
@@ -137,6 +152,32 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error {
|
||||
if req.IsFree != nil {
|
||||
subscription.IsFree = *req.IsFree
|
||||
}
|
||||
if req.StripeCustomerID != nil {
|
||||
subscription.StripeCustomerID = req.StripeCustomerID
|
||||
}
|
||||
if req.StripeSubscriptionID != nil {
|
||||
subscription.StripeSubscriptionID = req.StripeSubscriptionID
|
||||
}
|
||||
if req.StripePriceID != nil {
|
||||
subscription.StripePriceID = req.StripePriceID
|
||||
}
|
||||
if req.TrialStart != nil {
|
||||
parsed, err := time.Parse(time.RFC3339, *req.TrialStart)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_start format, expected RFC3339"})
|
||||
}
|
||||
subscription.TrialStart = &parsed
|
||||
}
|
||||
if req.TrialEnd != nil {
|
||||
parsed, err := time.Parse(time.RFC3339, *req.TrialEnd)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_end format, expected RFC3339"})
|
||||
}
|
||||
subscription.TrialEnd = &parsed
|
||||
}
|
||||
if req.TrialUsed != nil {
|
||||
subscription.TrialUsed = *req.TrialUsed
|
||||
}
|
||||
|
||||
if err := h.db.Save(&subscription).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update subscription"})
|
||||
@@ -184,29 +225,39 @@ func (h *AdminSubscriptionHandler) GetByUser(c echo.Context) error {
|
||||
// GetStats handles GET /api/admin/subscriptions/stats
|
||||
func (h *AdminSubscriptionHandler) GetStats(c echo.Context) error {
|
||||
var total, free, premium, pro int64
|
||||
var stripeSubscribers, activeTrials int64
|
||||
|
||||
h.db.Model(&models.UserSubscription{}).Count(&total)
|
||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "free").Count(&free)
|
||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&premium)
|
||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&pro)
|
||||
h.db.Model(&models.UserSubscription{}).Where("stripe_subscription_id IS NOT NULL AND tier = ?", "pro").Count(&stripeSubscribers)
|
||||
h.db.Model(&models.UserSubscription{}).Where("trial_end IS NOT NULL AND trial_end > NOW()").Count(&activeTrials)
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"total": total,
|
||||
"free": free,
|
||||
"premium": premium,
|
||||
"pro": pro,
|
||||
"total": total,
|
||||
"free": free,
|
||||
"premium": premium,
|
||||
"pro": pro,
|
||||
"stripe_subscribers": stripeSubscribers,
|
||||
"active_trials": activeTrials,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscription) dto.SubscriptionResponse {
|
||||
response := dto.SubscriptionResponse{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
Tier: string(sub.Tier),
|
||||
Platform: sub.Platform,
|
||||
AutoRenew: sub.AutoRenew,
|
||||
IsFree: sub.IsFree,
|
||||
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
Tier: string(sub.Tier),
|
||||
Platform: sub.Platform,
|
||||
AutoRenew: sub.AutoRenew,
|
||||
IsFree: sub.IsFree,
|
||||
StripeCustomerID: sub.StripeCustomerID,
|
||||
StripeSubscriptionID: sub.StripeSubscriptionID,
|
||||
StripePriceID: sub.StripePriceID,
|
||||
TrialUsed: sub.TrialUsed,
|
||||
TrialActive: sub.IsTrialActive(),
|
||||
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
if sub.User.ID != 0 {
|
||||
@@ -225,6 +276,14 @@ func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscr
|
||||
cancelledAt := sub.CancelledAt.Format("2006-01-02T15:04:05Z")
|
||||
response.CancelledAt = &cancelledAt
|
||||
}
|
||||
if sub.TrialStart != nil {
|
||||
trialStart := sub.TrialStart.Format(time.RFC3339)
|
||||
response.TrialStart = &trialStart
|
||||
}
|
||||
if sub.TrialEnd != nil {
|
||||
trialEnd := sub.TrialEnd.Format(time.RFC3339)
|
||||
response.TrialEnd = &trialEnd
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ type Config struct {
|
||||
GoogleAuth GoogleAuthConfig
|
||||
AppleIAP AppleIAPConfig
|
||||
GoogleIAP GoogleIAPConfig
|
||||
Stripe StripeConfig
|
||||
Features FeatureFlags
|
||||
}
|
||||
|
||||
@@ -104,6 +105,14 @@ type GoogleIAPConfig struct {
|
||||
PackageName string // Android package name (e.g., com.tt.casera)
|
||||
}
|
||||
|
||||
// StripeConfig holds Stripe payment configuration
|
||||
type StripeConfig struct {
|
||||
SecretKey string // Stripe secret API key
|
||||
WebhookSecret string // Stripe webhook endpoint signing secret
|
||||
PriceMonthly string // Stripe Price ID for monthly Pro subscription
|
||||
PriceYearly string // Stripe Price ID for yearly Pro subscription
|
||||
}
|
||||
|
||||
type WorkerConfig struct {
|
||||
// Scheduled job times (UTC)
|
||||
TaskReminderHour int
|
||||
@@ -248,6 +257,12 @@ func Load() (*Config, error) {
|
||||
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
|
||||
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
|
||||
},
|
||||
Stripe: StripeConfig{
|
||||
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
|
||||
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
|
||||
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
|
||||
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
|
||||
},
|
||||
Features: FeatureFlags{
|
||||
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
|
||||
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
|
||||
handler := NewSubscriptionHandler(subscriptionService)
|
||||
handler := NewSubscriptionHandler(subscriptionService, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
|
||||
@@ -13,11 +13,15 @@ import (
|
||||
// SubscriptionHandler handles subscription-related HTTP requests
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *services.SubscriptionService
|
||||
stripeService *services.StripeService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{subscriptionService: subscriptionService}
|
||||
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService, stripeService *services.StripeService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
stripeService: stripeService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSubscription handles GET /api/subscription/
|
||||
@@ -194,3 +198,82 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
||||
"subscription": subscription,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateCheckoutSession handles POST /api/subscription/checkout/
|
||||
// Creates a Stripe Checkout Session for web subscription purchases
|
||||
func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.stripeService == nil {
|
||||
return apperrors.BadRequest("error.stripe_not_configured")
|
||||
}
|
||||
|
||||
// Check if already Pro from another platform
|
||||
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if alreadyPro {
|
||||
return c.JSON(http.StatusConflict, map[string]interface{}{
|
||||
"error": "error.already_subscribed_other_platform",
|
||||
"existing_platform": existingPlatform,
|
||||
"message": "You already have an active Pro subscription via " + existingPlatform + ". Manage it there to avoid double billing.",
|
||||
})
|
||||
}
|
||||
|
||||
var req struct {
|
||||
PriceID string `json:"price_id" validate:"required"`
|
||||
SuccessURL string `json:"success_url" validate:"required,url"`
|
||||
CancelURL string `json:"cancel_url" validate:"required,url"`
|
||||
}
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionURL, err := h.stripeService.CreateCheckoutSession(user.ID, req.PriceID, req.SuccessURL, req.CancelURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"checkout_url": sessionURL,
|
||||
})
|
||||
}
|
||||
|
||||
// CreatePortalSession handles POST /api/subscription/portal/
|
||||
// Creates a Stripe Customer Portal session for managing web subscriptions
|
||||
func (h *SubscriptionHandler) CreatePortalSession(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.stripeService == nil {
|
||||
return apperrors.BadRequest("error.stripe_not_configured")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
ReturnURL string `json:"return_url" validate:"required,url"`
|
||||
}
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
portalURL, err := h.stripeService.CreatePortalSession(user.ID, req.ReturnURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"portal_url": portalURL,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
// SubscriptionWebhookHandler handles subscription webhook callbacks
|
||||
@@ -28,6 +29,7 @@ type SubscriptionWebhookHandler struct {
|
||||
userRepo *repositories.UserRepository
|
||||
webhookEventRepo *repositories.WebhookEventRepository
|
||||
appleRootCerts []*x509.Certificate
|
||||
stripeService *services.StripeService
|
||||
enabled bool
|
||||
}
|
||||
|
||||
@@ -46,6 +48,11 @@ func NewSubscriptionWebhookHandler(
|
||||
}
|
||||
}
|
||||
|
||||
// SetStripeService sets the Stripe service for webhook handling
|
||||
func (h *SubscriptionWebhookHandler) SetStripeService(stripeService *services.StripeService) {
|
||||
h.stripeService = stripeService
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Apple App Store Server Notifications v2
|
||||
// ====================
|
||||
@@ -377,38 +384,30 @@ func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *A
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple expired"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple refund"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple revoke"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple grace period expired"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -705,22 +704,16 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Subscription revoked - immediate downgrade
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Google revoke"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Subscription expired
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Google expired"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -730,6 +723,88 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Multi-Source Downgrade Safety
|
||||
// ====================
|
||||
|
||||
// safeDowngradeToFree checks if the user has active subscriptions from other sources
|
||||
// before downgrading to free. If another source is still active, skip the downgrade.
|
||||
func (h *SubscriptionWebhookHandler) safeDowngradeToFree(userID uint, reason string) error {
|
||||
sub, err := h.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Could not find subscription for multi-source check, proceeding with downgrade")
|
||||
return h.subscriptionRepo.DowngradeToFree(userID)
|
||||
}
|
||||
|
||||
// Check if Stripe subscription is still active
|
||||
if sub.HasStripeSubscription() && sub.Platform != models.PlatformStripe {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Stripe subscription")
|
||||
return nil
|
||||
}
|
||||
// Check if Apple subscription is still active (for Google/Stripe webhooks)
|
||||
if sub.HasAppleSubscription() && sub.Platform != models.PlatformIOS {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Apple subscription")
|
||||
return nil
|
||||
}
|
||||
// Check if Google subscription is still active (for Apple/Stripe webhooks)
|
||||
if sub.HasGoogleSubscription() && sub.Platform != models.PlatformAndroid {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Google subscription")
|
||||
return nil
|
||||
}
|
||||
// Check if trial is still active
|
||||
if sub.IsTrialActive() {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active trial")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: User downgraded to free (no other active sources)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Stripe Webhooks
|
||||
// ====================
|
||||
|
||||
// HandleStripeWebhook handles POST /api/subscription/webhook/stripe/
|
||||
func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error {
|
||||
if !h.enabled {
|
||||
log.Info().Msg("Stripe Webhook: webhooks disabled by feature flag")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||
}
|
||||
|
||||
if h.stripeService == nil {
|
||||
log.Warn().Msg("Stripe Webhook: Stripe service not configured")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||
}
|
||||
|
||||
signature := c.Request().Header.Get("Stripe-Signature")
|
||||
if signature == "" {
|
||||
log.Warn().Msg("Stripe Webhook: Missing Stripe-Signature header")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "missing signature"})
|
||||
}
|
||||
|
||||
if err := h.stripeService.HandleWebhookEvent(body, signature); err != nil {
|
||||
log.Error().Err(err).Msg("Stripe Webhook: Failed to process webhook")
|
||||
// Still return 200 to prevent Stripe from retrying on business logic errors
|
||||
// Only return error for signature verification failures
|
||||
if strings.Contains(err.Error(), "signature") {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signature"})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Signature Verification (Optional but Recommended)
|
||||
// ====================
|
||||
|
||||
@@ -98,6 +98,10 @@ var specEndpointsKMPSkips = map[routeKey]bool{
|
||||
{Method: "POST", Path: "/notifications/devices/"}: true, // KMP uses /notifications/devices/register/
|
||||
{Method: "POST", Path: "/notifications/devices/unregister/"}: true, // KMP uses DELETE on device ID
|
||||
{Method: "PATCH", Path: "/notifications/preferences/"}: true, // KMP uses PUT
|
||||
// Stripe web-only and server-to-server endpoints — not implemented in mobile KMP
|
||||
{Method: "POST", Path: "/subscription/checkout/"}: true, // Web-only (Stripe Checkout)
|
||||
{Method: "POST", Path: "/subscription/portal/"}: true, // Web-only (Stripe Customer Portal)
|
||||
{Method: "POST", Path: "/subscription/webhook/stripe/"}: true, // Server-to-server (Stripe webhook)
|
||||
}
|
||||
|
||||
// kmpRouteAliases maps KMP paths to their canonical spec paths.
|
||||
|
||||
@@ -76,7 +76,7 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
|
||||
taskHandler := handlers.NewTaskHandler(taskService, nil)
|
||||
contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo))
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
|
||||
|
||||
// Create router with real middleware
|
||||
e := echo.New()
|
||||
|
||||
@@ -64,7 +64,7 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
|
||||
// Create handlers
|
||||
authHandler := handlers.NewAuthHandler(authService, nil, nil)
|
||||
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
|
||||
|
||||
// Create router
|
||||
e := echo.New()
|
||||
|
||||
@@ -12,11 +12,20 @@ const (
|
||||
TierPro SubscriptionTier = "pro"
|
||||
)
|
||||
|
||||
// SubscriptionPlatform constants
|
||||
const (
|
||||
PlatformIOS = "ios"
|
||||
PlatformAndroid = "android"
|
||||
PlatformStripe = "stripe"
|
||||
)
|
||||
|
||||
// SubscriptionSettings represents the subscription_subscriptionsettings table (singleton)
|
||||
type SubscriptionSettings struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
EnableLimitations bool `gorm:"column:enable_limitations;default:false" json:"enable_limitations"`
|
||||
EnableMonitoring bool `gorm:"column:enable_monitoring;default:true" json:"enable_monitoring"`
|
||||
TrialEnabled bool `gorm:"column:trial_enabled;default:true" json:"trial_enabled"`
|
||||
TrialDurationDays int `gorm:"column:trial_duration_days;default:14" json:"trial_duration_days"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
@@ -31,18 +40,28 @@ type UserSubscription struct {
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
Tier SubscriptionTier `gorm:"column:tier;size:10;default:'free'" json:"tier"`
|
||||
|
||||
// In-App Purchase data
|
||||
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
|
||||
// In-App Purchase data (Apple / Google)
|
||||
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
|
||||
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
|
||||
|
||||
// Stripe data (web subscriptions)
|
||||
StripeCustomerID *string `gorm:"column:stripe_customer_id;size:255" json:"-"`
|
||||
StripeSubscriptionID *string `gorm:"column:stripe_subscription_id;size:255" json:"-"`
|
||||
StripePriceID *string `gorm:"column:stripe_price_id;size:255" json:"-"`
|
||||
|
||||
// Subscription dates
|
||||
SubscribedAt *time.Time `gorm:"column:subscribed_at" json:"subscribed_at"`
|
||||
ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"`
|
||||
AutoRenew bool `gorm:"column:auto_renew;default:true" json:"auto_renew"`
|
||||
|
||||
// Trial
|
||||
TrialStart *time.Time `gorm:"column:trial_start" json:"trial_start"`
|
||||
TrialEnd *time.Time `gorm:"column:trial_end" json:"trial_end"`
|
||||
TrialUsed bool `gorm:"column:trial_used;default:false" json:"trial_used"`
|
||||
|
||||
// Tracking
|
||||
CancelledAt *time.Time `gorm:"column:cancelled_at" json:"cancelled_at"`
|
||||
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android
|
||||
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android, stripe
|
||||
|
||||
// Admin override - bypasses all limitations regardless of global settings
|
||||
IsFree bool `gorm:"column:is_free;default:false" json:"is_free"`
|
||||
@@ -53,8 +72,11 @@ func (UserSubscription) TableName() string {
|
||||
return "subscription_usersubscription"
|
||||
}
|
||||
|
||||
// IsActive returns true if the subscription is active (pro tier and not expired)
|
||||
// IsActive returns true if the subscription is active (pro tier and not expired, or trial active)
|
||||
func (s *UserSubscription) IsActive() bool {
|
||||
if s.IsTrialActive() {
|
||||
return true
|
||||
}
|
||||
if s.Tier != TierPro {
|
||||
return false
|
||||
}
|
||||
@@ -64,9 +86,37 @@ func (s *UserSubscription) IsActive() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsPro returns true if the user has a pro subscription
|
||||
// IsPro returns true if the user has a pro subscription or active trial
|
||||
func (s *UserSubscription) IsPro() bool {
|
||||
return s.Tier == TierPro && s.IsActive()
|
||||
return s.IsActive()
|
||||
}
|
||||
|
||||
// IsTrialActive returns true if the user has an active, unexpired trial
|
||||
func (s *UserSubscription) IsTrialActive() bool {
|
||||
if s.TrialEnd == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().UTC().Before(*s.TrialEnd)
|
||||
}
|
||||
|
||||
// HasStripeSubscription returns true if the user has Stripe subscription data
|
||||
func (s *UserSubscription) HasStripeSubscription() bool {
|
||||
return s.StripeSubscriptionID != nil && *s.StripeSubscriptionID != ""
|
||||
}
|
||||
|
||||
// HasAppleSubscription returns true if the user has Apple receipt data
|
||||
func (s *UserSubscription) HasAppleSubscription() bool {
|
||||
return s.AppleReceiptData != nil && *s.AppleReceiptData != ""
|
||||
}
|
||||
|
||||
// HasGoogleSubscription returns true if the user has Google purchase token
|
||||
func (s *UserSubscription) HasGoogleSubscription() bool {
|
||||
return s.GooglePurchaseToken != nil && *s.GooglePurchaseToken != ""
|
||||
}
|
||||
|
||||
// SubscriptionSource returns the platform that the active subscription came from
|
||||
func (s *UserSubscription) SubscriptionSource() string {
|
||||
return s.Platform
|
||||
}
|
||||
|
||||
// UpgradeTrigger represents the subscription_upgradetrigger table
|
||||
|
||||
187
internal/models/subscription_test.go
Normal file
187
internal/models/subscription_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsTrialActive(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(24 * time.Hour)
|
||||
past := now.Add(-24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "trial_end in future returns true",
|
||||
sub: &UserSubscription{TrialEnd: &future},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "trial_end in past returns false",
|
||||
sub: &UserSubscription{TrialEnd: &past},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "trial_end nil returns false",
|
||||
sub: &UserSubscription{TrialEnd: nil},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.sub.IsTrialActive()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPro(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(24 * time.Hour)
|
||||
past := now.Add(-24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "tier=pro, expires_at in future returns true",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierPro,
|
||||
ExpiresAt: &future,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tier=pro, expires_at in past returns false",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierPro,
|
||||
ExpiresAt: &past,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "tier=free, trial active returns true",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierFree,
|
||||
TrialEnd: &future,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tier=free, trial expired returns false",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierFree,
|
||||
TrialEnd: &past,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "tier=free, no trial returns false",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierFree,
|
||||
TrialEnd: nil,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.sub.IsPro()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSubscriptionHelpers(t *testing.T) {
|
||||
empty := ""
|
||||
validStripeID := "sub_1234567890"
|
||||
validReceipt := "MIIT..."
|
||||
validToken := "google-purchase-token-123"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
method string
|
||||
expected bool
|
||||
}{
|
||||
// HasStripeSubscription
|
||||
{
|
||||
name: "HasStripeSubscription with nil returns false",
|
||||
sub: &UserSubscription{StripeSubscriptionID: nil},
|
||||
method: "stripe",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasStripeSubscription with empty string returns false",
|
||||
sub: &UserSubscription{StripeSubscriptionID: &empty},
|
||||
method: "stripe",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasStripeSubscription with valid ID returns true",
|
||||
sub: &UserSubscription{StripeSubscriptionID: &validStripeID},
|
||||
method: "stripe",
|
||||
expected: true,
|
||||
},
|
||||
// HasAppleSubscription
|
||||
{
|
||||
name: "HasAppleSubscription with nil returns false",
|
||||
sub: &UserSubscription{AppleReceiptData: nil},
|
||||
method: "apple",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasAppleSubscription with empty string returns false",
|
||||
sub: &UserSubscription{AppleReceiptData: &empty},
|
||||
method: "apple",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasAppleSubscription with valid receipt returns true",
|
||||
sub: &UserSubscription{AppleReceiptData: &validReceipt},
|
||||
method: "apple",
|
||||
expected: true,
|
||||
},
|
||||
// HasGoogleSubscription
|
||||
{
|
||||
name: "HasGoogleSubscription with nil returns false",
|
||||
sub: &UserSubscription{GooglePurchaseToken: nil},
|
||||
method: "google",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasGoogleSubscription with empty string returns false",
|
||||
sub: &UserSubscription{GooglePurchaseToken: &empty},
|
||||
method: "google",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasGoogleSubscription with valid token returns true",
|
||||
sub: &UserSubscription{GooglePurchaseToken: &validToken},
|
||||
method: "google",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result bool
|
||||
switch tt.method {
|
||||
case "stripe":
|
||||
result = tt.sub.HasStripeSubscription()
|
||||
case "apple":
|
||||
result = tt.sub.HasAppleSubscription()
|
||||
case "google":
|
||||
result = tt.sub.HasGoogleSubscription()
|
||||
}
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -262,3 +262,59 @@ func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.P
|
||||
}
|
||||
return &promotion, nil
|
||||
}
|
||||
|
||||
// === Stripe Lookups ===
|
||||
|
||||
// FindByStripeCustomerID finds a subscription by Stripe customer ID
|
||||
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
|
||||
var sub models.UserSubscription
|
||||
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
|
||||
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
|
||||
var sub models.UserSubscription
|
||||
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
|
||||
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"stripe_customer_id": customerID,
|
||||
"stripe_subscription_id": subscriptionID,
|
||||
"stripe_price_id": priceID,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
|
||||
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"stripe_subscription_id": nil,
|
||||
"stripe_price_id": nil,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// === Trial Management ===
|
||||
|
||||
// SetTrialDates sets the trial start, end, and marks trial as used
|
||||
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"trial_start": trialStart,
|
||||
"trial_end": trialEnd,
|
||||
"trial_used": true,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateExpiresAt updates the expires_at field for a user's subscription
|
||||
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
|
||||
Update("expires_at", expiresAt).Error
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
@@ -77,3 +79,150 @@ func TestGetOrCreate_Idempotent(t *testing.T) {
|
||||
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
|
||||
}
|
||||
|
||||
func TestFindByStripeCustomerID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
customerID string
|
||||
seedID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "finds existing subscription by stripe customer ID",
|
||||
customerID: "cus_test123",
|
||||
seedID: "cus_test123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns error for unknown stripe customer ID",
|
||||
customerID: "cus_unknown999",
|
||||
seedID: "cus_test456",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
StripeCustomerID: &tt.seedID,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByStripeCustomerID(tt.customerID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
assert.Equal(t, tt.seedID, *found.StripeCustomerID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateStripeData(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update all three Stripe fields
|
||||
customerID := "cus_abc123"
|
||||
subscriptionID := "sub_xyz789"
|
||||
priceID := "price_monthly"
|
||||
|
||||
err = repo.UpdateStripeData(user.ID, customerID, subscriptionID, priceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all three fields are set
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.StripeCustomerID)
|
||||
require.NotNil(t, updated.StripeSubscriptionID)
|
||||
require.NotNil(t, updated.StripePriceID)
|
||||
assert.Equal(t, customerID, *updated.StripeCustomerID)
|
||||
assert.Equal(t, subscriptionID, *updated.StripeSubscriptionID)
|
||||
assert.Equal(t, priceID, *updated.StripePriceID)
|
||||
|
||||
// Now call ClearStripeData
|
||||
err = repo.ClearStripeData(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription_id and price_id are cleared, customer_id preserved
|
||||
var cleared models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&cleared).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cleared.StripeCustomerID, "customer_id should be preserved after ClearStripeData")
|
||||
assert.Equal(t, customerID, *cleared.StripeCustomerID)
|
||||
assert.Nil(t, cleared.StripeSubscriptionID, "subscription_id should be nil after ClearStripeData")
|
||||
assert.Nil(t, cleared.StripePriceID, "price_id should be nil after ClearStripeData")
|
||||
}
|
||||
|
||||
func TestSetTrialDates(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
TrialUsed: false,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
trialStart := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
trialEnd := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
err = repo.SetTrialDates(user.ID, trialStart, trialEnd)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, updated.TrialStart)
|
||||
require.NotNil(t, updated.TrialEnd)
|
||||
assert.True(t, updated.TrialUsed, "trial_used should be set to true")
|
||||
assert.WithinDuration(t, trialStart, *updated.TrialStart, time.Second, "trial_start should match")
|
||||
assert.WithinDuration(t, trialEnd, *updated.TrialEnd, time.Second, "trial_end should match")
|
||||
}
|
||||
|
||||
func TestUpdateExpiresAt(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
newExpiry := time.Date(2027, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
err = repo.UpdateExpiresAt(user.ID, newExpiry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, updated.ExpiresAt)
|
||||
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
|
||||
}
|
||||
|
||||
@@ -58,7 +58,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
e.Use(custommiddleware.RequestIDMiddleware())
|
||||
e.Use(utils.EchoRecovery())
|
||||
e.Use(custommiddleware.StructuredLogger())
|
||||
e.Use(middleware.BodyLimit("1M")) // 1MB default for JSON payloads
|
||||
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
||||
Limit: "1M", // 1MB default for JSON payloads
|
||||
Skipper: func(c echo.Context) bool {
|
||||
// Allow larger payloads for webhook endpoints (Apple/Google/Stripe notifications)
|
||||
return strings.HasPrefix(c.Request().URL.Path, "/api/subscription/webhook")
|
||||
},
|
||||
}))
|
||||
e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
|
||||
Timeout: 30 * time.Second,
|
||||
Skipper: func(c echo.Context) bool {
|
||||
@@ -143,11 +149,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
residenceService.SetSubscriptionService(subscriptionService) // Wire up subscription service for tier limit enforcement
|
||||
taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo)
|
||||
|
||||
// Initialize Stripe service
|
||||
stripeService := services.NewStripeService(subscriptionRepo, userRepo)
|
||||
|
||||
// Initialize webhook event repo for deduplication
|
||||
webhookEventRepo := repositories.NewWebhookEventRepository(deps.DB)
|
||||
|
||||
// Initialize webhook handler for Apple/Google subscription notifications
|
||||
// Initialize webhook handler for Apple/Google/Stripe subscription notifications
|
||||
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled)
|
||||
subscriptionWebhookHandler.SetStripeService(stripeService)
|
||||
|
||||
// Initialize middleware
|
||||
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
|
||||
@@ -166,7 +176,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
contractorHandler := handlers.NewContractorHandler(contractorService)
|
||||
documentHandler := handlers.NewDocumentHandler(documentService, deps.StorageService)
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, stripeService)
|
||||
staticDataHandler := handlers.NewStaticDataHandler(residenceService, taskService, contractorService, taskTemplateService, deps.Cache)
|
||||
taskTemplateHandler := handlers.NewTaskTemplateHandler(taskTemplateService)
|
||||
|
||||
@@ -458,6 +468,8 @@ func setupSubscriptionRoutes(api *echo.Group, subscriptionHandler *handlers.Subs
|
||||
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
|
||||
subscription.POST("/cancel/", subscriptionHandler.CancelSubscription)
|
||||
subscription.POST("/restore/", subscriptionHandler.RestoreSubscription)
|
||||
subscription.POST("/checkout/", subscriptionHandler.CreateCheckoutSession)
|
||||
subscription.POST("/portal/", subscriptionHandler.CreatePortalSession)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -499,6 +511,7 @@ func setupWebhookRoutes(api *echo.Group, webhookHandler *handlers.SubscriptionWe
|
||||
{
|
||||
webhooks.POST("/apple/", webhookHandler.HandleAppleWebhook)
|
||||
webhooks.POST("/google/", webhookHandler.HandleGoogleWebhook)
|
||||
webhooks.POST("/stripe/", webhookHandler.HandleStripeWebhook)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
456
internal/services/stripe_service.go
Normal file
456
internal/services/stripe_service.go
Normal file
@@ -0,0 +1,456 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
|
||||
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
||||
"github.com/stripe/stripe-go/v81/customer"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
)
|
||||
|
||||
// StripeService handles Stripe checkout, portal, and webhook processing
|
||||
// for web-based subscription purchases.
|
||||
type StripeService struct {
|
||||
subscriptionRepo *repositories.SubscriptionRepository
|
||||
userRepo *repositories.UserRepository
|
||||
webhookSecret string
|
||||
}
|
||||
|
||||
// NewStripeService creates a new Stripe service. It initializes the global
|
||||
// Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key
|
||||
// is not set, a warning is logged but the service is still returned (matching
|
||||
// the pattern used by the Apple/Google IAP clients).
|
||||
func NewStripeService(
|
||||
subscriptionRepo *repositories.SubscriptionRepository,
|
||||
userRepo *repositories.UserRepository,
|
||||
) *StripeService {
|
||||
key := os.Getenv("STRIPE_SECRET_KEY")
|
||||
if key == "" {
|
||||
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
|
||||
} else {
|
||||
stripe.Key = key
|
||||
log.Info().Msg("Stripe API key configured")
|
||||
}
|
||||
|
||||
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
|
||||
if webhookSecret == "" {
|
||||
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
|
||||
}
|
||||
|
||||
return &StripeService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
userRepo: userRepo,
|
||||
webhookSecret: webhookSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCheckoutSession creates a Stripe Checkout Session for a web subscription purchase.
|
||||
// It ensures the user has a Stripe customer record and configures the session with a trial
|
||||
// period if the user has not used their trial yet.
|
||||
func (s *StripeService) CreateCheckoutSession(userID uint, priceID string, successURL string, cancelURL string) (string, error) {
|
||||
// Get or create the user's subscription record
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Get the user's email for the Stripe customer
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Get or create a Stripe customer
|
||||
stripeCustomerID, err := s.getOrCreateStripeCustomer(sub, user)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Build the checkout session parameters
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Customer: stripe.String(stripeCustomerID),
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
SuccessURL: stripe.String(successURL),
|
||||
CancelURL: stripe.String(cancelURL),
|
||||
ClientReferenceID: stripe.String(fmt.Sprintf("%d", userID)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(priceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Offer a trial period if the user has not used their trial yet
|
||||
if !sub.TrialUsed {
|
||||
trialDays, err := s.getTrialDays()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get trial duration from settings, skipping trial")
|
||||
} else if trialDays > 0 {
|
||||
params.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{
|
||||
TrialPeriodDays: stripe.Int64(int64(trialDays)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session, err := checkoutsession.New(params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe checkout session")
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", userID).
|
||||
Str("session_id", session.ID).
|
||||
Str("price_id", priceID).
|
||||
Msg("Stripe checkout session created")
|
||||
|
||||
return session.URL, nil
|
||||
}
|
||||
|
||||
// CreatePortalSession creates a Stripe Customer Portal session so the user
|
||||
// can manage their subscription (cancel, change plan, update payment method).
|
||||
func (s *StripeService) CreatePortalSession(userID uint, returnURL string) (string, error) {
|
||||
sub, err := s.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.NotFound("error.subscription_not_found")
|
||||
}
|
||||
|
||||
if sub.StripeCustomerID == nil || *sub.StripeCustomerID == "" {
|
||||
return "", apperrors.BadRequest("error.no_stripe_customer")
|
||||
}
|
||||
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: sub.StripeCustomerID,
|
||||
ReturnURL: stripe.String(returnURL),
|
||||
}
|
||||
|
||||
session, err := portalsession.New(params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe portal session")
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return session.URL, nil
|
||||
}
|
||||
|
||||
// HandleWebhookEvent verifies and processes a Stripe webhook event.
|
||||
// It handles checkout completion, subscription lifecycle changes, and invoice events.
|
||||
func (s *StripeService) HandleWebhookEvent(payload []byte, signature string) error {
|
||||
event, err := webhook.ConstructEvent(payload, signature, s.webhookSecret)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Stripe webhook signature verification failed")
|
||||
return apperrors.BadRequest("error.invalid_webhook_signature")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("event_type", string(event.Type)).
|
||||
Str("event_id", event.ID).
|
||||
Msg("Processing Stripe webhook event")
|
||||
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
return s.handleCheckoutCompleted(event)
|
||||
case "customer.subscription.updated":
|
||||
return s.handleSubscriptionUpdated(event)
|
||||
case "customer.subscription.deleted":
|
||||
return s.handleSubscriptionDeleted(event)
|
||||
case "invoice.paid":
|
||||
return s.handleInvoicePaid(event)
|
||||
case "invoice.payment_failed":
|
||||
return s.handleInvoicePaymentFailed(event)
|
||||
default:
|
||||
log.Debug().Str("event_type", string(event.Type)).Msg("Unhandled Stripe webhook event type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleCheckoutCompleted processes a successful checkout session. It links the Stripe
|
||||
// customer and subscription to the user's record and upgrades them to Pro.
|
||||
func (s *StripeService) handleCheckoutCompleted(event stripe.Event) error {
|
||||
var session stripe.CheckoutSession
|
||||
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal checkout session from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Extract the user ID from client_reference_id
|
||||
var userID uint
|
||||
if _, err := fmt.Sscanf(session.ClientReferenceID, "%d", &userID); err != nil {
|
||||
log.Error().Str("client_reference_id", session.ClientReferenceID).Msg("Invalid client_reference_id in checkout session")
|
||||
return apperrors.BadRequest("error.invalid_client_reference_id")
|
||||
}
|
||||
|
||||
// Get or create the subscription record
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Save Stripe customer and subscription IDs
|
||||
if session.Customer != nil {
|
||||
sub.StripeCustomerID = &session.Customer.ID
|
||||
}
|
||||
if session.Subscription != nil {
|
||||
sub.StripeSubscriptionID = &session.Subscription.ID
|
||||
}
|
||||
|
||||
if err := s.subscriptionRepo.Update(sub); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Upgrade to Pro. Use a far-future expiry for now; the invoice.paid event
|
||||
// will set the real period_end once the first invoice is finalized.
|
||||
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
|
||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
customerID := ""
|
||||
if session.Customer != nil {
|
||||
customerID = session.Customer.ID
|
||||
}
|
||||
subscriptionID := ""
|
||||
if session.Subscription != nil {
|
||||
subscriptionID = session.Subscription.ID
|
||||
}
|
||||
log.Info().
|
||||
Uint("user_id", userID).
|
||||
Str("stripe_customer_id", customerID).
|
||||
Str("stripe_subscription_id", subscriptionID).
|
||||
Msg("Checkout completed, user upgraded to Pro")
|
||||
|
||||
// TODO: Send push notification to user's devices when subscription activates
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionUpdated processes subscription status changes. It upgrades or
|
||||
// downgrades the user depending on the subscription's current status.
|
||||
func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
|
||||
var subscription stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch subscription.Status {
|
||||
case stripe.SubscriptionStatusActive, stripe.SubscriptionStatusTrialing:
|
||||
// Subscription is healthy, ensure user is Pro
|
||||
expiresAt := time.Unix(subscription.CurrentPeriodEnd, 0).UTC()
|
||||
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active")
|
||||
|
||||
case stripe.SubscriptionStatusPastDue:
|
||||
log.Warn().Uint("user_id", sub.UserID).Msg("Stripe subscription past due, waiting for retry")
|
||||
// Don't downgrade yet; Stripe will retry the payment automatically.
|
||||
|
||||
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusUnpaid:
|
||||
// Check if the user has active subscriptions from other sources before downgrading
|
||||
if s.isActiveFromOtherSources(sub) {
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Str("status", string(subscription.Status)).
|
||||
Msg("Stripe subscription ended but user has other active sources, keeping Pro")
|
||||
return nil
|
||||
}
|
||||
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionDeleted processes a subscription that has been fully cancelled
|
||||
// and is no longer active. It downgrades the user unless they have active subscriptions
|
||||
// from other sources (Apple, Google).
|
||||
func (s *StripeService) handleSubscriptionDeleted(event stripe.Event) error {
|
||||
var subscription stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check multi-source before downgrading
|
||||
if s.isActiveFromOtherSources(sub) {
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Msg("Stripe subscription deleted but user has other active sources, keeping Pro")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted")
|
||||
|
||||
// TODO: Send push notification to user's devices about subscription ending
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaid processes a successful invoice payment. It updates the subscription
|
||||
// expiry to the current billing period's end date and ensures the user is on Pro.
|
||||
func (s *StripeService) handleInvoicePaid(event stripe.Event) error {
|
||||
var invoice stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Only process subscription invoices
|
||||
if invoice.Subscription == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update expiry from the invoice's period end
|
||||
expiresAt := time.Unix(invoice.PeriodEnd, 0).UTC()
|
||||
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Time("expires_at", expiresAt).
|
||||
Msg("Invoice paid, subscription renewed")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaymentFailed logs a warning when a payment fails. We do not downgrade
|
||||
// the user here because Stripe will automatically retry the payment according to its
|
||||
// Smart Retries schedule.
|
||||
func (s *StripeService) handleInvoicePaymentFailed(event stripe.Event) error {
|
||||
var invoice stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
if invoice.Subscription == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
|
||||
if err != nil {
|
||||
// If we can't find the subscription, just log and return
|
||||
log.Warn().Str("stripe_subscription_id", invoice.Subscription.ID).Msg("Invoice payment failed for unknown subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warn().
|
||||
Uint("user_id", sub.UserID).
|
||||
Str("invoice_id", invoice.ID).
|
||||
Msg("Stripe invoice payment failed, Stripe will retry automatically")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isActiveFromOtherSources checks if the user has active subscriptions from Apple or Google
|
||||
// that should prevent a downgrade when the Stripe subscription ends.
|
||||
func (s *StripeService) isActiveFromOtherSources(sub *models.UserSubscription) bool {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Check Apple subscription
|
||||
if sub.HasAppleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check Google subscription
|
||||
if sub.HasGoogleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check active trial
|
||||
if sub.IsTrialActive() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getOrCreateStripeCustomer returns the existing Stripe customer ID from the subscription
|
||||
// record, or creates a new Stripe customer and persists the ID.
|
||||
func (s *StripeService) getOrCreateStripeCustomer(sub *models.UserSubscription, user *models.User) (string, error) {
|
||||
// If we already have a Stripe customer, return it
|
||||
if sub.StripeCustomerID != nil && *sub.StripeCustomerID != "" {
|
||||
return *sub.StripeCustomerID, nil
|
||||
}
|
||||
|
||||
// Create a new Stripe customer
|
||||
params := &stripe.CustomerParams{
|
||||
Email: stripe.String(user.Email),
|
||||
Name: stripe.String(user.GetFullName()),
|
||||
}
|
||||
params.AddMetadata("casera_user_id", fmt.Sprintf("%d", user.ID))
|
||||
|
||||
c, err := customer.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create Stripe customer: %w", err)
|
||||
}
|
||||
|
||||
// Save the customer ID to the subscription record
|
||||
sub.StripeCustomerID = &c.ID
|
||||
if err := s.subscriptionRepo.Update(sub); err != nil {
|
||||
return "", fmt.Errorf("failed to save Stripe customer ID: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", user.ID).
|
||||
Str("stripe_customer_id", c.ID).
|
||||
Msg("Created new Stripe customer")
|
||||
|
||||
return c.ID, nil
|
||||
}
|
||||
|
||||
// findSubscriptionByStripeID looks up a UserSubscription by its Stripe subscription ID.
|
||||
func (s *StripeService) findSubscriptionByStripeID(stripeSubID string) (*models.UserSubscription, error) {
|
||||
sub, err := s.subscriptionRepo.FindByStripeSubscriptionID(stripeSubID)
|
||||
if err != nil {
|
||||
log.Warn().Str("stripe_subscription_id", stripeSubID).Err(err).Msg("Subscription not found for Stripe ID")
|
||||
return nil, apperrors.NotFound("error.subscription_not_found")
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// getTrialDays reads the trial duration from SubscriptionSettings.
|
||||
func (s *StripeService) getTrialDays() (int, error) {
|
||||
settings, err := s.subscriptionRepo.GetSettings()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !settings.TrialEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
return settings.TrialDurationDays, nil
|
||||
}
|
||||
@@ -118,6 +118,20 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Auto-start trial for new users who have never had a trial
|
||||
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
|
||||
now := time.Now().UTC()
|
||||
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
|
||||
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
// Re-fetch after starting trial so response reflects the new state
|
||||
sub, err = s.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all tier limits and build a map
|
||||
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
|
||||
if err != nil {
|
||||
@@ -154,6 +168,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
|
||||
// Build flattened response (KMM expects subscription fields at top level)
|
||||
resp := &SubscriptionStatusResponse{
|
||||
Tier: string(sub.Tier),
|
||||
IsActive: sub.IsActive(),
|
||||
AutoRenew: sub.AutoRenew,
|
||||
Limits: limitsMap,
|
||||
Usage: usage,
|
||||
@@ -170,6 +186,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
resp.ExpiresAt = &t
|
||||
}
|
||||
|
||||
// Populate trial fields
|
||||
if sub.TrialStart != nil {
|
||||
t := sub.TrialStart.Format("2006-01-02T15:04:05Z")
|
||||
resp.TrialStart = &t
|
||||
}
|
||||
if sub.TrialEnd != nil {
|
||||
t := sub.TrialEnd.Format("2006-01-02T15:04:05Z")
|
||||
resp.TrialEnd = &t
|
||||
}
|
||||
resp.TrialActive = sub.IsTrialActive()
|
||||
resp.SubscriptionSource = sub.SubscriptionSource()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -449,28 +477,48 @@ func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResp
|
||||
return s.GetSubscription(userID)
|
||||
}
|
||||
|
||||
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
|
||||
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
|
||||
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return false, "", apperrors.Internal(err)
|
||||
}
|
||||
if !sub.IsPro() {
|
||||
return false, "", nil
|
||||
}
|
||||
if sub.Platform == requestedPlatform {
|
||||
return false, "", nil
|
||||
}
|
||||
return true, sub.Platform, nil
|
||||
}
|
||||
|
||||
// === Response Types ===
|
||||
|
||||
// SubscriptionResponse represents a subscription in API response
|
||||
type SubscriptionResponse struct {
|
||||
Tier string `json:"tier"`
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsPro bool `json:"is_pro"`
|
||||
Tier string `json:"tier"`
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsPro bool `json:"is_pro"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
SubscriptionSource string `json:"subscription_source"`
|
||||
}
|
||||
|
||||
// NewSubscriptionResponse creates a SubscriptionResponse from a model
|
||||
func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse {
|
||||
resp := &SubscriptionResponse{
|
||||
Tier: string(s.Tier),
|
||||
AutoRenew: s.AutoRenew,
|
||||
Platform: s.Platform,
|
||||
IsActive: s.IsActive(),
|
||||
IsPro: s.IsPro(),
|
||||
Tier: string(s.Tier),
|
||||
AutoRenew: s.AutoRenew,
|
||||
Platform: s.Platform,
|
||||
IsActive: s.IsActive(),
|
||||
IsPro: s.IsPro(),
|
||||
TrialActive: s.IsTrialActive(),
|
||||
SubscriptionSource: s.SubscriptionSource(),
|
||||
}
|
||||
if s.SubscribedAt != nil {
|
||||
t := s.SubscribedAt.Format("2006-01-02T15:04:05Z")
|
||||
@@ -536,11 +584,23 @@ func NewTierLimitsClientResponse(l *models.TierLimits) *TierLimitsClientResponse
|
||||
// SubscriptionStatusResponse represents full subscription status
|
||||
// Fields are flattened to match KMM client expectations
|
||||
type SubscriptionStatusResponse struct {
|
||||
// Tier and active status
|
||||
Tier string `json:"tier"`
|
||||
IsActive bool `json:"is_active"`
|
||||
|
||||
// Flattened subscription fields (KMM expects these at top level)
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start,omitempty"`
|
||||
TrialEnd *string `json:"trial_end,omitempty"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
|
||||
// Subscription source
|
||||
SubscriptionSource string `json:"subscription_source"`
|
||||
|
||||
// Other fields
|
||||
Usage *UsageResponse `json:"usage"`
|
||||
Limits map[string]*TierLimitsClientResponse `json:"limits"`
|
||||
@@ -638,5 +698,5 @@ type ProcessPurchaseRequest struct {
|
||||
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
|
||||
PurchaseToken string `json:"purchase_token"` // Android
|
||||
ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android"`
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android stripe"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -179,3 +180,94 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
|
||||
}
|
||||
|
||||
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tier models.SubscriptionTier
|
||||
platform string
|
||||
expiresAt *time.Time
|
||||
trialEnd *time.Time
|
||||
requestedPlatform string
|
||||
wantConflict bool
|
||||
wantPlatform string
|
||||
}{
|
||||
{
|
||||
name: "free user returns no conflict",
|
||||
tier: models.TierFree,
|
||||
platform: "",
|
||||
expiresAt: nil,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "stripe",
|
||||
wantConflict: false,
|
||||
wantPlatform: "",
|
||||
},
|
||||
{
|
||||
name: "pro from ios, requesting ios returns no conflict (same platform)",
|
||||
tier: models.TierPro,
|
||||
platform: "ios",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "ios",
|
||||
wantConflict: false,
|
||||
wantPlatform: "",
|
||||
},
|
||||
{
|
||||
name: "pro from ios, requesting stripe returns conflict",
|
||||
tier: models.TierPro,
|
||||
platform: "ios",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "stripe",
|
||||
wantConflict: true,
|
||||
wantPlatform: "ios",
|
||||
},
|
||||
{
|
||||
name: "pro from stripe, requesting android returns conflict",
|
||||
tier: models.TierPro,
|
||||
platform: "stripe",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "android",
|
||||
wantConflict: true,
|
||||
wantPlatform: "stripe",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: tt.tier,
|
||||
Platform: tt.platform,
|
||||
ExpiresAt: tt.expiresAt,
|
||||
TrialEnd: tt.trialEnd,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantConflict, conflict)
|
||||
assert.Equal(t, tt.wantPlatform, existingPlatform)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user