Add Stripe billing, free trials, and cross-platform subscription guards

- Stripe integration: add StripeService with checkout sessions, customer
  portal, and webhook handling for subscription lifecycle events.
- Free trials: auto-start configurable trial on first subscription check,
  with admin-controllable duration and enable/disable toggle.
- Cross-platform guard: prevent duplicate subscriptions across iOS, Android,
  and Stripe by checking existing platform before allowing purchase.
- Subscription model: add Stripe fields (customer_id, subscription_id,
  price_id), trial fields (trial_start, trial_end, trial_used), and
  SubscriptionSource/IsTrialActive helpers.
- API: add trial and source fields to status response, update OpenAPI spec.
- Clean up stale migration and audit docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-05 11:36:14 -06:00
parent d5bb123cd0
commit 72db9050f8
35 changed files with 1555 additions and 1120 deletions

View File

@@ -234,11 +234,13 @@ type UpdateNotificationRequest struct {
// SubscriptionFilters holds subscription-specific filter parameters
type SubscriptionFilters struct {
PaginationParams
UserID *uint `form:"user_id"`
Tier *string `form:"tier"`
Platform *string `form:"platform"`
AutoRenew *bool `form:"auto_renew"`
Active *bool `form:"active"`
UserID *uint `form:"user_id"`
Tier *string `form:"tier"`
Platform *string `form:"platform"`
AutoRenew *bool `form:"auto_renew"`
Active *bool `form:"active"`
HasStripe *bool `form:"has_stripe"`
TrialActive *bool `form:"trial_active"`
}
// UpdateSubscriptionRequest for updating a subscription
@@ -250,6 +252,14 @@ type UpdateSubscriptionRequest struct {
SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"`
CancelledAt *string `json:"cancelled_at"`
// Stripe fields
StripeCustomerID *string `json:"stripe_customer_id"`
StripeSubscriptionID *string `json:"stripe_subscription_id"`
StripePriceID *string `json:"stripe_price_id"`
// Trial fields
TrialStart *string `json:"trial_start"`
TrialEnd *string `json:"trial_end"`
TrialUsed *bool `json:"trial_used"`
}
// CreateResidenceRequest for creating a new residence

View File

@@ -264,7 +264,16 @@ type SubscriptionResponse struct {
SubscribedAt *string `json:"subscribed_at,omitempty"`
ExpiresAt *string `json:"expires_at,omitempty"`
CancelledAt *string `json:"cancelled_at,omitempty"`
CreatedAt string `json:"created_at"`
// Stripe fields
StripeCustomerID *string `json:"stripe_customer_id,omitempty"`
StripeSubscriptionID *string `json:"stripe_subscription_id,omitempty"`
StripePriceID *string `json:"stripe_price_id,omitempty"`
// Trial fields
TrialStart *string `json:"trial_start,omitempty"`
TrialEnd *string `json:"trial_end,omitempty"`
TrialUsed bool `json:"trial_used"`
TrialActive bool `json:"trial_active"`
CreatedAt string `json:"created_at"`
}
// SubscriptionDetailResponse includes more details for single subscription view

View File

@@ -30,6 +30,8 @@ func NewAdminSettingsHandler(db *gorm.DB) *AdminSettingsHandler {
type SettingsResponse struct {
EnableLimitations bool `json:"enable_limitations"`
EnableMonitoring bool `json:"enable_monitoring"`
TrialEnabled bool `json:"trial_enabled"`
TrialDurationDays int `json:"trial_duration_days"`
}
// GetSettings handles GET /api/admin/settings
@@ -38,7 +40,13 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
if err := h.db.First(&settings, 1).Error; err != nil {
if err == gorm.ErrRecordNotFound {
// Create default settings
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false, EnableMonitoring: true}
settings = models.SubscriptionSettings{
ID: 1,
EnableLimitations: false,
EnableMonitoring: true,
TrialEnabled: true,
TrialDurationDays: 14,
}
h.db.Create(&settings)
} else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
@@ -48,6 +56,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
return c.JSON(http.StatusOK, SettingsResponse{
EnableLimitations: settings.EnableLimitations,
EnableMonitoring: settings.EnableMonitoring,
TrialEnabled: settings.TrialEnabled,
TrialDurationDays: settings.TrialDurationDays,
})
}
@@ -55,6 +65,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
type UpdateSettingsRequest struct {
EnableLimitations *bool `json:"enable_limitations"`
EnableMonitoring *bool `json:"enable_monitoring"`
TrialEnabled *bool `json:"trial_enabled"`
TrialDurationDays *int `json:"trial_duration_days"`
}
// UpdateSettings handles PUT /api/admin/settings
@@ -67,7 +79,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
var settings models.SubscriptionSettings
if err := h.db.First(&settings, 1).Error; err != nil {
if err == gorm.ErrRecordNotFound {
settings = models.SubscriptionSettings{ID: 1, EnableMonitoring: true}
settings = models.SubscriptionSettings{
ID: 1,
EnableMonitoring: true,
TrialEnabled: true,
TrialDurationDays: 14,
}
} else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
}
@@ -81,6 +98,14 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
settings.EnableMonitoring = *req.EnableMonitoring
}
if req.TrialEnabled != nil {
settings.TrialEnabled = *req.TrialEnabled
}
if req.TrialDurationDays != nil {
settings.TrialDurationDays = *req.TrialDurationDays
}
if err := h.db.Save(&settings).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update settings"})
}
@@ -88,6 +113,8 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
return c.JSON(http.StatusOK, SettingsResponse{
EnableLimitations: settings.EnableLimitations,
EnableMonitoring: settings.EnableMonitoring,
TrialEnabled: settings.TrialEnabled,
TrialDurationDays: settings.TrialDurationDays,
})
}

View File

