Add Stripe billing, free trials, and cross-platform subscription guards
- Stripe integration: add StripeService with checkout sessions, customer portal, and webhook handling for subscription lifecycle events. - Free trials: auto-start configurable trial on first subscription check, with admin-controllable duration and enable/disable toggle. - Cross-platform guard: prevent duplicate subscriptions across iOS, Android, and Stripe by checking existing platform before allowing purchase. - Subscription model: add Stripe fields (customer_id, subscription_id, price_id), trial fields (trial_start, trial_end, trial_used), and SubscriptionSource/IsTrialActive helpers. - API: add trial and source fields to status response, update OpenAPI spec. - Clean up stale migration and audit docs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -234,11 +234,13 @@ type UpdateNotificationRequest struct {
|
||||
// SubscriptionFilters holds subscription-specific filter parameters
|
||||
type SubscriptionFilters struct {
|
||||
PaginationParams
|
||||
UserID *uint `form:"user_id"`
|
||||
Tier *string `form:"tier"`
|
||||
Platform *string `form:"platform"`
|
||||
AutoRenew *bool `form:"auto_renew"`
|
||||
Active *bool `form:"active"`
|
||||
UserID *uint `form:"user_id"`
|
||||
Tier *string `form:"tier"`
|
||||
Platform *string `form:"platform"`
|
||||
AutoRenew *bool `form:"auto_renew"`
|
||||
Active *bool `form:"active"`
|
||||
HasStripe *bool `form:"has_stripe"`
|
||||
TrialActive *bool `form:"trial_active"`
|
||||
}
|
||||
|
||||
// UpdateSubscriptionRequest for updating a subscription
|
||||
@@ -250,6 +252,14 @@ type UpdateSubscriptionRequest struct {
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
// Stripe fields
|
||||
StripeCustomerID *string `json:"stripe_customer_id"`
|
||||
StripeSubscriptionID *string `json:"stripe_subscription_id"`
|
||||
StripePriceID *string `json:"stripe_price_id"`
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start"`
|
||||
TrialEnd *string `json:"trial_end"`
|
||||
TrialUsed *bool `json:"trial_used"`
|
||||
}
|
||||
|
||||
// CreateResidenceRequest for creating a new residence
|
||||
|
||||
@@ -264,7 +264,16 @@ type SubscriptionResponse struct {
|
||||
SubscribedAt *string `json:"subscribed_at,omitempty"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
CancelledAt *string `json:"cancelled_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
// Stripe fields
|
||||
StripeCustomerID *string `json:"stripe_customer_id,omitempty"`
|
||||
StripeSubscriptionID *string `json:"stripe_subscription_id,omitempty"`
|
||||
StripePriceID *string `json:"stripe_price_id,omitempty"`
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start,omitempty"`
|
||||
TrialEnd *string `json:"trial_end,omitempty"`
|
||||
TrialUsed bool `json:"trial_used"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// SubscriptionDetailResponse includes more details for single subscription view
|
||||
|
||||
@@ -30,6 +30,8 @@ func NewAdminSettingsHandler(db *gorm.DB) *AdminSettingsHandler {
|
||||
type SettingsResponse struct {
|
||||
EnableLimitations bool `json:"enable_limitations"`
|
||||
EnableMonitoring bool `json:"enable_monitoring"`
|
||||
TrialEnabled bool `json:"trial_enabled"`
|
||||
TrialDurationDays int `json:"trial_duration_days"`
|
||||
}
|
||||
|
||||
// GetSettings handles GET /api/admin/settings
|
||||
@@ -38,7 +40,13 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
||||
if err := h.db.First(&settings, 1).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// Create default settings
|
||||
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false, EnableMonitoring: true}
|
||||
settings = models.SubscriptionSettings{
|
||||
ID: 1,
|
||||
EnableLimitations: false,
|
||||
EnableMonitoring: true,
|
||||
TrialEnabled: true,
|
||||
TrialDurationDays: 14,
|
||||
}
|
||||
h.db.Create(&settings)
|
||||
} else {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
||||
@@ -48,6 +56,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, SettingsResponse{
|
||||
EnableLimitations: settings.EnableLimitations,
|
||||
EnableMonitoring: settings.EnableMonitoring,
|
||||
TrialEnabled: settings.TrialEnabled,
|
||||
TrialDurationDays: settings.TrialDurationDays,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -55,6 +65,8 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
||||
type UpdateSettingsRequest struct {
|
||||
EnableLimitations *bool `json:"enable_limitations"`
|
||||
EnableMonitoring *bool `json:"enable_monitoring"`
|
||||
TrialEnabled *bool `json:"trial_enabled"`
|
||||
TrialDurationDays *int `json:"trial_duration_days"`
|
||||
}
|
||||
|
||||
// UpdateSettings handles PUT /api/admin/settings
|
||||
@@ -67,7 +79,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||
var settings models.SubscriptionSettings
|
||||
if err := h.db.First(&settings, 1).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
settings = models.SubscriptionSettings{ID: 1, EnableMonitoring: true}
|
||||
settings = models.SubscriptionSettings{
|
||||
ID: 1,
|
||||
EnableMonitoring: true,
|
||||
TrialEnabled: true,
|
||||
TrialDurationDays: 14,
|
||||
}
|
||||
} else {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
||||
}
|
||||
@@ -81,6 +98,14 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||
settings.EnableMonitoring = *req.EnableMonitoring
|
||||
}
|
||||
|
||||
if req.TrialEnabled != nil {
|
||||
settings.TrialEnabled = *req.TrialEnabled
|
||||
}
|
||||
|
||||
if req.TrialDurationDays != nil {
|
||||
settings.TrialDurationDays = *req.TrialDurationDays
|
||||
}
|
||||
|
||||
if err := h.db.Save(&settings).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update settings"})
|
||||
}
|
||||
@@ -88,6 +113,8 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, SettingsResponse{
|
||||
EnableLimitations: settings.EnableLimitations,
|
||||
EnableMonitoring: settings.EnableMonitoring,
|
||||
TrialEnabled: settings.TrialEnabled,
|
||||
TrialDurationDays: settings.TrialDurationDays,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
@@ -61,6 +62,20 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
|
||||
query = query.Where("expires_at IS NOT NULL AND expires_at <= NOW()")
|
||||
}
|
||||
}
|
||||
if filters.HasStripe != nil {
|
||||
if *filters.HasStripe {
|
||||
query = query.Where("stripe_subscription_id IS NOT NULL")
|
||||
} else {
|
||||
query = query.Where("stripe_subscription_id IS NULL")
|
||||
}
|
||||
}
|
||||
if filters.TrialActive != nil {
|
||||
if *filters.TrialActive {
|
||||
query = query.Where("trial_end IS NOT NULL AND trial_end > NOW()")
|
||||
} else {
|
||||
query = query.Where("trial_end IS NULL OR trial_end <= NOW()")
|
||||
}
|
||||
}
|
||||
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
@@ -137,6 +152,32 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error {
|
||||
if req.IsFree != nil {
|
||||
subscription.IsFree = *req.IsFree
|
||||
}
|
||||
if req.StripeCustomerID != nil {
|
||||
subscription.StripeCustomerID = req.StripeCustomerID
|
||||
}
|
||||
if req.StripeSubscriptionID != nil {
|
||||
subscription.StripeSubscriptionID = req.StripeSubscriptionID
|
||||
}
|
||||
if req.StripePriceID != nil {
|
||||
subscription.StripePriceID = req.StripePriceID
|
||||
}
|
||||
if req.TrialStart != nil {
|
||||
parsed, err := time.Parse(time.RFC3339, *req.TrialStart)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_start format, expected RFC3339"})
|
||||
}
|
||||
subscription.TrialStart = &parsed
|
||||
}
|
||||
if req.TrialEnd != nil {
|
||||
parsed, err := time.Parse(time.RFC3339, *req.TrialEnd)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid trial_end format, expected RFC3339"})
|
||||
}
|
||||
subscription.TrialEnd = &parsed
|
||||
}
|
||||
if req.TrialUsed != nil {
|
||||
subscription.TrialUsed = *req.TrialUsed
|
||||
}
|
||||
|
||||
if err := h.db.Save(&subscription).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update subscription"})
|
||||
@@ -184,29 +225,39 @@ func (h *AdminSubscriptionHandler) GetByUser(c echo.Context) error {
|
||||
// GetStats handles GET /api/admin/subscriptions/stats
|
||||
func (h *AdminSubscriptionHandler) GetStats(c echo.Context) error {
|
||||
var total, free, premium, pro int64
|
||||
var stripeSubscribers, activeTrials int64
|
||||
|
||||
h.db.Model(&models.UserSubscription{}).Count(&total)
|
||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "free").Count(&free)
|
||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&premium)
|
||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&pro)
|
||||
h.db.Model(&models.UserSubscription{}).Where("stripe_subscription_id IS NOT NULL AND tier = ?", "pro").Count(&stripeSubscribers)
|
||||
h.db.Model(&models.UserSubscription{}).Where("trial_end IS NOT NULL AND trial_end > NOW()").Count(&activeTrials)
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"total": total,
|
||||
"free": free,
|
||||
"premium": premium,
|
||||
"pro": pro,
|
||||
"total": total,
|
||||
"free": free,
|
||||
"premium": premium,
|
||||
"pro": pro,
|
||||
"stripe_subscribers": stripeSubscribers,
|
||||
"active_trials": activeTrials,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscription) dto.SubscriptionResponse {
|
||||
response := dto.SubscriptionResponse{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
Tier: string(sub.Tier),
|
||||
Platform: sub.Platform,
|
||||
AutoRenew: sub.AutoRenew,
|
||||
IsFree: sub.IsFree,
|
||||
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
Tier: string(sub.Tier),
|
||||
Platform: sub.Platform,
|
||||
AutoRenew: sub.AutoRenew,
|
||||
IsFree: sub.IsFree,
|
||||
StripeCustomerID: sub.StripeCustomerID,
|
||||
StripeSubscriptionID: sub.StripeSubscriptionID,
|
||||
StripePriceID: sub.StripePriceID,
|
||||
TrialUsed: sub.TrialUsed,
|
||||
TrialActive: sub.IsTrialActive(),
|
||||
CreatedAt: sub.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
if sub.User.ID != 0 {
|
||||
@@ -225,6 +276,14 @@ func (h *AdminSubscriptionHandler) toSubscriptionResponse(sub *models.UserSubscr
|
||||
cancelledAt := sub.CancelledAt.Format("2006-01-02T15:04:05Z")
|
||||
response.CancelledAt = &cancelledAt
|
||||
}
|
||||
if sub.TrialStart != nil {
|
||||
trialStart := sub.TrialStart.Format(time.RFC3339)
|
||||
response.TrialStart = &trialStart
|
||||
}
|
||||
if sub.TrialEnd != nil {
|
||||
trialEnd := sub.TrialEnd.Format(time.RFC3339)
|
||||
response.TrialEnd = &trialEnd
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ type Config struct {
|
||||
GoogleAuth GoogleAuthConfig
|
||||
AppleIAP AppleIAPConfig
|
||||
GoogleIAP GoogleIAPConfig
|
||||
Stripe StripeConfig
|
||||
Features FeatureFlags
|
||||
}
|
||||
|
||||
@@ -104,6 +105,14 @@ type GoogleIAPConfig struct {
|
||||
PackageName string // Android package name (e.g., com.tt.casera)
|
||||
}
|
||||
|
||||
// StripeConfig holds Stripe payment configuration
|
||||
type StripeConfig struct {
|
||||
SecretKey string // Stripe secret API key
|
||||
WebhookSecret string // Stripe webhook endpoint signing secret
|
||||
PriceMonthly string // Stripe Price ID for monthly Pro subscription
|
||||
PriceYearly string // Stripe Price ID for yearly Pro subscription
|
||||
}
|
||||
|
||||
type WorkerConfig struct {
|
||||
// Scheduled job times (UTC)
|
||||
TaskReminderHour int
|
||||
@@ -248,6 +257,12 @@ func Load() (*Config, error) {
|
||||
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
|
||||
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
|
||||
},
|
||||
Stripe: StripeConfig{
|
||||
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
|
||||
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
|
||||
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
|
||||
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
|
||||
},
|
||||
Features: FeatureFlags{
|
||||
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
|
||||
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
|
||||
handler := NewSubscriptionHandler(subscriptionService)
|
||||
handler := NewSubscriptionHandler(subscriptionService, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
|
||||
@@ -13,11 +13,15 @@ import (
|
||||
// SubscriptionHandler handles subscription-related HTTP requests
|
||||
type SubscriptionHandler struct {
|
||||
subscriptionService *services.SubscriptionService
|
||||
stripeService *services.StripeService
|
||||
}
|
||||
|
||||
// NewSubscriptionHandler creates a new subscription handler
|
||||
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{subscriptionService: subscriptionService}
|
||||
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService, stripeService *services.StripeService) *SubscriptionHandler {
|
||||
return &SubscriptionHandler{
|
||||
subscriptionService: subscriptionService,
|
||||
stripeService: stripeService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSubscription handles GET /api/subscription/
|
||||
@@ -194,3 +198,82 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
||||
"subscription": subscription,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateCheckoutSession handles POST /api/subscription/checkout/
|
||||
// Creates a Stripe Checkout Session for web subscription purchases
|
||||
func (h *SubscriptionHandler) CreateCheckoutSession(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.stripeService == nil {
|
||||
return apperrors.BadRequest("error.stripe_not_configured")
|
||||
}
|
||||
|
||||
// Check if already Pro from another platform
|
||||
alreadyPro, existingPlatform, err := h.subscriptionService.IsAlreadyProFromOtherPlatform(user.ID, "stripe")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if alreadyPro {
|
||||
return c.JSON(http.StatusConflict, map[string]interface{}{
|
||||
"error": "error.already_subscribed_other_platform",
|
||||
"existing_platform": existingPlatform,
|
||||
"message": "You already have an active Pro subscription via " + existingPlatform + ". Manage it there to avoid double billing.",
|
||||
})
|
||||
}
|
||||
|
||||
var req struct {
|
||||
PriceID string `json:"price_id" validate:"required"`
|
||||
SuccessURL string `json:"success_url" validate:"required,url"`
|
||||
CancelURL string `json:"cancel_url" validate:"required,url"`
|
||||
}
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionURL, err := h.stripeService.CreateCheckoutSession(user.ID, req.PriceID, req.SuccessURL, req.CancelURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"checkout_url": sessionURL,
|
||||
})
|
||||
}
|
||||
|
||||
// CreatePortalSession handles POST /api/subscription/portal/
|
||||
// Creates a Stripe Customer Portal session for managing web subscriptions
|
||||
func (h *SubscriptionHandler) CreatePortalSession(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.stripeService == nil {
|
||||
return apperrors.BadRequest("error.stripe_not_configured")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
ReturnURL string `json:"return_url" validate:"required,url"`
|
||||
}
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
portalURL, err := h.stripeService.CreatePortalSession(user.ID, req.ReturnURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"portal_url": portalURL,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
// SubscriptionWebhookHandler handles subscription webhook callbacks
|
||||
@@ -28,6 +29,7 @@ type SubscriptionWebhookHandler struct {
|
||||
userRepo *repositories.UserRepository
|
||||
webhookEventRepo *repositories.WebhookEventRepository
|
||||
appleRootCerts []*x509.Certificate
|
||||
stripeService *services.StripeService
|
||||
enabled bool
|
||||
}
|
||||
|
||||
@@ -46,6 +48,11 @@ func NewSubscriptionWebhookHandler(
|
||||
}
|
||||
}
|
||||
|
||||
// SetStripeService sets the Stripe service for webhook handling
|
||||
func (h *SubscriptionWebhookHandler) SetStripeService(stripeService *services.StripeService) {
|
||||
h.stripeService = stripeService
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Apple App Store Server Notifications v2
|
||||
// ====================
|
||||
@@ -377,38 +384,30 @@ func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *A
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple expired"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple refund"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple revoke"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint, tx *AppleTransactionInfo) error {
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Apple grace period expired"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -705,22 +704,16 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Subscription revoked - immediate downgrade
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Google revoke"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Subscription expired
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
if err := h.safeDowngradeToFree(userID, "Google expired"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -730,6 +723,88 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Multi-Source Downgrade Safety
|
||||
// ====================
|
||||
|
||||
// safeDowngradeToFree checks if the user has active subscriptions from other sources
|
||||
// before downgrading to free. If another source is still active, skip the downgrade.
|
||||
func (h *SubscriptionWebhookHandler) safeDowngradeToFree(userID uint, reason string) error {
|
||||
sub, err := h.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Could not find subscription for multi-source check, proceeding with downgrade")
|
||||
return h.subscriptionRepo.DowngradeToFree(userID)
|
||||
}
|
||||
|
||||
// Check if Stripe subscription is still active
|
||||
if sub.HasStripeSubscription() && sub.Platform != models.PlatformStripe {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Stripe subscription")
|
||||
return nil
|
||||
}
|
||||
// Check if Apple subscription is still active (for Google/Stripe webhooks)
|
||||
if sub.HasAppleSubscription() && sub.Platform != models.PlatformIOS {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Apple subscription")
|
||||
return nil
|
||||
}
|
||||
// Check if Google subscription is still active (for Apple/Stripe webhooks)
|
||||
if sub.HasGoogleSubscription() && sub.Platform != models.PlatformAndroid {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active Google subscription")
|
||||
return nil
|
||||
}
|
||||
// Check if trial is still active
|
||||
if sub.IsTrialActive() {
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Skipping downgrade — user has active trial")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: User downgraded to free (no other active sources)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Stripe Webhooks
|
||||
// ====================
|
||||
|
||||
// HandleStripeWebhook handles POST /api/subscription/webhook/stripe/
|
||||
func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error {
|
||||
if !h.enabled {
|
||||
log.Info().Msg("Stripe Webhook: webhooks disabled by feature flag")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||
}
|
||||
|
||||
if h.stripeService == nil {
|
||||
log.Warn().Msg("Stripe Webhook: Stripe service not configured")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||
}
|
||||
|
||||
signature := c.Request().Header.Get("Stripe-Signature")
|
||||
if signature == "" {
|
||||
log.Warn().Msg("Stripe Webhook: Missing Stripe-Signature header")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "missing signature"})
|
||||
}
|
||||
|
||||
if err := h.stripeService.HandleWebhookEvent(body, signature); err != nil {
|
||||
log.Error().Err(err).Msg("Stripe Webhook: Failed to process webhook")
|
||||
// Still return 200 to prevent Stripe from retrying on business logic errors
|
||||
// Only return error for signature verification failures
|
||||
if strings.Contains(err.Error(), "signature") {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signature"})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
|
||||
}
|
||||
|
||||
// ====================
|
||||
// Signature Verification (Optional but Recommended)
|
||||
// ====================
|
||||
|
||||
@@ -98,6 +98,10 @@ var specEndpointsKMPSkips = map[routeKey]bool{
|
||||
{Method: "POST", Path: "/notifications/devices/"}: true, // KMP uses /notifications/devices/register/
|
||||
{Method: "POST", Path: "/notifications/devices/unregister/"}: true, // KMP uses DELETE on device ID
|
||||
{Method: "PATCH", Path: "/notifications/preferences/"}: true, // KMP uses PUT
|
||||
// Stripe web-only and server-to-server endpoints — not implemented in mobile KMP
|
||||
{Method: "POST", Path: "/subscription/checkout/"}: true, // Web-only (Stripe Checkout)
|
||||
{Method: "POST", Path: "/subscription/portal/"}: true, // Web-only (Stripe Customer Portal)
|
||||
{Method: "POST", Path: "/subscription/webhook/stripe/"}: true, // Server-to-server (Stripe webhook)
|
||||
}
|
||||
|
||||
// kmpRouteAliases maps KMP paths to their canonical spec paths.
|
||||
|
||||
@@ -76,7 +76,7 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
|
||||
taskHandler := handlers.NewTaskHandler(taskService, nil)
|
||||
contractorHandler := handlers.NewContractorHandler(services.NewContractorService(contractorRepo, residenceRepo))
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
|
||||
|
||||
// Create router with real middleware
|
||||
e := echo.New()
|
||||
|
||||
@@ -64,7 +64,7 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
|
||||
// Create handlers
|
||||
authHandler := handlers.NewAuthHandler(authService, nil, nil)
|
||||
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
|
||||
|
||||
// Create router
|
||||
e := echo.New()
|
||||
|
||||
@@ -12,11 +12,20 @@ const (
|
||||
TierPro SubscriptionTier = "pro"
|
||||
)
|
||||
|
||||
// SubscriptionPlatform constants
|
||||
const (
|
||||
PlatformIOS = "ios"
|
||||
PlatformAndroid = "android"
|
||||
PlatformStripe = "stripe"
|
||||
)
|
||||
|
||||
// SubscriptionSettings represents the subscription_subscriptionsettings table (singleton)
|
||||
type SubscriptionSettings struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
EnableLimitations bool `gorm:"column:enable_limitations;default:false" json:"enable_limitations"`
|
||||
EnableMonitoring bool `gorm:"column:enable_monitoring;default:true" json:"enable_monitoring"`
|
||||
TrialEnabled bool `gorm:"column:trial_enabled;default:true" json:"trial_enabled"`
|
||||
TrialDurationDays int `gorm:"column:trial_duration_days;default:14" json:"trial_duration_days"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
@@ -31,18 +40,28 @@ type UserSubscription struct {
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
Tier SubscriptionTier `gorm:"column:tier;size:10;default:'free'" json:"tier"`
|
||||
|
||||
// In-App Purchase data
|
||||
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
|
||||
// In-App Purchase data (Apple / Google)
|
||||
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
|
||||
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
|
||||
|
||||
// Stripe data (web subscriptions)
|
||||
StripeCustomerID *string `gorm:"column:stripe_customer_id;size:255" json:"-"`
|
||||
StripeSubscriptionID *string `gorm:"column:stripe_subscription_id;size:255" json:"-"`
|
||||
StripePriceID *string `gorm:"column:stripe_price_id;size:255" json:"-"`
|
||||
|
||||
// Subscription dates
|
||||
SubscribedAt *time.Time `gorm:"column:subscribed_at" json:"subscribed_at"`
|
||||
ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"`
|
||||
AutoRenew bool `gorm:"column:auto_renew;default:true" json:"auto_renew"`
|
||||
|
||||
// Trial
|
||||
TrialStart *time.Time `gorm:"column:trial_start" json:"trial_start"`
|
||||
TrialEnd *time.Time `gorm:"column:trial_end" json:"trial_end"`
|
||||
TrialUsed bool `gorm:"column:trial_used;default:false" json:"trial_used"`
|
||||
|
||||
// Tracking
|
||||
CancelledAt *time.Time `gorm:"column:cancelled_at" json:"cancelled_at"`
|
||||
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android
|
||||
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android, stripe
|
||||
|
||||
// Admin override - bypasses all limitations regardless of global settings
|
||||
IsFree bool `gorm:"column:is_free;default:false" json:"is_free"`
|
||||
@@ -53,8 +72,11 @@ func (UserSubscription) TableName() string {
|
||||
return "subscription_usersubscription"
|
||||
}
|
||||
|
||||
// IsActive returns true if the subscription is active (pro tier and not expired)
|
||||
// IsActive returns true if the subscription is active (pro tier and not expired, or trial active)
|
||||
func (s *UserSubscription) IsActive() bool {
|
||||
if s.IsTrialActive() {
|
||||
return true
|
||||
}
|
||||
if s.Tier != TierPro {
|
||||
return false
|
||||
}
|
||||
@@ -64,9 +86,37 @@ func (s *UserSubscription) IsActive() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsPro returns true if the user has a pro subscription
|
||||
// IsPro returns true if the user has a pro subscription or active trial
|
||||
func (s *UserSubscription) IsPro() bool {
|
||||
return s.Tier == TierPro && s.IsActive()
|
||||
return s.IsActive()
|
||||
}
|
||||
|
||||
// IsTrialActive returns true if the user has an active, unexpired trial
|
||||
func (s *UserSubscription) IsTrialActive() bool {
|
||||
if s.TrialEnd == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().UTC().Before(*s.TrialEnd)
|
||||
}
|
||||
|
||||
// HasStripeSubscription returns true if the user has Stripe subscription data
|
||||
func (s *UserSubscription) HasStripeSubscription() bool {
|
||||
return s.StripeSubscriptionID != nil && *s.StripeSubscriptionID != ""
|
||||
}
|
||||
|
||||
// HasAppleSubscription returns true if the user has Apple receipt data
|
||||
func (s *UserSubscription) HasAppleSubscription() bool {
|
||||
return s.AppleReceiptData != nil && *s.AppleReceiptData != ""
|
||||
}
|
||||
|
||||
// HasGoogleSubscription returns true if the user has Google purchase token
|
||||
func (s *UserSubscription) HasGoogleSubscription() bool {
|
||||
return s.GooglePurchaseToken != nil && *s.GooglePurchaseToken != ""
|
||||
}
|
||||
|
||||
// SubscriptionSource returns the platform that the active subscription came from
|
||||
func (s *UserSubscription) SubscriptionSource() string {
|
||||
return s.Platform
|
||||
}
|
||||
|
||||
// UpgradeTrigger represents the subscription_upgradetrigger table
|
||||
|
||||
187
internal/models/subscription_test.go
Normal file
187
internal/models/subscription_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsTrialActive(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(24 * time.Hour)
|
||||
past := now.Add(-24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "trial_end in future returns true",
|
||||
sub: &UserSubscription{TrialEnd: &future},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "trial_end in past returns false",
|
||||
sub: &UserSubscription{TrialEnd: &past},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "trial_end nil returns false",
|
||||
sub: &UserSubscription{TrialEnd: nil},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.sub.IsTrialActive()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPro(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(24 * time.Hour)
|
||||
past := now.Add(-24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "tier=pro, expires_at in future returns true",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierPro,
|
||||
ExpiresAt: &future,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tier=pro, expires_at in past returns false",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierPro,
|
||||
ExpiresAt: &past,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "tier=free, trial active returns true",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierFree,
|
||||
TrialEnd: &future,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tier=free, trial expired returns false",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierFree,
|
||||
TrialEnd: &past,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "tier=free, no trial returns false",
|
||||
sub: &UserSubscription{
|
||||
Tier: TierFree,
|
||||
TrialEnd: nil,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.sub.IsPro()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSubscriptionHelpers(t *testing.T) {
|
||||
empty := ""
|
||||
validStripeID := "sub_1234567890"
|
||||
validReceipt := "MIIT..."
|
||||
validToken := "google-purchase-token-123"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
method string
|
||||
expected bool
|
||||
}{
|
||||
// HasStripeSubscription
|
||||
{
|
||||
name: "HasStripeSubscription with nil returns false",
|
||||
sub: &UserSubscription{StripeSubscriptionID: nil},
|
||||
method: "stripe",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasStripeSubscription with empty string returns false",
|
||||
sub: &UserSubscription{StripeSubscriptionID: &empty},
|
||||
method: "stripe",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasStripeSubscription with valid ID returns true",
|
||||
sub: &UserSubscription{StripeSubscriptionID: &validStripeID},
|
||||
method: "stripe",
|
||||
expected: true,
|
||||
},
|
||||
// HasAppleSubscription
|
||||
{
|
||||
name: "HasAppleSubscription with nil returns false",
|
||||
sub: &UserSubscription{AppleReceiptData: nil},
|
||||
method: "apple",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasAppleSubscription with empty string returns false",
|
||||
sub: &UserSubscription{AppleReceiptData: &empty},
|
||||
method: "apple",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasAppleSubscription with valid receipt returns true",
|
||||
sub: &UserSubscription{AppleReceiptData: &validReceipt},
|
||||
method: "apple",
|
||||
expected: true,
|
||||
},
|
||||
// HasGoogleSubscription
|
||||
{
|
||||
name: "HasGoogleSubscription with nil returns false",
|
||||
sub: &UserSubscription{GooglePurchaseToken: nil},
|
||||
method: "google",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasGoogleSubscription with empty string returns false",
|
||||
sub: &UserSubscription{GooglePurchaseToken: &empty},
|
||||
method: "google",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HasGoogleSubscription with valid token returns true",
|
||||
sub: &UserSubscription{GooglePurchaseToken: &validToken},
|
||||
method: "google",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result bool
|
||||
switch tt.method {
|
||||
case "stripe":
|
||||
result = tt.sub.HasStripeSubscription()
|
||||
case "apple":
|
||||
result = tt.sub.HasAppleSubscription()
|
||||
case "google":
|
||||
result = tt.sub.HasGoogleSubscription()
|
||||
}
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -262,3 +262,59 @@ func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.P
|
||||
}
|
||||
return &promotion, nil
|
||||
}
|
||||
|
||||
// === Stripe Lookups ===
|
||||
|
||||
// FindByStripeCustomerID finds a subscription by Stripe customer ID
|
||||
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
|
||||
var sub models.UserSubscription
|
||||
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
|
||||
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
|
||||
var sub models.UserSubscription
|
||||
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
|
||||
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"stripe_customer_id": customerID,
|
||||
"stripe_subscription_id": subscriptionID,
|
||||
"stripe_price_id": priceID,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
|
||||
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"stripe_subscription_id": nil,
|
||||
"stripe_price_id": nil,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// === Trial Management ===
|
||||
|
||||
// SetTrialDates sets the trial start, end, and marks trial as used
|
||||
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"trial_start": trialStart,
|
||||
"trial_end": trialEnd,
|
||||
"trial_used": true,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateExpiresAt updates the expires_at field for a user's subscription
|
||||
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
|
||||
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
|
||||
Update("expires_at", expiresAt).Error
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
@@ -77,3 +79,150 @@ func TestGetOrCreate_Idempotent(t *testing.T) {
|
||||
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
|
||||
}
|
||||
|
||||
func TestFindByStripeCustomerID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
customerID string
|
||||
seedID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "finds existing subscription by stripe customer ID",
|
||||
customerID: "cus_test123",
|
||||
seedID: "cus_test123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns error for unknown stripe customer ID",
|
||||
customerID: "cus_unknown999",
|
||||
seedID: "cus_test456",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
StripeCustomerID: &tt.seedID,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByStripeCustomerID(tt.customerID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
assert.Equal(t, tt.seedID, *found.StripeCustomerID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateStripeData(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update all three Stripe fields
|
||||
customerID := "cus_abc123"
|
||||
subscriptionID := "sub_xyz789"
|
||||
priceID := "price_monthly"
|
||||
|
||||
err = repo.UpdateStripeData(user.ID, customerID, subscriptionID, priceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all three fields are set
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.StripeCustomerID)
|
||||
require.NotNil(t, updated.StripeSubscriptionID)
|
||||
require.NotNil(t, updated.StripePriceID)
|
||||
assert.Equal(t, customerID, *updated.StripeCustomerID)
|
||||
assert.Equal(t, subscriptionID, *updated.StripeSubscriptionID)
|
||||
assert.Equal(t, priceID, *updated.StripePriceID)
|
||||
|
||||
// Now call ClearStripeData
|
||||
err = repo.ClearStripeData(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription_id and price_id are cleared, customer_id preserved
|
||||
var cleared models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&cleared).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cleared.StripeCustomerID, "customer_id should be preserved after ClearStripeData")
|
||||
assert.Equal(t, customerID, *cleared.StripeCustomerID)
|
||||
assert.Nil(t, cleared.StripeSubscriptionID, "subscription_id should be nil after ClearStripeData")
|
||||
assert.Nil(t, cleared.StripePriceID, "price_id should be nil after ClearStripeData")
|
||||
}
|
||||
|
||||
func TestSetTrialDates(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
TrialUsed: false,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
trialStart := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
trialEnd := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
err = repo.SetTrialDates(user.ID, trialStart, trialEnd)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, updated.TrialStart)
|
||||
require.NotNil(t, updated.TrialEnd)
|
||||
assert.True(t, updated.TrialUsed, "trial_used should be set to true")
|
||||
assert.WithinDuration(t, trialStart, *updated.TrialStart, time.Second, "trial_start should match")
|
||||
assert.WithinDuration(t, trialEnd, *updated.TrialEnd, time.Second, "trial_end should match")
|
||||
}
|
||||
|
||||
func TestUpdateExpiresAt(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
newExpiry := time.Date(2027, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
err = repo.UpdateExpiresAt(user.ID, newExpiry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, updated.ExpiresAt)
|
||||
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
|
||||
}
|
||||
|
||||
@@ -58,7 +58,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
e.Use(custommiddleware.RequestIDMiddleware())
|
||||
e.Use(utils.EchoRecovery())
|
||||
e.Use(custommiddleware.StructuredLogger())
|
||||
e.Use(middleware.BodyLimit("1M")) // 1MB default for JSON payloads
|
||||
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
||||
Limit: "1M", // 1MB default for JSON payloads
|
||||
Skipper: func(c echo.Context) bool {
|
||||
// Allow larger payloads for webhook endpoints (Apple/Google/Stripe notifications)
|
||||
return strings.HasPrefix(c.Request().URL.Path, "/api/subscription/webhook")
|
||||
},
|
||||
}))
|
||||
e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
|
||||
Timeout: 30 * time.Second,
|
||||
Skipper: func(c echo.Context) bool {
|
||||
@@ -143,11 +149,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
residenceService.SetSubscriptionService(subscriptionService) // Wire up subscription service for tier limit enforcement
|
||||
taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo)
|
||||
|
||||
// Initialize Stripe service
|
||||
stripeService := services.NewStripeService(subscriptionRepo, userRepo)
|
||||
|
||||
// Initialize webhook event repo for deduplication
|
||||
webhookEventRepo := repositories.NewWebhookEventRepository(deps.DB)
|
||||
|
||||
// Initialize webhook handler for Apple/Google subscription notifications
|
||||
// Initialize webhook handler for Apple/Google/Stripe subscription notifications
|
||||
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled)
|
||||
subscriptionWebhookHandler.SetStripeService(stripeService)
|
||||
|
||||
// Initialize middleware
|
||||
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
|
||||
@@ -166,7 +176,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
contractorHandler := handlers.NewContractorHandler(contractorService)
|
||||
documentHandler := handlers.NewDocumentHandler(documentService, deps.StorageService)
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, stripeService)
|
||||
staticDataHandler := handlers.NewStaticDataHandler(residenceService, taskService, contractorService, taskTemplateService, deps.Cache)
|
||||
taskTemplateHandler := handlers.NewTaskTemplateHandler(taskTemplateService)
|
||||
|
||||
@@ -458,6 +468,8 @@ func setupSubscriptionRoutes(api *echo.Group, subscriptionHandler *handlers.Subs
|
||||
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
|
||||
subscription.POST("/cancel/", subscriptionHandler.CancelSubscription)
|
||||
subscription.POST("/restore/", subscriptionHandler.RestoreSubscription)
|
||||
subscription.POST("/checkout/", subscriptionHandler.CreateCheckoutSession)
|
||||
subscription.POST("/portal/", subscriptionHandler.CreatePortalSession)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -499,6 +511,7 @@ func setupWebhookRoutes(api *echo.Group, webhookHandler *handlers.SubscriptionWe
|
||||
{
|
||||
webhooks.POST("/apple/", webhookHandler.HandleAppleWebhook)
|
||||
webhooks.POST("/google/", webhookHandler.HandleGoogleWebhook)
|
||||
webhooks.POST("/stripe/", webhookHandler.HandleStripeWebhook)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
456
internal/services/stripe_service.go
Normal file
456
internal/services/stripe_service.go
Normal file
@@ -0,0 +1,456 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
|
||||
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
||||
"github.com/stripe/stripe-go/v81/customer"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
)
|
||||
|
||||
// StripeService handles Stripe checkout, portal, and webhook processing
|
||||
// for web-based subscription purchases.
|
||||
type StripeService struct {
|
||||
subscriptionRepo *repositories.SubscriptionRepository
|
||||
userRepo *repositories.UserRepository
|
||||
webhookSecret string
|
||||
}
|
||||
|
||||
// NewStripeService creates a new Stripe service. It initializes the global
|
||||
// Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key
|
||||
// is not set, a warning is logged but the service is still returned (matching
|
||||
// the pattern used by the Apple/Google IAP clients).
|
||||
func NewStripeService(
|
||||
subscriptionRepo *repositories.SubscriptionRepository,
|
||||
userRepo *repositories.UserRepository,
|
||||
) *StripeService {
|
||||
key := os.Getenv("STRIPE_SECRET_KEY")
|
||||
if key == "" {
|
||||
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
|
||||
} else {
|
||||
stripe.Key = key
|
||||
log.Info().Msg("Stripe API key configured")
|
||||
}
|
||||
|
||||
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
|
||||
if webhookSecret == "" {
|
||||
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
|
||||
}
|
||||
|
||||
return &StripeService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
userRepo: userRepo,
|
||||
webhookSecret: webhookSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCheckoutSession creates a Stripe Checkout Session for a web subscription purchase.
|
||||
// It ensures the user has a Stripe customer record and configures the session with a trial
|
||||
// period if the user has not used their trial yet.
|
||||
func (s *StripeService) CreateCheckoutSession(userID uint, priceID string, successURL string, cancelURL string) (string, error) {
|
||||
// Get or create the user's subscription record
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Get the user's email for the Stripe customer
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Get or create a Stripe customer
|
||||
stripeCustomerID, err := s.getOrCreateStripeCustomer(sub, user)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Build the checkout session parameters
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Customer: stripe.String(stripeCustomerID),
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
SuccessURL: stripe.String(successURL),
|
||||
CancelURL: stripe.String(cancelURL),
|
||||
ClientReferenceID: stripe.String(fmt.Sprintf("%d", userID)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(priceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Offer a trial period if the user has not used their trial yet
|
||||
if !sub.TrialUsed {
|
||||
trialDays, err := s.getTrialDays()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get trial duration from settings, skipping trial")
|
||||
} else if trialDays > 0 {
|
||||
params.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{
|
||||
TrialPeriodDays: stripe.Int64(int64(trialDays)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session, err := checkoutsession.New(params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe checkout session")
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", userID).
|
||||
Str("session_id", session.ID).
|
||||
Str("price_id", priceID).
|
||||
Msg("Stripe checkout session created")
|
||||
|
||||
return session.URL, nil
|
||||
}
|
||||
|
||||
// CreatePortalSession creates a Stripe Customer Portal session so the user
|
||||
// can manage their subscription (cancel, change plan, update payment method).
|
||||
func (s *StripeService) CreatePortalSession(userID uint, returnURL string) (string, error) {
|
||||
sub, err := s.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.NotFound("error.subscription_not_found")
|
||||
}
|
||||
|
||||
if sub.StripeCustomerID == nil || *sub.StripeCustomerID == "" {
|
||||
return "", apperrors.BadRequest("error.no_stripe_customer")
|
||||
}
|
||||
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: sub.StripeCustomerID,
|
||||
ReturnURL: stripe.String(returnURL),
|
||||
}
|
||||
|
||||
session, err := portalsession.New(params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe portal session")
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return session.URL, nil
|
||||
}
|
||||
|
||||
// HandleWebhookEvent verifies and processes a Stripe webhook event.
|
||||
// It handles checkout completion, subscription lifecycle changes, and invoice events.
|
||||
func (s *StripeService) HandleWebhookEvent(payload []byte, signature string) error {
|
||||
event, err := webhook.ConstructEvent(payload, signature, s.webhookSecret)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Stripe webhook signature verification failed")
|
||||
return apperrors.BadRequest("error.invalid_webhook_signature")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("event_type", string(event.Type)).
|
||||
Str("event_id", event.ID).
|
||||
Msg("Processing Stripe webhook event")
|
||||
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
return s.handleCheckoutCompleted(event)
|
||||
case "customer.subscription.updated":
|
||||
return s.handleSubscriptionUpdated(event)
|
||||
case "customer.subscription.deleted":
|
||||
return s.handleSubscriptionDeleted(event)
|
||||
case "invoice.paid":
|
||||
return s.handleInvoicePaid(event)
|
||||
case "invoice.payment_failed":
|
||||
return s.handleInvoicePaymentFailed(event)
|
||||
default:
|
||||
log.Debug().Str("event_type", string(event.Type)).Msg("Unhandled Stripe webhook event type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleCheckoutCompleted processes a successful checkout session. It links the Stripe
|
||||
// customer and subscription to the user's record and upgrades them to Pro.
|
||||
func (s *StripeService) handleCheckoutCompleted(event stripe.Event) error {
|
||||
var session stripe.CheckoutSession
|
||||
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal checkout session from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Extract the user ID from client_reference_id
|
||||
var userID uint
|
||||
if _, err := fmt.Sscanf(session.ClientReferenceID, "%d", &userID); err != nil {
|
||||
log.Error().Str("client_reference_id", session.ClientReferenceID).Msg("Invalid client_reference_id in checkout session")
|
||||
return apperrors.BadRequest("error.invalid_client_reference_id")
|
||||
}
|
||||
|
||||
// Get or create the subscription record
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Save Stripe customer and subscription IDs
|
||||
if session.Customer != nil {
|
||||
sub.StripeCustomerID = &session.Customer.ID
|
||||
}
|
||||
if session.Subscription != nil {
|
||||
sub.StripeSubscriptionID = &session.Subscription.ID
|
||||
}
|
||||
|
||||
if err := s.subscriptionRepo.Update(sub); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Upgrade to Pro. Use a far-future expiry for now; the invoice.paid event
|
||||
// will set the real period_end once the first invoice is finalized.
|
||||
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
|
||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
customerID := ""
|
||||
if session.Customer != nil {
|
||||
customerID = session.Customer.ID
|
||||
}
|
||||
subscriptionID := ""
|
||||
if session.Subscription != nil {
|
||||
subscriptionID = session.Subscription.ID
|
||||
}
|
||||
log.Info().
|
||||
Uint("user_id", userID).
|
||||
Str("stripe_customer_id", customerID).
|
||||
Str("stripe_subscription_id", subscriptionID).
|
||||
Msg("Checkout completed, user upgraded to Pro")
|
||||
|
||||
// TODO: Send push notification to user's devices when subscription activates
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionUpdated processes subscription status changes. It upgrades or
|
||||
// downgrades the user depending on the subscription's current status.
|
||||
func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
|
||||
var subscription stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch subscription.Status {
|
||||
case stripe.SubscriptionStatusActive, stripe.SubscriptionStatusTrialing:
|
||||
// Subscription is healthy, ensure user is Pro
|
||||
expiresAt := time.Unix(subscription.CurrentPeriodEnd, 0).UTC()
|
||||
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active")
|
||||
|
||||
case stripe.SubscriptionStatusPastDue:
|
||||
log.Warn().Uint("user_id", sub.UserID).Msg("Stripe subscription past due, waiting for retry")
|
||||
// Don't downgrade yet; Stripe will retry the payment automatically.
|
||||
|
||||
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusUnpaid:
|
||||
// Check if the user has active subscriptions from other sources before downgrading
|
||||
if s.isActiveFromOtherSources(sub) {
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Str("status", string(subscription.Status)).
|
||||
Msg("Stripe subscription ended but user has other active sources, keeping Pro")
|
||||
return nil
|
||||
}
|
||||
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionDeleted processes a subscription that has been fully cancelled
|
||||
// and is no longer active. It downgrades the user unless they have active subscriptions
|
||||
// from other sources (Apple, Google).
|
||||
func (s *StripeService) handleSubscriptionDeleted(event stripe.Event) error {
|
||||
var subscription stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check multi-source before downgrading
|
||||
if s.isActiveFromOtherSources(sub) {
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Msg("Stripe subscription deleted but user has other active sources, keeping Pro")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted")
|
||||
|
||||
// TODO: Send push notification to user's devices about subscription ending
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaid processes a successful invoice payment. It updates the subscription
|
||||
// expiry to the current billing period's end date and ensures the user is on Pro.
|
||||
func (s *StripeService) handleInvoicePaid(event stripe.Event) error {
|
||||
var invoice stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Only process subscription invoices
|
||||
if invoice.Subscription == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update expiry from the invoice's period end
|
||||
expiresAt := time.Unix(invoice.PeriodEnd, 0).UTC()
|
||||
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Time("expires_at", expiresAt).
|
||||
Msg("Invoice paid, subscription renewed")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaymentFailed logs a warning when a payment fails. We do not downgrade
|
||||
// the user here because Stripe will automatically retry the payment according to its
|
||||
// Smart Retries schedule.
|
||||
func (s *StripeService) handleInvoicePaymentFailed(event stripe.Event) error {
|
||||
var invoice stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
if invoice.Subscription == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
|
||||
if err != nil {
|
||||
// If we can't find the subscription, just log and return
|
||||
log.Warn().Str("stripe_subscription_id", invoice.Subscription.ID).Msg("Invoice payment failed for unknown subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warn().
|
||||
Uint("user_id", sub.UserID).
|
||||
Str("invoice_id", invoice.ID).
|
||||
Msg("Stripe invoice payment failed, Stripe will retry automatically")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isActiveFromOtherSources checks if the user has active subscriptions from Apple or Google
|
||||
// that should prevent a downgrade when the Stripe subscription ends.
|
||||
func (s *StripeService) isActiveFromOtherSources(sub *models.UserSubscription) bool {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Check Apple subscription
|
||||
if sub.HasAppleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check Google subscription
|
||||
if sub.HasGoogleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check active trial
|
||||
if sub.IsTrialActive() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getOrCreateStripeCustomer returns the existing Stripe customer ID from the subscription
|
||||
// record, or creates a new Stripe customer and persists the ID.
|
||||
func (s *StripeService) getOrCreateStripeCustomer(sub *models.UserSubscription, user *models.User) (string, error) {
|
||||
// If we already have a Stripe customer, return it
|
||||
if sub.StripeCustomerID != nil && *sub.StripeCustomerID != "" {
|
||||
return *sub.StripeCustomerID, nil
|
||||
}
|
||||
|
||||
// Create a new Stripe customer
|
||||
params := &stripe.CustomerParams{
|
||||
Email: stripe.String(user.Email),
|
||||
Name: stripe.String(user.GetFullName()),
|
||||
}
|
||||
params.AddMetadata("casera_user_id", fmt.Sprintf("%d", user.ID))
|
||||
|
||||
c, err := customer.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create Stripe customer: %w", err)
|
||||
}
|
||||
|
||||
// Save the customer ID to the subscription record
|
||||
sub.StripeCustomerID = &c.ID
|
||||
if err := s.subscriptionRepo.Update(sub); err != nil {
|
||||
return "", fmt.Errorf("failed to save Stripe customer ID: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", user.ID).
|
||||
Str("stripe_customer_id", c.ID).
|
||||
Msg("Created new Stripe customer")
|
||||
|
||||
return c.ID, nil
|
||||
}
|
||||
|
||||
// findSubscriptionByStripeID looks up a UserSubscription by its Stripe subscription ID.
|
||||
func (s *StripeService) findSubscriptionByStripeID(stripeSubID string) (*models.UserSubscription, error) {
|
||||
sub, err := s.subscriptionRepo.FindByStripeSubscriptionID(stripeSubID)
|
||||
if err != nil {
|
||||
log.Warn().Str("stripe_subscription_id", stripeSubID).Err(err).Msg("Subscription not found for Stripe ID")
|
||||
return nil, apperrors.NotFound("error.subscription_not_found")
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// getTrialDays reads the trial duration from SubscriptionSettings.
|
||||
func (s *StripeService) getTrialDays() (int, error) {
|
||||
settings, err := s.subscriptionRepo.GetSettings()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !settings.TrialEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
return settings.TrialDurationDays, nil
|
||||
}
|
||||
@@ -118,6 +118,20 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Auto-start trial for new users who have never had a trial
|
||||
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
|
||||
now := time.Now().UTC()
|
||||
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
|
||||
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
// Re-fetch after starting trial so response reflects the new state
|
||||
sub, err = s.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all tier limits and build a map
|
||||
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
|
||||
if err != nil {
|
||||
@@ -154,6 +168,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
|
||||
// Build flattened response (KMM expects subscription fields at top level)
|
||||
resp := &SubscriptionStatusResponse{
|
||||
Tier: string(sub.Tier),
|
||||
IsActive: sub.IsActive(),
|
||||
AutoRenew: sub.AutoRenew,
|
||||
Limits: limitsMap,
|
||||
Usage: usage,
|
||||
@@ -170,6 +186,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
resp.ExpiresAt = &t
|
||||
}
|
||||
|
||||
// Populate trial fields
|
||||
if sub.TrialStart != nil {
|
||||
t := sub.TrialStart.Format("2006-01-02T15:04:05Z")
|
||||
resp.TrialStart = &t
|
||||
}
|
||||
if sub.TrialEnd != nil {
|
||||
t := sub.TrialEnd.Format("2006-01-02T15:04:05Z")
|
||||
resp.TrialEnd = &t
|
||||
}
|
||||
resp.TrialActive = sub.IsTrialActive()
|
||||
resp.SubscriptionSource = sub.SubscriptionSource()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -449,28 +477,48 @@ func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResp
|
||||
return s.GetSubscription(userID)
|
||||
}
|
||||
|
||||
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
|
||||
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
|
||||
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return false, "", apperrors.Internal(err)
|
||||
}
|
||||
if !sub.IsPro() {
|
||||
return false, "", nil
|
||||
}
|
||||
if sub.Platform == requestedPlatform {
|
||||
return false, "", nil
|
||||
}
|
||||
return true, sub.Platform, nil
|
||||
}
|
||||
|
||||
// === Response Types ===
|
||||
|
||||
// SubscriptionResponse represents a subscription in API response
|
||||
type SubscriptionResponse struct {
|
||||
Tier string `json:"tier"`
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsPro bool `json:"is_pro"`
|
||||
Tier string `json:"tier"`
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsPro bool `json:"is_pro"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
SubscriptionSource string `json:"subscription_source"`
|
||||
}
|
||||
|
||||
// NewSubscriptionResponse creates a SubscriptionResponse from a model
|
||||
func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse {
|
||||
resp := &SubscriptionResponse{
|
||||
Tier: string(s.Tier),
|
||||
AutoRenew: s.AutoRenew,
|
||||
Platform: s.Platform,
|
||||
IsActive: s.IsActive(),
|
||||
IsPro: s.IsPro(),
|
||||
Tier: string(s.Tier),
|
||||
AutoRenew: s.AutoRenew,
|
||||
Platform: s.Platform,
|
||||
IsActive: s.IsActive(),
|
||||
IsPro: s.IsPro(),
|
||||
TrialActive: s.IsTrialActive(),
|
||||
SubscriptionSource: s.SubscriptionSource(),
|
||||
}
|
||||
if s.SubscribedAt != nil {
|
||||
t := s.SubscribedAt.Format("2006-01-02T15:04:05Z")
|
||||
@@ -536,11 +584,23 @@ func NewTierLimitsClientResponse(l *models.TierLimits) *TierLimitsClientResponse
|
||||
// SubscriptionStatusResponse represents full subscription status
|
||||
// Fields are flattened to match KMM client expectations
|
||||
type SubscriptionStatusResponse struct {
|
||||
// Tier and active status
|
||||
Tier string `json:"tier"`
|
||||
IsActive bool `json:"is_active"`
|
||||
|
||||
// Flattened subscription fields (KMM expects these at top level)
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start,omitempty"`
|
||||
TrialEnd *string `json:"trial_end,omitempty"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
|
||||
// Subscription source
|
||||
SubscriptionSource string `json:"subscription_source"`
|
||||
|
||||
// Other fields
|
||||
Usage *UsageResponse `json:"usage"`
|
||||
Limits map[string]*TierLimitsClientResponse `json:"limits"`
|
||||
@@ -638,5 +698,5 @@ type ProcessPurchaseRequest struct {
|
||||
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
|
||||
PurchaseToken string `json:"purchase_token"` // Android
|
||||
ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android"`
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android stripe"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -179,3 +180,94 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
|
||||
}
|
||||
|
||||
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tier models.SubscriptionTier
|
||||
platform string
|
||||
expiresAt *time.Time
|
||||
trialEnd *time.Time
|
||||
requestedPlatform string
|
||||
wantConflict bool
|
||||
wantPlatform string
|
||||
}{
|
||||
{
|
||||
name: "free user returns no conflict",
|
||||
tier: models.TierFree,
|
||||
platform: "",
|
||||
expiresAt: nil,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "stripe",
|
||||
wantConflict: false,
|
||||
wantPlatform: "",
|
||||
},
|
||||
{
|
||||
name: "pro from ios, requesting ios returns no conflict (same platform)",
|
||||
tier: models.TierPro,
|
||||
platform: "ios",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "ios",
|
||||
wantConflict: false,
|
||||
wantPlatform: "",
|
||||
},
|
||||
{
|
||||
name: "pro from ios, requesting stripe returns conflict",
|
||||
tier: models.TierPro,
|
||||
platform: "ios",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "stripe",
|
||||
wantConflict: true,
|
||||
wantPlatform: "ios",
|
||||
},
|
||||
{
|
||||
name: "pro from stripe, requesting android returns conflict",
|
||||
tier: models.TierPro,
|
||||
platform: "stripe",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "android",
|
||||
wantConflict: true,
|
||||
wantPlatform: "stripe",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: tt.tier,
|
||||
Platform: tt.platform,
|
||||
ExpiresAt: tt.expiresAt,
|
||||
TrialEnd: tt.trialEnd,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantConflict, conflict)
|
||||
assert.Equal(t, tt.wantPlatform, existingPlatform)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user