Add Stripe billing, free trials, and cross-platform subscription guards
- Stripe integration: add StripeService with checkout sessions, customer portal, and webhook handling for subscription lifecycle events. - Free trials: auto-start configurable trial on first subscription check, with admin-controllable duration and enable/disable toggle. - Cross-platform guard: prevent duplicate subscriptions across iOS, Android, and Stripe by checking existing platform before allowing purchase. - Subscription model: add Stripe fields (customer_id, subscription_id, price_id), trial fields (trial_start, trial_end, trial_used), and SubscriptionSource/IsTrialActive helpers. - API: add trial and source fields to status response, update OpenAPI spec. - Clean up stale migration and audit docs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
456
internal/services/stripe_service.go
Normal file
456
internal/services/stripe_service.go
Normal file
@@ -0,0 +1,456 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
|
||||
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
||||
"github.com/stripe/stripe-go/v81/customer"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
)
|
||||
|
||||
// StripeService handles Stripe checkout, portal, and webhook processing
|
||||
// for web-based subscription purchases.
|
||||
type StripeService struct {
|
||||
subscriptionRepo *repositories.SubscriptionRepository
|
||||
userRepo *repositories.UserRepository
|
||||
webhookSecret string
|
||||
}
|
||||
|
||||
// NewStripeService creates a new Stripe service. It initializes the global
|
||||
// Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key
|
||||
// is not set, a warning is logged but the service is still returned (matching
|
||||
// the pattern used by the Apple/Google IAP clients).
|
||||
func NewStripeService(
|
||||
subscriptionRepo *repositories.SubscriptionRepository,
|
||||
userRepo *repositories.UserRepository,
|
||||
) *StripeService {
|
||||
key := os.Getenv("STRIPE_SECRET_KEY")
|
||||
if key == "" {
|
||||
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
|
||||
} else {
|
||||
stripe.Key = key
|
||||
log.Info().Msg("Stripe API key configured")
|
||||
}
|
||||
|
||||
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
|
||||
if webhookSecret == "" {
|
||||
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
|
||||
}
|
||||
|
||||
return &StripeService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
userRepo: userRepo,
|
||||
webhookSecret: webhookSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCheckoutSession creates a Stripe Checkout Session for a web subscription purchase.
|
||||
// It ensures the user has a Stripe customer record and configures the session with a trial
|
||||
// period if the user has not used their trial yet.
|
||||
func (s *StripeService) CreateCheckoutSession(userID uint, priceID string, successURL string, cancelURL string) (string, error) {
|
||||
// Get or create the user's subscription record
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Get the user's email for the Stripe customer
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Get or create a Stripe customer
|
||||
stripeCustomerID, err := s.getOrCreateStripeCustomer(sub, user)
|
||||
if err != nil {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Build the checkout session parameters
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Customer: stripe.String(stripeCustomerID),
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
SuccessURL: stripe.String(successURL),
|
||||
CancelURL: stripe.String(cancelURL),
|
||||
ClientReferenceID: stripe.String(fmt.Sprintf("%d", userID)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(priceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Offer a trial period if the user has not used their trial yet
|
||||
if !sub.TrialUsed {
|
||||
trialDays, err := s.getTrialDays()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get trial duration from settings, skipping trial")
|
||||
} else if trialDays > 0 {
|
||||
params.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{
|
||||
TrialPeriodDays: stripe.Int64(int64(trialDays)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session, err := checkoutsession.New(params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe checkout session")
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", userID).
|
||||
Str("session_id", session.ID).
|
||||
Str("price_id", priceID).
|
||||
Msg("Stripe checkout session created")
|
||||
|
||||
return session.URL, nil
|
||||
}
|
||||
|
||||
// CreatePortalSession creates a Stripe Customer Portal session so the user
|
||||
// can manage their subscription (cancel, change plan, update payment method).
|
||||
func (s *StripeService) CreatePortalSession(userID uint, returnURL string) (string, error) {
|
||||
sub, err := s.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return "", apperrors.NotFound("error.subscription_not_found")
|
||||
}
|
||||
|
||||
if sub.StripeCustomerID == nil || *sub.StripeCustomerID == "" {
|
||||
return "", apperrors.BadRequest("error.no_stripe_customer")
|
||||
}
|
||||
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: sub.StripeCustomerID,
|
||||
ReturnURL: stripe.String(returnURL),
|
||||
}
|
||||
|
||||
session, err := portalsession.New(params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to create Stripe portal session")
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return session.URL, nil
|
||||
}
|
||||
|
||||
// HandleWebhookEvent verifies and processes a Stripe webhook event.
|
||||
// It handles checkout completion, subscription lifecycle changes, and invoice events.
|
||||
func (s *StripeService) HandleWebhookEvent(payload []byte, signature string) error {
|
||||
event, err := webhook.ConstructEvent(payload, signature, s.webhookSecret)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Stripe webhook signature verification failed")
|
||||
return apperrors.BadRequest("error.invalid_webhook_signature")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("event_type", string(event.Type)).
|
||||
Str("event_id", event.ID).
|
||||
Msg("Processing Stripe webhook event")
|
||||
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
return s.handleCheckoutCompleted(event)
|
||||
case "customer.subscription.updated":
|
||||
return s.handleSubscriptionUpdated(event)
|
||||
case "customer.subscription.deleted":
|
||||
return s.handleSubscriptionDeleted(event)
|
||||
case "invoice.paid":
|
||||
return s.handleInvoicePaid(event)
|
||||
case "invoice.payment_failed":
|
||||
return s.handleInvoicePaymentFailed(event)
|
||||
default:
|
||||
log.Debug().Str("event_type", string(event.Type)).Msg("Unhandled Stripe webhook event type")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleCheckoutCompleted processes a successful checkout session. It links the Stripe
|
||||
// customer and subscription to the user's record and upgrades them to Pro.
|
||||
func (s *StripeService) handleCheckoutCompleted(event stripe.Event) error {
|
||||
var session stripe.CheckoutSession
|
||||
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal checkout session from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Extract the user ID from client_reference_id
|
||||
var userID uint
|
||||
if _, err := fmt.Sscanf(session.ClientReferenceID, "%d", &userID); err != nil {
|
||||
log.Error().Str("client_reference_id", session.ClientReferenceID).Msg("Invalid client_reference_id in checkout session")
|
||||
return apperrors.BadRequest("error.invalid_client_reference_id")
|
||||
}
|
||||
|
||||
// Get or create the subscription record
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Save Stripe customer and subscription IDs
|
||||
if session.Customer != nil {
|
||||
sub.StripeCustomerID = &session.Customer.ID
|
||||
}
|
||||
if session.Subscription != nil {
|
||||
sub.StripeSubscriptionID = &session.Subscription.ID
|
||||
}
|
||||
|
||||
if err := s.subscriptionRepo.Update(sub); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Upgrade to Pro. Use a far-future expiry for now; the invoice.paid event
|
||||
// will set the real period_end once the first invoice is finalized.
|
||||
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
|
||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
customerID := ""
|
||||
if session.Customer != nil {
|
||||
customerID = session.Customer.ID
|
||||
}
|
||||
subscriptionID := ""
|
||||
if session.Subscription != nil {
|
||||
subscriptionID = session.Subscription.ID
|
||||
}
|
||||
log.Info().
|
||||
Uint("user_id", userID).
|
||||
Str("stripe_customer_id", customerID).
|
||||
Str("stripe_subscription_id", subscriptionID).
|
||||
Msg("Checkout completed, user upgraded to Pro")
|
||||
|
||||
// TODO: Send push notification to user's devices when subscription activates
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionUpdated processes subscription status changes. It upgrades or
|
||||
// downgrades the user depending on the subscription's current status.
|
||||
func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
|
||||
var subscription stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch subscription.Status {
|
||||
case stripe.SubscriptionStatusActive, stripe.SubscriptionStatusTrialing:
|
||||
// Subscription is healthy, ensure user is Pro
|
||||
expiresAt := time.Unix(subscription.CurrentPeriodEnd, 0).UTC()
|
||||
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active")
|
||||
|
||||
case stripe.SubscriptionStatusPastDue:
|
||||
log.Warn().Uint("user_id", sub.UserID).Msg("Stripe subscription past due, waiting for retry")
|
||||
// Don't downgrade yet; Stripe will retry the payment automatically.
|
||||
|
||||
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusUnpaid:
|
||||
// Check if the user has active subscriptions from other sources before downgrading
|
||||
if s.isActiveFromOtherSources(sub) {
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Str("status", string(subscription.Status)).
|
||||
Msg("Stripe subscription ended but user has other active sources, keeping Pro")
|
||||
return nil
|
||||
}
|
||||
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionDeleted processes a subscription that has been fully cancelled
|
||||
// and is no longer active. It downgrades the user unless they have active subscriptions
|
||||
// from other sources (Apple, Google).
|
||||
func (s *StripeService) handleSubscriptionDeleted(event stripe.Event) error {
|
||||
var subscription stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal subscription from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check multi-source before downgrading
|
||||
if s.isActiveFromOtherSources(sub) {
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Msg("Stripe subscription deleted but user has other active sources, keeping Pro")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted")
|
||||
|
||||
// TODO: Send push notification to user's devices about subscription ending
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaid processes a successful invoice payment. It updates the subscription
|
||||
// expiry to the current billing period's end date and ensures the user is on Pro.
|
||||
func (s *StripeService) handleInvoicePaid(event stripe.Event) error {
|
||||
var invoice stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Only process subscription invoices
|
||||
if invoice.Subscription == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update expiry from the invoice's period end
|
||||
expiresAt := time.Unix(invoice.PeriodEnd, 0).UTC()
|
||||
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", sub.UserID).
|
||||
Time("expires_at", expiresAt).
|
||||
Msg("Invoice paid, subscription renewed")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaymentFailed logs a warning when a payment fails. We do not downgrade
|
||||
// the user here because Stripe will automatically retry the payment according to its
|
||||
// Smart Retries schedule.
|
||||
func (s *StripeService) handleInvoicePaymentFailed(event stripe.Event) error {
|
||||
var invoice stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &invoice); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal invoice from webhook")
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
if invoice.Subscription == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub, err := s.findSubscriptionByStripeID(invoice.Subscription.ID)
|
||||
if err != nil {
|
||||
// If we can't find the subscription, just log and return
|
||||
log.Warn().Str("stripe_subscription_id", invoice.Subscription.ID).Msg("Invoice payment failed for unknown subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warn().
|
||||
Uint("user_id", sub.UserID).
|
||||
Str("invoice_id", invoice.ID).
|
||||
Msg("Stripe invoice payment failed, Stripe will retry automatically")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isActiveFromOtherSources checks if the user has active subscriptions from Apple or Google
|
||||
// that should prevent a downgrade when the Stripe subscription ends.
|
||||
func (s *StripeService) isActiveFromOtherSources(sub *models.UserSubscription) bool {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Check Apple subscription
|
||||
if sub.HasAppleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check Google subscription
|
||||
if sub.HasGoogleSubscription() && sub.Tier == models.TierPro && sub.ExpiresAt != nil && now.Before(*sub.ExpiresAt) && sub.Platform != models.PlatformStripe {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check active trial
|
||||
if sub.IsTrialActive() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getOrCreateStripeCustomer returns the existing Stripe customer ID from the subscription
|
||||
// record, or creates a new Stripe customer and persists the ID.
|
||||
func (s *StripeService) getOrCreateStripeCustomer(sub *models.UserSubscription, user *models.User) (string, error) {
|
||||
// If we already have a Stripe customer, return it
|
||||
if sub.StripeCustomerID != nil && *sub.StripeCustomerID != "" {
|
||||
return *sub.StripeCustomerID, nil
|
||||
}
|
||||
|
||||
// Create a new Stripe customer
|
||||
params := &stripe.CustomerParams{
|
||||
Email: stripe.String(user.Email),
|
||||
Name: stripe.String(user.GetFullName()),
|
||||
}
|
||||
params.AddMetadata("casera_user_id", fmt.Sprintf("%d", user.ID))
|
||||
|
||||
c, err := customer.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create Stripe customer: %w", err)
|
||||
}
|
||||
|
||||
// Save the customer ID to the subscription record
|
||||
sub.StripeCustomerID = &c.ID
|
||||
if err := s.subscriptionRepo.Update(sub); err != nil {
|
||||
return "", fmt.Errorf("failed to save Stripe customer ID: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint("user_id", user.ID).
|
||||
Str("stripe_customer_id", c.ID).
|
||||
Msg("Created new Stripe customer")
|
||||
|
||||
return c.ID, nil
|
||||
}
|
||||
|
||||
// findSubscriptionByStripeID looks up a UserSubscription by its Stripe subscription ID.
|
||||
func (s *StripeService) findSubscriptionByStripeID(stripeSubID string) (*models.UserSubscription, error) {
|
||||
sub, err := s.subscriptionRepo.FindByStripeSubscriptionID(stripeSubID)
|
||||
if err != nil {
|
||||
log.Warn().Str("stripe_subscription_id", stripeSubID).Err(err).Msg("Subscription not found for Stripe ID")
|
||||
return nil, apperrors.NotFound("error.subscription_not_found")
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// getTrialDays reads the trial duration from SubscriptionSettings.
|
||||
func (s *StripeService) getTrialDays() (int, error) {
|
||||
settings, err := s.subscriptionRepo.GetSettings()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !settings.TrialEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
return settings.TrialDurationDays, nil
|
||||
}
|
||||
@@ -118,6 +118,20 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Auto-start trial for new users who have never had a trial
|
||||
if !sub.TrialUsed && sub.TrialEnd == nil && settings.TrialEnabled {
|
||||
now := time.Now().UTC()
|
||||
trialEnd := now.Add(time.Duration(settings.TrialDurationDays) * 24 * time.Hour)
|
||||
if err := s.subscriptionRepo.SetTrialDates(userID, now, trialEnd); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
// Re-fetch after starting trial so response reflects the new state
|
||||
sub, err = s.subscriptionRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all tier limits and build a map
|
||||
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
|
||||
if err != nil {
|
||||
@@ -154,6 +168,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
|
||||
// Build flattened response (KMM expects subscription fields at top level)
|
||||
resp := &SubscriptionStatusResponse{
|
||||
Tier: string(sub.Tier),
|
||||
IsActive: sub.IsActive(),
|
||||
AutoRenew: sub.AutoRenew,
|
||||
Limits: limitsMap,
|
||||
Usage: usage,
|
||||
@@ -170,6 +186,18 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
resp.ExpiresAt = &t
|
||||
}
|
||||
|
||||
// Populate trial fields
|
||||
if sub.TrialStart != nil {
|
||||
t := sub.TrialStart.Format("2006-01-02T15:04:05Z")
|
||||
resp.TrialStart = &t
|
||||
}
|
||||
if sub.TrialEnd != nil {
|
||||
t := sub.TrialEnd.Format("2006-01-02T15:04:05Z")
|
||||
resp.TrialEnd = &t
|
||||
}
|
||||
resp.TrialActive = sub.IsTrialActive()
|
||||
resp.SubscriptionSource = sub.SubscriptionSource()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -449,28 +477,48 @@ func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResp
|
||||
return s.GetSubscription(userID)
|
||||
}
|
||||
|
||||
// IsAlreadyProFromOtherPlatform checks if a user already has an active Pro subscription
|
||||
// from a different platform than the one being requested. Returns (conflict, existingPlatform, error).
|
||||
func (s *SubscriptionService) IsAlreadyProFromOtherPlatform(userID uint, requestedPlatform string) (bool, string, error) {
|
||||
sub, err := s.subscriptionRepo.GetOrCreate(userID)
|
||||
if err != nil {
|
||||
return false, "", apperrors.Internal(err)
|
||||
}
|
||||
if !sub.IsPro() {
|
||||
return false, "", nil
|
||||
}
|
||||
if sub.Platform == requestedPlatform {
|
||||
return false, "", nil
|
||||
}
|
||||
return true, sub.Platform, nil
|
||||
}
|
||||
|
||||
// === Response Types ===
|
||||
|
||||
// SubscriptionResponse represents a subscription in API response
|
||||
type SubscriptionResponse struct {
|
||||
Tier string `json:"tier"`
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsPro bool `json:"is_pro"`
|
||||
Tier string `json:"tier"`
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
CancelledAt *string `json:"cancelled_at"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsPro bool `json:"is_pro"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
SubscriptionSource string `json:"subscription_source"`
|
||||
}
|
||||
|
||||
// NewSubscriptionResponse creates a SubscriptionResponse from a model
|
||||
func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse {
|
||||
resp := &SubscriptionResponse{
|
||||
Tier: string(s.Tier),
|
||||
AutoRenew: s.AutoRenew,
|
||||
Platform: s.Platform,
|
||||
IsActive: s.IsActive(),
|
||||
IsPro: s.IsPro(),
|
||||
Tier: string(s.Tier),
|
||||
AutoRenew: s.AutoRenew,
|
||||
Platform: s.Platform,
|
||||
IsActive: s.IsActive(),
|
||||
IsPro: s.IsPro(),
|
||||
TrialActive: s.IsTrialActive(),
|
||||
SubscriptionSource: s.SubscriptionSource(),
|
||||
}
|
||||
if s.SubscribedAt != nil {
|
||||
t := s.SubscribedAt.Format("2006-01-02T15:04:05Z")
|
||||
@@ -536,11 +584,23 @@ func NewTierLimitsClientResponse(l *models.TierLimits) *TierLimitsClientResponse
|
||||
// SubscriptionStatusResponse represents full subscription status
|
||||
// Fields are flattened to match KMM client expectations
|
||||
type SubscriptionStatusResponse struct {
|
||||
// Tier and active status
|
||||
Tier string `json:"tier"`
|
||||
IsActive bool `json:"is_active"`
|
||||
|
||||
// Flattened subscription fields (KMM expects these at top level)
|
||||
SubscribedAt *string `json:"subscribed_at"`
|
||||
ExpiresAt *string `json:"expires_at"`
|
||||
AutoRenew bool `json:"auto_renew"`
|
||||
|
||||
// Trial fields
|
||||
TrialStart *string `json:"trial_start,omitempty"`
|
||||
TrialEnd *string `json:"trial_end,omitempty"`
|
||||
TrialActive bool `json:"trial_active"`
|
||||
|
||||
// Subscription source
|
||||
SubscriptionSource string `json:"subscription_source"`
|
||||
|
||||
// Other fields
|
||||
Usage *UsageResponse `json:"usage"`
|
||||
Limits map[string]*TierLimitsClientResponse `json:"limits"`
|
||||
@@ -638,5 +698,5 @@ type ProcessPurchaseRequest struct {
|
||||
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
|
||||
PurchaseToken string `json:"purchase_token"` // Android
|
||||
ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android"`
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android stripe"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -179,3 +180,94 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
|
||||
}
|
||||
|
||||
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tier models.SubscriptionTier
|
||||
platform string
|
||||
expiresAt *time.Time
|
||||
trialEnd *time.Time
|
||||
requestedPlatform string
|
||||
wantConflict bool
|
||||
wantPlatform string
|
||||
}{
|
||||
{
|
||||
name: "free user returns no conflict",
|
||||
tier: models.TierFree,
|
||||
platform: "",
|
||||
expiresAt: nil,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "stripe",
|
||||
wantConflict: false,
|
||||
wantPlatform: "",
|
||||
},
|
||||
{
|
||||
name: "pro from ios, requesting ios returns no conflict (same platform)",
|
||||
tier: models.TierPro,
|
||||
platform: "ios",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "ios",
|
||||
wantConflict: false,
|
||||
wantPlatform: "",
|
||||
},
|
||||
{
|
||||
name: "pro from ios, requesting stripe returns conflict",
|
||||
tier: models.TierPro,
|
||||
platform: "ios",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "stripe",
|
||||
wantConflict: true,
|
||||
wantPlatform: "ios",
|
||||
},
|
||||
{
|
||||
name: "pro from stripe, requesting android returns conflict",
|
||||
tier: models.TierPro,
|
||||
platform: "stripe",
|
||||
expiresAt: &future,
|
||||
trialEnd: nil,
|
||||
requestedPlatform: "android",
|
||||
wantConflict: true,
|
||||
wantPlatform: "stripe",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: tt.tier,
|
||||
Platform: tt.platform,
|
||||
ExpiresAt: tt.expiresAt,
|
||||
TrialEnd: tt.trialEnd,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
conflict, existingPlatform, err := svc.IsAlreadyProFromOtherPlatform(user.ID, tt.requestedPlatform)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantConflict, conflict)
|
||||
assert.Equal(t, tt.wantPlatform, existingPlatform)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user