@@ -3,6 +3,7 @@ package handlers
import (
"net/http"
"strconv"
"time"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
@@ -61,6 +62,20 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
query = query.Where("expires_at IS NOT NULL AND expires_at <= NOW()")
}
}
if filters.HasStripe != nil {
if *filters.HasStripe {
query = query.Where("stripe_subscription_id IS NOT NULL")
} else {
query = query.Where("stripe_subscription_id IS NULL")
}
}
if filters.TrialActive != nil {
if *filters.TrialActive {
query = query.Where("trial_end IS NOT NULL AND trial_end > NOW()")
} else {
query = query.Where("trial_end IS NULL OR trial_end <= NOW()")
}
}
// Get total count
query.Count(&total)
@@ -137,6 +152,32 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error {
if req.IsFree != nil {
subscription.IsFree = *req.IsFree
}
if req.StripeCustomerID != nil {
subscription.StripeCustomerID = req.StripeCustomerID
}
if req.StripeSubscriptionID != nil {
subscription.StripeSubscriptionID = req.StripeSubscriptionID
}
if req.StripePriceID != nil {
subscription.StripePriceID = req.StripePriceID
}
if req.TrialStart != nil {
parsed, err := time.Parse(time.RFC3339, *req.TrialStart)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_start format, expected RFC3339"})
}
subscription.TrialStart = &parsed
}
if req.TrialEnd != nil {
parsed, err := time.Parse(time.RFC3339, *req.TrialEnd)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_end format, expected RFC3339"})
}
subscription.TrialEnd = &parsed
}
if req.TrialUsed != nil {
subscription.TrialUsed = *req.TrialUsed
}
if err := h.db.Save(&subscription).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update subscription"})
@@ -184,29 +225,39 @@ func (h *AdminSubscriptionHandler) GetByUser(c echo.Context) error {
// GetStats handles GET /api/admin/subscriptions/stats
func (h *AdminSubscriptionHandler) GetStats(c echo.Context) error {
var total, free, premium, pro int64
var stripeSubscribers, activeTrials int64
h.db.Model(&models.UserSubscription{}).Count(&total)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "free").Count(&free)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&premium)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&pro)
h.db.Model(&models.UserSubscription{}).Where("stripe_subscription_id IS NOT NULL AND tier = ?", "pro").Count(&stripeSubscribers)
h.db.Model(&models.UserSubscription{}).Where("trial_end IS NOT NULL AND trial_end > NOW()").Count(&activeTrials)
return c.JSON(http.StatusOK, map[string]interface{}{
"total": total,
"free": free,
"premium": premium,
"pro": pro,
"total": total,
"free": free,
"premium": premium,
"pro": pro,
"stripe_subscribers": stripeSubscribers,
"active_trials": activeTrials,
})
}
func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscription) dto.SubscriptionResponse {
response := dto.SubscriptionResponse{
ID: sub.ID,
UserID: sub.UserID,
Tier: string(sub.Tier),
Platform: sub.Platform,
AutoRenew: sub.AutoRenew,
IsFree: sub.IsFree,
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
ID: sub.ID,
UserID: sub.UserID,
Tier: string(sub.Tier),
Platform: sub.Platform,
AutoRenew: sub.AutoRenew,
IsFree: sub.IsFree,
StripeCustomerID: sub.StripeCustomerID,
StripeSubscriptionID: sub.StripeSubscriptionID,
StripePriceID: sub.StripePriceID,
TrialUsed: sub.TrialUsed,
TrialActive: sub.IsTrialActive(),
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
}
if sub.User.ID != 0 {
@@ -225,6 +276,14 @@ func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscr
cancelledAt := sub.CancelledAt.Format("2006-01-02T15:04:05Z")
response.CancelledAt = &cancelledAt
}
if sub.TrialStart != nil {
trialStart := sub.TrialStart.Format(time.RFC3339)
response.TrialStart = &trialStart
}
if sub.TrialEnd != nil {
trialEnd := sub.TrialEnd.Format(time.RFC3339)
response.TrialEnd = &trialEnd
}
return response
}

View File

@@ -25,6 +25,7 @@ type Config struct {
GoogleAuth GoogleAuthConfig
AppleIAP AppleIAPConfig
GoogleIAP GoogleIAPConfig
Stripe StripeConfig
Features FeatureFlags
}
@@ -104,6 +105,14 @@ type GoogleIAPConfig struct {
PackageName string // Android package name (e.g., com.tt.casera)
}
// StripeConfig holds Stripe payment configuration
type StripeConfig struct {
SecretKey string // Stripe secret API key
WebhookSecret string // Stripe webhook endpoint signing secret
PriceMonthly string // Stripe Price ID for monthly Pro subscription
PriceYearly string // Stripe Price ID for yearly Pro subscription
}
type WorkerConfig struct {
// Scheduled job times (UTC)
TaskReminderHour int
@@ -248,6 +257,12 @@ func Load() (*Config, error) {
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
},
Stripe: StripeConfig{
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
},
Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),

View File

@@ -240,7 +240,7 @@ func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
handler := NewSubscriptionHandler(subscriptionService)
handler := NewSubscriptionHandler(subscriptionService, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware

View File

@@ -13,11 +13,15 @@ import (
// SubscriptionHandler handles subscription-related HTTP requests
type SubscriptionHandler struct {
subscriptionService *services.SubscriptionService
stripeService *services.StripeService
}
// NewSubscriptionHandler creates a new subscription handler
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{subscriptionService: subscriptionService}
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService, stripeService *services.StripeService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
stripeService: stripeService,
}
}
// GetSubscription handles GET /api/subscription/
@@ -194,3 +198,82 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
"subscription": subscription,
})
}
// CreateCheckoutSession handles POST /api/subscription/checkout/
// Creates a Stripe Checkout Session for web subscription purchases
func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
if h.stripeService == nil {
return apperrors.BadRequest("error.stripe_not_configured")
}
// Check if already Pro from another platform
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe")
if err != nil {
return err
}
if alreadyPro {
return c.JSON(http.StatusConflict, map[string]interface{}{
"error": "error.already_subscribed_other_platform",
"existing_platform": existingPlatform,
"message": "You already have an active Pro subscription via " + existingPlatform + ". Manage it there to avoid double billing.",
})
}
var req struct {
PriceID string `json:"price_id" validate:"required"`
SuccessURL string `json:"success_url" validate:"required,url"`
CancelURL string `json:"cancel_url" validate:"required,url"`
}
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
sessionURL, err := h.stripeService.CreateCheckoutSession(user.ID, req.PriceID, req.SuccessURL, req.CancelURL)
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{
"checkout_url": sessionURL,
})
}
// CreatePortalSession handles POST /api/subscription/portal/
// Creates a Stripe Customer Portal session for managing web subscriptions
func (h *SubscriptionHandler) CreatePortalSession(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
if h.stripeService == nil {
return apperrors.BadRequest("error.stripe_not_configured")
}
var req struct {
ReturnURL string `json:"return_url" validate:"required,url"`
}
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
portalURL, err := h.stripeService.CreatePortalSession(user.ID, req.ReturnURL)
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{
"portal_url": portalURL,
})
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
)
// SubscriptionWebhookHandler handles subscription webhook callbacks
@@ -28,6 +29,7 @@ type SubscriptionWebhookHandler struct {
userRepo *repositories.UserRepository
webhookEventRepo *repositories.WebhookEventRepository
appleRootCerts []*x509.Certificate
stripeService *services.StripeService
enabled bool
}
@@ -46,6 +48,11 @@ func NewSubscriptionWebhookHandler(
}
}
// SetStripeService sets the Stripe service for webhook handling
func (h *SubscriptionWebhookHandler) SetStripeService(stripeService *services.StripeService) {
h.stripeService = stripeService
}
// ====================
// Apple App Store Server Notifications v2
// ====================
@@ -377,38 +384,30 @@ func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *A
}
func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
if err := h.safeDowngradeToFree(userID, "Apple expired"); err != nil {
return err
}
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
return nil
}
func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
if err := h.safeDowngradeToFree(userID, "Apple refund"); err != nil {
return err
}
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
return nil
}
func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
if err := h.safeDowngradeToFree(userID, "Apple revoke"); err != nil {
return err
}
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
return nil
}
func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint, tx *AppleTransactionInfo) error {
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
if err := h.safeDowngradeToFree(userID, "Apple grace period expired"); err != nil {
return err
}
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
return nil
}
@@ -705,22 +704,16 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
}
func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notification *GoogleSubscriptionNotification) error {
// Subscription revoked - immediate downgrade
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
if err := h.safeDowngradeToFree(userID, "Google revoke"); err != nil {
return err
}
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
return nil
}
func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notification *GoogleSubscriptionNotification) error {
// Subscription expired
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
if err := h.safeDowngradeToFree(userID, "Google expired"); err != nil {
return err
}
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
return nil
}
@@ -730,6 +723,88 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
return nil
}
// ====================
// Multi-Source Downgrade Safety
// ====================
// safeDowngradeToFree checks if the user has active subscriptions from other sources
// before downgrading to free. If another source is still active, skip the downgrade.
func (h *SubscriptionWebhookHandler) safeDowngradeToFree(userID uint, reason string) error {
sub, err := h.subscriptionRepo.FindByUserID(userID)
if err != nil {
log.Warn().Err(err).Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Could not find subscription for multi-source check, proceeding with downgrade")
return h.subscriptionRepo.DowngradeToFree(userID)
}
// Check if Stripe subscription is still active
if sub.HasStripeSubscription() && sub.Platform != models.PlatformStripe {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Stripe subscription")
return nil
}
// Check if Apple subscription is still active (for Google/Stripe webhooks)
if sub.HasAppleSubscription() && sub.Platform != models.PlatformIOS {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Apple subscription")
return nil
}
// Check if Google subscription is still active (for Apple/Stripe webhooks)
if sub.HasGoogleSubscription() && sub.Platform != models.PlatformAndroid {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Google subscription")
return nil
}
// Check if trial is still active
if sub.IsTrialActive() {
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active trial")
return nil
}
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
return err
}
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: User downgraded to free (no other active sources)")
return nil
}
// ====================
// Stripe Webhooks
// ====================
// HandleStripeWebhook handles POST /api/subscription/webhook/stripe/
func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error {
if !h.enabled {
log.Info().Msg("Stripe Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
if h.stripeService == nil {
log.Warn().Msg("Stripe Webhook: Stripe service not configured")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
}
signature := c.Request().Header.Get("Stripe-Signature")
if signature == "" {
log.Warn().Msg("Stripe Webhook: Missing Stripe-Signature header")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "missing signature"})
}
if err := h.stripeService.HandleWebhookEvent(body, signature); err != nil {
log.Error().Err(err).Msg("Stripe Webhook: Failed to process webhook")
// Still return 200 to prevent Stripe from retrying on business logic errors
// Only return error for signature verification failures
if strings.Contains(err.Error(), "signature") {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signature"})
}
}
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
}
// ====================
// Signature Verification (Optional but Recommended)
// ====================

View File

@@ -98,6 +98,10 @@ var specEndpointsKMPSkips = map[routeKey]bool{
{Method: "POST", Path: "/notifications/devices/"}: true, // KMP uses /notifications/devices/register/
{Method: "POST", Path: "/notifications/devices/unregister/"}: true, // KMP uses DELETE on device ID
{Method: "PATCH", Path: "/notifications/preferences/"}: true, // KMP uses PUT
// Stripe web-only and server-to-server endpoints — not implemented in mobile KMP
{Method: "POST", Path: "/subscription/checkout/"}: true, // Web-only (Stripe Checkout)
{Method: "POST", Path: "/subscription/portal/"}: true, // Web-only (Stripe Customer Portal)
{Method: "POST", Path: "/subscription/webhook/stripe/"}: true, // Server-to-server (Stripe webhook)
}
// kmpRouteAliases maps KMP paths to their canonical spec paths.

View File

@@ -76,7 +76,7 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo))
notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
// Create router with real middleware
e := echo.New()

View File

@@ -64,7 +64,7 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
// Create handlers
authHandler := handlers.NewAuthHandler(authService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
// Create router
e := echo.New()

View File

@@ -12,11 +12,20 @@ const (
TierPro SubscriptionTier = "pro"
)
// SubscriptionPlatform constants
const (
PlatformIOS = "ios"
PlatformAndroid = "android"
PlatformStripe = "stripe"
)
// SubscriptionSettings represents the subscription_subscriptionsettings table (singleton)
type SubscriptionSettings struct {
ID uint `gorm:"primaryKey" json:"id"`
EnableLimitations bool `gorm:"column:enable_limitations;default:false" json:"enable_limitations"`
EnableMonitoring bool `gorm:"column:enable_monitoring;default:true" json:"enable_monitoring"`
TrialEnabled bool `gorm:"column:trial_enabled;default:true" json:"trial_enabled"`
TrialDurationDays int `gorm:"column:trial_duration_days;default:14" json:"trial_duration_days"`
}
// TableName returns the table name for GORM
@@ -31,18 +40,28 @@ type UserSubscription struct {
User User `gorm:"foreignKey:UserID" json:"-"`
Tier SubscriptionTier `gorm:"column:tier;size:10;default:'free'" json:"tier"`
// In-App Purchase data
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
// In-App Purchase data (Apple / Google)
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
// Stripe data (web subscriptions)
StripeCustomerID *string `gorm:"column:stripe_customer_id;size:255" json:"-"`
StripeSubscriptionID *string `gorm:"column:stripe_subscription_id;size:255" json:"-"`
StripePriceID *string `gorm:"column:stripe_price_id;size:255" json:"-"`
// Subscription dates
SubscribedAt *time.Time `gorm:"column:subscribed_at" json:"subscribed_at"`
ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"`
AutoRenew bool `gorm:"column:auto_renew;default:true" json:"auto_renew"`
// Trial
TrialStart *time.Time `gorm:"column:trial_start" json:"trial_start"`
TrialEnd *time.Time `gorm:"column:trial_end" json:"trial_end"`
TrialUsed bool `gorm:"column:trial_used;default:false" json:"trial_used"`
// Tracking
CancelledAt *time.Time `gorm:"column:cancelled_at" json:"cancelled_at"`
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android, stripe
// Admin override - bypasses all limitations regardless of global settings
IsFree bool `gorm:"column:is_free;default:false" json:"is_free"`
@@ -53,8 +72,11 @@ func (UserSubscription) TableName() string {
return "subscription_usersubscription"
}
// IsActive returns true if the subscription is active (pro tier and not expired)
// IsActive returns true if the subscription is active (pro tier and not expired, or trial active)
func (s *UserSubscription) IsActive() bool {
if s.IsTrialActive() {
return true
}
if s.Tier != TierPro {
return false
}
@@ -64,9 +86,37 @@ func (s *UserSubscription) IsActive() bool {
return true
}
// IsPro returns true if the user has a pro subscription
// IsPro returns true if the user has a pro subscription or active trial
func (s *UserSubscription) IsPro() bool {
return s.Tier == TierPro && s.IsActive()
return s.IsActive()
}
// IsTrialActive returns true if the user has an active, unexpired trial
func (s *UserSubscription) IsTrialActive() bool {
if s.TrialEnd == nil {
return false
}
return time.Now().UTC().Before(*s.TrialEnd)
}
// HasStripeSubscription returns true if the user has Stripe subscription data
func (s *UserSubscription) HasStripeSubscription() bool {
return s.StripeSubscriptionID != nil && *s.StripeSubscriptionID != ""
}
// HasAppleSubscription returns true if the user has Apple receipt data
func (s *UserSubscription) HasAppleSubscription() bool {
return s.AppleReceiptData != nil && *s.AppleReceiptData != ""
}
// HasGoogleSubscription returns true if the user has Google purchase token
func (s *UserSubscription) HasGoogleSubscription() bool {
return s.GooglePurchaseToken != nil && *s.GooglePurchaseToken != ""
}
// SubscriptionSource returns the platform that the active subscription came from
func (s *UserSubscription) SubscriptionSource() string {
return s.Platform
}
// UpgradeTrigger represents the subscription_upgradetrigger table

View File

@@ -0,0 +1,187 @@
package models
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestIsTrialActive(t *testing.T) {
now := time.Now().UTC()
future := now.Add(24 * time.Hour)
past := now.Add(-24 * time.Hour)
tests := []struct {
name string
sub *UserSubscription
expected bool
}{
{
name: "trial_end in future returns true",
sub: &UserSubscription{TrialEnd: &future},
expected: true,
},
{
name: "trial_end in past returns false",
sub: &UserSubscription{TrialEnd: &past},
expected: false,
},
{
name: "trial_end nil returns false",
sub: &UserSubscription{TrialEnd: nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.sub.IsTrialActive()
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsPro(t *testing.T) {
now := time.Now().UTC()
future := now.Add(24 * time.Hour)
past := now.Add(-24 * time.Hour)
tests := []struct {
name string
sub *UserSubscription
expected bool
}{
{
name: "tier=pro, expires_at in future returns true",
sub: &UserSubscription{
Tier: TierPro,
ExpiresAt: &future,
},
expected: true,
},
{
name: "tier=pro, expires_at in past returns false",
sub: &UserSubscription{
Tier: TierPro,
ExpiresAt: &past,
},
expected: false,
},
{
name: "tier=free, trial active returns true",
sub: &UserSubscription{
Tier: TierFree,
TrialEnd: &future,
},
expected: true,
},
{
name: "tier=free, trial expired returns false",
sub: &UserSubscription{
Tier: TierFree,
TrialEnd: &past,
},
expected: false,
},
{
name: "tier=free, no trial returns false",
sub: &UserSubscription{
Tier: TierFree,
TrialEnd: nil,
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.sub.IsPro()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHasSubscriptionHelpers(t *testing.T) {
empty := ""
validStripeID := "sub_1234567890"
validReceipt := "MIIT..."
validToken := "google-purchase-token-123"
tests := []struct {
name string
sub *UserSubscription
method string
expected bool
}{
// HasStripeSubscription
{
name: "HasStripeSubscription with nil returns false",
sub: &UserSubscription{StripeSubscriptionID: nil},
method: "stripe",
expected: false,
},
{
name: "HasStripeSubscription with empty string returns false",
sub: &UserSubscription{StripeSubscriptionID: &empty},
method: "stripe",
expected: false,
},
{
name: "HasStripeSubscription with valid ID returns true",
sub: &UserSubscription{StripeSubscriptionID: &validStripeID},
method: "stripe",
expected: true,
},
// HasAppleSubscription
{
name: "HasAppleSubscription with nil returns false",
sub: &UserSubscription{AppleReceiptData: nil},
method: "apple",
expected: false,
},
{
name: "HasAppleSubscription with empty string returns false",
sub: &UserSubscription{AppleReceiptData: &empty},
method: "apple",
expected: false,
},
{
name: "HasAppleSubscription with valid receipt returns true",
sub: &UserSubscription{AppleReceiptData: &validReceipt},
method: "apple",
expected: true,
},
// HasGoogleSubscription
{
name: "HasGoogleSubscription with nil returns false",
sub: &UserSubscription{GooglePurchaseToken: nil},
method: "google",
expected: false,
},
{
name: "HasGoogleSubscription with empty string returns false",
sub: &UserSubscription{GooglePurchaseToken: &empty},
method: "google",
expected: false,
},
{
name: "HasGoogleSubscription with valid token returns true",
sub: &UserSubscription{GooglePurchaseToken: &validToken},
method: "google",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result bool
switch tt.method {
case "stripe":
result = tt.sub.HasStripeSubscription()
case "apple":
result = tt.sub.HasAppleSubscription()
case "google":
result = tt.sub.HasGoogleSubscription()
}
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -262,3 +262,59 @@ func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.P
}
return &promotion, nil
}
// === Stripe Lookups ===
// FindByStripeCustomerID finds a subscription by Stripe customer ID
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_customer_id": customerID,
"stripe_subscription_id": subscriptionID,
"stripe_price_id": priceID,
}).Error
}
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_subscription_id": nil,
"stripe_price_id": nil,
}).Error
}
// === Trial Management ===
// SetTrialDates sets the trial start, end, and marks trial as used
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"trial_start": trialStart,
"trial_end": trialEnd,
"trial_used": true,
}).Error
}
// UpdateExpiresAt updates the expires_at field for a user's subscription
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
Update("expires_at", expiresAt).Error
}

View File

@@ -2,9 +2,11 @@ package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
@@ -77,3 +79,150 @@ func TestGetOrCreate_Idempotent(t *testing.T) {
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
}
func TestFindByStripeCustomerID(t *testing.T) {
tests := []struct {
name string
customerID string
seedID string
wantErr bool
}{
{
name: "finds existing subscription by stripe customer ID",
customerID: "cus_test123",
seedID: "cus_test123",
wantErr: false,
},
{
name: "returns error for unknown stripe customer ID",
customerID: "cus_unknown999",
seedID: "cus_test456",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
StripeCustomerID: &tt.seedID,
}
err := db.Create(sub).Error
require.NoError(t, err)
found, err := repo.FindByStripeCustomerID(tt.customerID)
if tt.wantErr {
assert.Error(t, err)
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
} else {
require.NoError(t, err)
require.NotNil(t, found)
assert.Equal(t, user.ID, found.UserID)
assert.Equal(t, tt.seedID, *found.StripeCustomerID)
}
})
}
}
func TestUpdateStripeData(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
}
err := db.Create(sub).Error
require.NoError(t, err)
// Update all three Stripe fields
customerID := "cus_abc123"
subscriptionID := "sub_xyz789"
priceID := "price_monthly"
err = repo.UpdateStripeData(user.ID, customerID, subscriptionID, priceID)
require.NoError(t, err)
// Verify all three fields are set
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.StripeCustomerID)
require.NotNil(t, updated.StripeSubscriptionID)
require.NotNil(t, updated.StripePriceID)
assert.Equal(t, customerID, *updated.StripeCustomerID)
assert.Equal(t, subscriptionID, *updated.StripeSubscriptionID)
assert.Equal(t, priceID, *updated.StripePriceID)
// Now call ClearStripeData
err = repo.ClearStripeData(user.ID)
require.NoError(t, err)
// Verify subscription_id and price_id are cleared, customer_id preserved
var cleared models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&cleared).Error
require.NoError(t, err)
require.NotNil(t, cleared.StripeCustomerID, "customer_id should be preserved after ClearStripeData")
assert.Equal(t, customerID, *cleared.StripeCustomerID)
assert.Nil(t, cleared.StripeSubscriptionID, "subscription_id should be nil after ClearStripeData")
assert.Nil(t, cleared.StripePriceID, "price_id should be nil after ClearStripeData")
}
func TestSetTrialDates(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
TrialUsed: false,
}
err := db.Create(sub).Error
require.NoError(t, err)
trialStart := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
trialEnd := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
err = repo.SetTrialDates(user.ID, trialStart, trialEnd)
require.NoError(t, err)
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.TrialStart)
require.NotNil(t, updated.TrialEnd)
assert.True(t, updated.TrialUsed, "trial_used should be set to true")
assert.WithinDuration(t, trialStart, *updated.TrialStart, time.Second, "trial_start should match")
assert.WithinDuration(t, trialEnd, *updated.TrialEnd, time.Second, "trial_end should match")
}
func TestUpdateExpiresAt(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
}
err := db.Create(sub).Error
require.NoError(t, err)
newExpiry := time.Date(2027, 6, 15, 12, 0, 0, 0, time.UTC)
err = repo.UpdateExpiresAt(user.ID, newExpiry)
require.NoError(t, err)
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.ExpiresAt)
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
}

View File

@@ -58,7 +58,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
e.Use(custommiddleware.RequestIDMiddleware())
e.Use(utils.EchoRecovery())
e.Use(custommiddleware.StructuredLogger())
e.Use(middleware.BodyLimit("1M")) // 1MB default for JSON payloads
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
Limit: "1M", // 1MB default for JSON payloads
Skipper: func(c echo.Context) bool {
// Allow larger payloads for webhook endpoints (Apple/Google/Stripe notifications)
return strings.HasPrefix(c.Request().URL.Path, "/api/subscription/webhook")
},
}))
e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
Timeout: 30 * time.Second,
Skipper: func(c echo.Context) bool {
@@ -143,11 +149,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
residenceService.SetSubscriptionService(subscriptionService) // Wire up subscription service for tier limit enforcement
taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo)
// Initialize Stripe service
stripeService := services.NewStripeService(subscriptionRepo, userRepo)
// Initialize webhook event repo for deduplication
webhookEventRepo := repositories.NewWebhookEventRepository(deps.DB)
// Initialize webhook handler for Apple/Google subscription notifications
// Initialize webhook handler for Apple/Google/Stripe subscription notifications
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled)
subscriptionWebhookHandler.SetStripeService(stripeService)
// Initialize middleware
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
@@ -166,7 +176,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
contractorHandler := handlers.NewContractorHandler(contractorService)
documentHandler := handlers.NewDocumentHandler(documentService, deps.StorageService)
notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, stripeService)
staticDataHandler := handlers.NewStaticDataHandler(residenceService, taskService, contractorService, taskTemplateService, deps.Cache)
taskTemplateHandler := handlers.NewTaskTemplateHandler(taskTemplateService)
@@ -458,6 +468,8 @@ func setupSubscriptionRoutes(api *echo.Group, subscriptionHandler *handlers.Subs
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
subscription.POST("/cancel/", subscriptionHandler.CancelSubscription)
subscription.POST("/restore/", subscriptionHandler.RestoreSubscription)
subscription.POST("/checkout/", subscriptionHandler.CreateCheckoutSession)
subscription.POST("/portal/", subscriptionHandler.CreatePortalSession)
}
}
@@ -499,6 +511,7 @@ func setupWebhookRoutes(api *echo.Group, webhookHandler *handlers.SubscriptionWe
{
webhooks.POST("/apple/", webhookHandler.HandleAppleWebhook)
webhooks.POST("/google/", webhookHandler.HandleGoogleWebhook)
webhooks.POST("/stripe/", webhookHandler.HandleStripeWebhook)
}
}

View File

@@ -0,0 +1,456 @@
package services
import (
"encoding/json"
"fmt"
"os"
"time"
"github.com/rs/zerolog/log"
"github.com/stripe/stripe-go/v81"
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
"github.com/stripe/stripe-go/v81/customer"
"github.com/stripe/stripe-go/v81/webhook"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
)
// StripeService handles Stripe checkout, portal, and webhook processing
// for web-based subscription purchases.
type StripeService struct {
subscriptionRepo *repositories.SubscriptionRepository
userRepo *repositories.UserRepository
webhookSecret string
}
// NewStripeService creates a new Stripe service. It initializes the global
// Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key
// is not set, a warning is logged but the service is still returned (matching
// the pattern used by the Apple/Google IAP clients).
func NewStripeService(
subscriptionRepo *repositories.SubscriptionRepository,
userRepo *repositories.UserRepository,
) *StripeService {
key := os.Getenv("STRIPE_SECRET_KEY")
if key == "" {
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
} else {
stripe.Key = key
log.Info().Msg("Stripe API key configured")
}
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
if webhookSecret == "" {
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
}
return &StripeService{
subscriptionRepo: subscriptionRepo,
userRepo: userRepo,
webhookSecret: webhookSecret,
}
}
// CreateCheckoutSession creates a Stripe Checkout Session for a web subscription purchase.
// It ensures the user has a Stripe customer record and configures the session with a trial
// period if the user has not used their trial yet.
func (s *StripeService) CreateCheckoutSession(userID uint, priceID string, successURL string, cancelURL string) (string, error) {
// Get or create the user's subscription record
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return "", apperrors.Internal(err)
}
// Get the user's email for the Stripe customer
user, err := s.userRepo.FindByID(userID)
if err != nil {
return "", apperrors.Internal(err)
}
// Get or create a Stripe customer
stripeCustomerID, err := s.getOrCreateStripeCustomer(sub, user)
if err != nil {
return "", apperrors.Internal(err)
}
// Build the checkout session parameters
params := &stripe.CheckoutSessionParams{
Customer: stripe.String(stripeCustomerID),
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
SuccessURL: stripe.String(successURL),
CancelURL: stripe.String(cancelURL),
ClientReferenceID: stripe.String(fmt.Sprintf("%d", userID)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
}
// Offer a trial period if the user has not used their trial yet
if !sub.TrialUsed {
trialDays, err := s.getTrialDays()
if err != nil {
log.Warn().Err(err).Msg("Failed to get trial duration from settings, skipping trial")
} else if trialDays > 0 {
params.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{
TrialPeriodDays: stripe.Int64(int64(trialDays)),
}
}
}
session, err := checkoutsession.New(params)
if err != nil {
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe checkout session")
return "", apperrors.Internal(err)
}
log.Info().
Uint("user_id", userID).
Str("session_id", session.ID).
Str("price_id", priceID).
Msg("Stripe checkout session created")
return session.URL, nil
}
// CreatePortalSession creates a Stripe Customer Portal session so the user
// can manage their subscription (cancel, change plan, update payment method).
func (s *StripeService) CreatePortalSession(userID uint, returnURL string) (string, error) {
sub, err := s.subscriptionRepo.FindByUserID(userID)
if err != nil {
return "", apperrors.NotFound("error.subscription_not_found")
}
if sub.StripeCustomerID == nil || *sub.StripeCustomerID == "" {
return "", apperrors.BadRequest("error.no_stripe_customer")
}
params := &stripe.BillingPortalSessionParams{
Customer: sub.StripeCustomerID,
ReturnURL: stripe.String(returnURL),
}
session, err := portalsession.New(params)
if err != nil {
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe portal session")
return "", apperrors.Internal(err)
}
return session.URL, nil
}
// HandleWebhookEvent verifies and processes a Stripe webhook event.
// It handles checkout completion, subscription lifecycle changes, and invoice events.
func (s *StripeService) HandleWebhookEvent(payload []byte, signature string) error {
event, err := webhook.ConstructEvent(payload, signature, s.webhookSecret)
if err != nil {
log.Warn().Err(err).Msg("Stripe webhook signature verification failed")
return apperrors.BadRequest("error.invalid_webhook_signature")
}
log.Info().
Str("event_type", string(event.Type)).
Str("event_id", event.ID).
Msg("Processing Stripe webhook event")
switch event.Type {
case "checkout.session.completed":
return s.handleCheckoutCompleted(event)
case "customer.subscription.updated":
return s.handleSubscriptionUpdated(event)
case "customer.subscription.deleted":
return s.handleSubscriptionDeleted(event)
case "invoice.paid":
return s.handleInvoicePaid(event)
case "invoice.payment_failed":
return s.handleInvoicePaymentFailed(event)
default:
log.Debug().Str("event_type", string(event.Type)).Msg("Unhandled Stripe webhook event type")
return nil
}
}
// handleCheckoutCompleted processes a successful checkout session. It links the Stripe
// customer and subscription to the user's record and upgrades them to Pro.
func (s *StripeService) handleCheckoutCompleted(event stripe.Event) error {
var session stripe.CheckoutSession
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal checkout session from webhook")
return apperrors.Internal(err)
}
// Extract the user ID from client_reference_id
var userID uint
if _, err := fmt.Sscanf(session.ClientReferenceID, "%d", &userID); err != nil {
log.Error().Str("client_reference_id", session.ClientReferenceID).Msg("Invalid client_reference_id in checkout session")
return apperrors.BadRequest("error.invalid_client_reference_id")
}
// Get or create the subscription record
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return apperrors.Internal(err)
}
// Save Stripe customer and subscription IDs
if session.Customer != nil {
sub.StripeCustomerID = &session.Customer.ID
}
if session.Subscription != nil {
sub.StripeSubscriptionID = &session.Subscription.ID
}
if err := s.subscriptionRepo.Update(sub); err != nil {
return apperrors.Internal(err)
}
// Upgrade to Pro. Use a far-future expiry for now; the invoice.paid event
// will set the real period_end once the first invoice is finalized.
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err)
}
customerID := ""
if session.Customer != nil {
customerID = session.Customer.ID
}
subscriptionID := ""
if session.Subscription != nil {
subscriptionID = session.Subscription.ID
}
log.Info().
Uint("user_id", userID).
Str("stripe_customer_id", customerID).
Str("stripe_subscription_id", subscriptionID).
Msg("Checkout completed, user upgraded to Pro")
// TODO: Send push notification to user's devices when subscription activates
return nil
}
// handleSubscriptionUpdated processes subscription status changes. It upgrades or
// downgrades the user depending on the subscription's current status.
func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
var subscription stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
return apperrors.Internal(err)
}
sub, err := s.findSubscriptionByStripeID(subscription.ID)
if err != nil {
return err
}
switch subscription.Status {
case stripe.SubscriptionStatusActive, stripe.SubscriptionStatusTrialing:
// Subscription is healthy, ensure user is Pro
expiresAt := time.Unix(subscription.CurrentPeriodEnd, 0).UTC()
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err)
}
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active")
case stripe.SubscriptionStatusPastDue:
log.Warn().Uint("user_id", sub.UserID).Msg("Stripe subscription past due, waiting for retry")
// Don't downgrade yet; Stripe will retry the payment automatically.
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusUnpaid:
// Check if the user has active subscriptions from other sources before downgrading
if s.isActiveFromOtherSources(sub) {
log.Info().
Uint("user_id", sub.UserID).
Str("status", string(subscription.Status)).
Msg("Stripe subscription ended but user has other active sources, keeping Pro")
return nil
}
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
return apperrors.Internal(err)
}
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended")
}
return nil
}
// handleSubscriptionDeleted processes a subscription that has been fully cancelled
// and is no longer active. It downgrades the user unless they have active subscriptions
// from other sources (Apple, Google).
func (s *StripeService) handleSubscriptionDeleted(event stripe.Event) error {
var subscription stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
return apperrors.Internal(err)
}
sub, err := s.findSubscriptionByStripeID(subscription.ID)
if err != nil {
return err
}
// Check multi-source before downgrading
if s.isActiveFromOtherSources(sub) {
log.Info().
Uint("user_id", sub.UserID).
Msg("Stripe subscription deleted but user has other active sources, keeping Pro")
return nil
}
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
return apperrors.Internal(err)
}
log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted")
// TODO: Send push notification to user's devices about subscription ending
return nil
}
// handleInvoicePaid processes a successful invoice payment. It updates the subscription
// expiry to the current billing period's end date and ensures the user is on Pro.
func (s *StripeService) handleInvoicePaid(event stripe.Event) error {
var invoice stripe.Invoice
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
return apperrors.Internal(err)
}
// Only process subscription invoices
if invoice.Subscription == nil {
return nil
}
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
if err != nil {
return err
}
// Update expiry from the invoice's period end
expiresAt := time.Unix(invoice.PeriodEnd, 0).UTC()
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err)
}
log.Info().
Uint("user_id", sub.UserID).
Time("expires_at", expiresAt).
Msg("Invoice paid, subscription renewed")
return nil
}
// handleInvoicePaymentFailed logs a warning when a payment fails. We do not downgrade
// the user here because Stripe will automatically retry the payment according to its
// Smart Retries schedule.
func (s *StripeService) handleInvoicePaymentFailed(event stripe.Event) error {
var invoice stripe.Invoice
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
return apperrors.Internal(err)
}
if invoice.Subscription == nil {
return nil
}
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
if err != nil {
// If we can't find the subscription, just log and return
log.Warn().Str("stripe_subscription_id", invoice.Subscription.ID).Msg("Invoice payment failed for unknown subscription")
return nil
}
log.Warn().
Uint("user_id", sub.UserID).
Str("invoice_id", invoice.ID).
Msg("Stripe invoice payment failed, Stripe will retry automatically")
return nil
}
// isActiveFromOtherSources checks if the user has active subscriptions from Apple or Google
// that should prevent a downgrade when the Stripe subscription ends.
func (s *StripeService) isActiveFromOtherSources(sub *models.UserSubscription) bool {
now := time.Now().UTC()
// Check Apple subscription
if sub.HasAppleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
return true
}
// Check Google subscription
if sub.HasGoogleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
return true
}
// Check active trial
if sub.IsTrialActive() {
return true
}
return false
}
// getOrCreateStripeCustomer returns the existing Stripe customer ID from the subscription
// record, or creates a new Stripe customer and persists the ID.
func (s *StripeService) getOrCreateStripeCustomer(sub *models.UserSubscription, user *models.User) (string, error) {
// If we already have a Stripe customer, return it
if sub.StripeCustomerID != nil && *sub.StripeCustomerID != "" {
return *sub.StripeCustomerID, nil
}
// Create a new Stripe customer
params := &stripe.CustomerParams{
Email: stripe.String(user.Email),
Name: stripe.String(user.GetFullName()),
}
params.AddMetadata("casera_user_id", fmt.Sprintf("%d", user.ID))
c, err := customer.New(params)
if err != nil {
return "", fmt.Errorf("failed to create Stripe customer: %w", err)
}
// Save the customer ID to the subscription record
sub.StripeCustomerID = &c.ID
if err := s.subscriptionRepo.Update(sub); err != nil {
return "", fmt.Errorf("failed to save Stripe customer ID: %w", err)
}
log.Info().
Uint("user_id", user.ID).
Str("stripe_customer_id", c.ID).
Msg("Created new Stripe customer")
return c.ID, nil
}
// findSubscriptionByStripeID looks up a UserSubscription by its Stripe subscription ID.
func (s *StripeService) findSubscriptionByStripeID(stripeSubID string) (*models.UserSubscription, error) {
sub, err := s.subscriptionRepo.FindByStripeSubscriptionID(stripeSubID)
if err != nil {
log.Warn().Str("stripe_subscription_id", stripeSubID).Err(err).Msg("Subscription not found for Stripe ID")
return nil, apperrors.NotFound("error.subscription_not_found")
}
return sub, nil
}
// getTrialDays reads the trial duration from SubscriptionSettings.
func (s *StripeService) getTrialDays() (int, error) {
settings, err := s.subscriptionRepo.GetSettings()
if err != nil {
return 0, err
}
if !settings.TrialEnabled {
return 0, nil
}
return settings.TrialDurationDays, nil
}

View File

@@ -118,6 +118,20 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
return nil, apperrors.Internal(err)
}
// Auto-start trial for new users who have never had a trial
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
now := time.Now().UTC()
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
return nil, apperrors.Internal(err)
}
// Re-fetch after starting trial so response reflects the new state
sub, err = s.subscriptionRepo.FindByUserID(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
}
// Get all tier limits and build a map
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
if err != nil {
@@ -154,6 +168,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
// Build flattened response (KMM expects subscription fields at top level)
resp := &SubscriptionStatusResponse{
Tier: string(sub.Tier),
IsActive: sub.IsActive(),
AutoRenew: sub.AutoRenew,
Limits: limitsMap,
Usage: usage,
@@ -170,6 +186,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
resp.ExpiresAt = &t
}
// Populate trial fields
if sub.TrialStart != nil {
t := sub.TrialStart.Format("2006-01-02T15:04:05Z")
resp.TrialStart = &t
}
if sub.TrialEnd != nil {
t := sub.TrialEnd.Format("2006-01-02T15:04:05Z")
resp.TrialEnd = &t
}
resp.TrialActive = sub.IsTrialActive()
resp.SubscriptionSource = sub.SubscriptionSource()
return resp, nil
}
@@ -449,28 +477,48 @@ func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResp
return s.GetSubscription(userID)
}
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return false, "", apperrors.Internal(err)
}
if !sub.IsPro() {
return false, "", nil
}
if sub.Platform == requestedPlatform {
return false, "", nil
}
return true, sub.Platform, nil
}
// === Response Types ===
// SubscriptionResponse represents a subscription in API response
type SubscriptionResponse struct {
Tier string `json:"tier"`
SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"`
AutoRenew bool `json:"auto_renew"`
CancelledAt *string `json:"cancelled_at"`
Platform string `json:"platform"`
IsActive bool `json:"is_active"`
IsPro bool `json:"is_pro"`
Tier string `json:"tier"`
SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"`
AutoRenew bool `json:"auto_renew"`
CancelledAt *string `json:"cancelled_at"`
Platform string `json:"platform"`
IsActive bool `json:"is_active"`
IsPro bool `json:"is_pro"`
TrialActive bool `json:"trial_active"`
SubscriptionSource string `json:"subscription_source"`
}
// NewSubscriptionResponse creates a SubscriptionResponse from a model
func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse {
resp := &SubscriptionResponse{
Tier: string(s.Tier),
AutoRenew: s.AutoRenew,
Platform: s.Platform,
IsActive: s.IsActive(),
IsPro: s.IsPro(),
Tier: string(s.Tier),
AutoRenew: s.AutoRenew,
Platform: s.Platform,
IsActive: s.IsActive(),
IsPro: s.IsPro(),
TrialActive: s.IsTrialActive(),
SubscriptionSource: s.SubscriptionSource(),
}
if s.SubscribedAt != nil {
t := s.SubscribedAt.Format("2006-01-02T15:04:05Z")
@@ -536,11 +584,23 @@ func NewTierLimitsClientResponse(l *models.TierLimits) *TierLimitsClientResponse
// SubscriptionStatusResponse represents full subscription status
// Fields are flattened to match KMM client expectations
type SubscriptionStatusResponse struct {
// Tier and active status
Tier string `json:"tier"`
IsActive bool `json:"is_active"`
// Flattened subscription fields (KMM expects these at top level)
SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"`
AutoRenew bool `json:"auto_renew"`
// Trial fields
TrialStart *string `json:"trial_start,omitempty"`
TrialEnd *string `json:"trial_end,omitempty"`
TrialActive bool `json:"trial_active"`
// Subscription source
SubscriptionSource string `json:"subscription_source"`
// Other fields
Usage *UsageResponse `json:"usage"`
Limits map[string]*TierLimitsClientResponse `json:"limits"`
@@ -638,5 +698,5 @@ type ProcessPurchaseRequest struct {
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
PurchaseToken string `json:"purchase_token"` // Android
ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
Platform string `json:"platform" validate:"required,oneof=ios android"`
Platform string `json:"platform" validate:"required,oneof=ios android stripe"`
}

View File

@@ -2,6 +2,7 @@ package services
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -179,3 +180,94 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
}
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
future := time.Now().UTC().Add(30 * 24 * time.Hour)
tests := []struct {
name string
tier models.SubscriptionTier
platform string
expiresAt *time.Time
trialEnd *time.Time
requestedPlatform string
wantConflict bool
wantPlatform string
}{
{
name: "free user returns no conflict",
tier: models.TierFree,
platform: "",
expiresAt: nil,
trialEnd: nil,
requestedPlatform: "stripe",
wantConflict: false,
wantPlatform: "",
},
{
name: "pro from ios, requesting ios returns no conflict (same platform)",
tier: models.TierPro,
platform: "ios",
expiresAt: &future,
trialEnd: nil,
requestedPlatform: "ios",
wantConflict: false,
wantPlatform: "",
},
{
name: "pro from ios, requesting stripe returns conflict",
tier: models.TierPro,
platform: "ios",
expiresAt: &future,
trialEnd: nil,
requestedPlatform: "stripe",
wantConflict: true,
wantPlatform: "ios",
},
{
name: "pro from stripe, requesting android returns conflict",
tier: models.TierPro,
platform: "stripe",
expiresAt: &future,
trialEnd: nil,
requestedPlatform: "android",
wantConflict: true,
wantPlatform: "stripe",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
}
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: tt.tier,
Platform: tt.platform,
ExpiresAt: tt.expiresAt,
TrialEnd: tt.trialEnd,
}
err := db.Create(sub).Error
require.NoError(t, err)
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
require.NoError(t, err)
assert.Equal(t, tt.wantConflict, conflict)
assert.Equal(t, tt.wantPlatform, existingPlatform)
})
}
}