package repositories import ( "errors" "time" "gorm.io/gorm" "gorm.io/gorm/clause" "github.com/treytartt/honeydue-api/internal/models" ) // SubscriptionRepository handles database operations for subscriptions type SubscriptionRepository struct { db *gorm.DB } // NewSubscriptionRepository creates a new subscription repository func NewSubscriptionRepository(db *gorm.DB) *SubscriptionRepository { return &SubscriptionRepository{db: db} } // === User Subscription === // FindByUserID finds a subscription by user ID func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscription, error) { var sub models.UserSubscription err := r.db.Where("user_id = ?", userID).First(&sub).Error if err != nil { return nil, err } return &sub, nil } // GetOrCreate gets or creates a subscription for a user (defaults to free tier). // Uses a transaction to avoid TOCTOU race conditions on concurrent requests. func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) { var sub models.UserSubscription err := r.db.Transaction(func(tx *gorm.DB) error { err := tx.Where("user_id = ?", userID).First(&sub).Error if err == nil { return nil // Found existing subscription } if !errors.Is(err, gorm.ErrRecordNotFound) { return err // Unexpected error } // Record not found -- create with free tier defaults sub = models.UserSubscription{ UserID: userID, Tier: models.TierFree, AutoRenew: true, } return tx.Create(&sub).Error }) if err != nil { return nil, err } return &sub, nil } // Update updates a subscription func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error { return r.db.Omit("User").Save(sub).Error } // UpgradeToPro upgrades a user to Pro tier using a transaction with row locking // to prevent concurrent subscription mutations from corrupting state. func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error { return r.db.Transaction(func(tx *gorm.DB) error { // Lock the row for update var sub models.UserSubscription if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Where("user_id = ?", userID).First(&sub).Error; err != nil { return err } now := time.Now().UTC() return tx.Model(&sub).Updates(map[string]interface{}{ "tier": models.TierPro, "subscribed_at": now, "expires_at": expiresAt, "cancelled_at": nil, "platform": platform, "auto_renew": true, }).Error }) } // DowngradeToFree downgrades a user to Free tier using a transaction with row locking // to prevent concurrent subscription mutations from corrupting state. func (r *SubscriptionRepository) DowngradeToFree(userID uint) error { return r.db.Transaction(func(tx *gorm.DB) error { // Lock the row for update var sub models.UserSubscription if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Where("user_id = ?", userID).First(&sub).Error; err != nil { return err } now := time.Now().UTC() return tx.Model(&sub).Updates(map[string]interface{}{ "tier": models.TierFree, "cancelled_at": now, "auto_renew": false, }).Error }) } // SetAutoRenew sets the auto-renew flag func (r *SubscriptionRepository) SetAutoRenew(userID uint, autoRenew bool) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("auto_renew", autoRenew).Error } // UpdateReceiptData updates the Apple receipt data func (r *SubscriptionRepository) UpdateReceiptData(userID uint, receiptData string) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("apple_receipt_data", receiptData).Error } // UpdatePurchaseToken updates the Google purchase token func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("google_purchase_token", token).Error } // FindByAppleReceiptContains finds a subscription by Apple transaction ID. // Used by webhooks to find the user associated with a transaction. // // PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which // cannot use a B-tree index and results in a full table scan. For better // performance at scale, add a dedicated indexed column: // // AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"` // // Then look up by exact match: WHERE apple_transaction_id = ? func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) { var sub models.UserSubscription // Escape LIKE wildcards in the transaction ID to prevent wildcard injection escaped := escapeLikeWildcards(transactionID) err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error if err != nil { return nil, err } return &sub, nil } // FindByGoogleToken finds a subscription by Google purchase token // Used by webhooks to find the user associated with a purchase func (r *SubscriptionRepository) FindByGoogleToken(purchaseToken string) (*models.UserSubscription, error) { var sub models.UserSubscription err := r.db.Where("google_purchase_token = ?", purchaseToken).First(&sub).Error if err != nil { return nil, err } return &sub, nil } // SetCancelledAt sets the cancellation timestamp func (r *SubscriptionRepository) SetCancelledAt(userID uint, cancelledAt time.Time) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("cancelled_at", cancelledAt).Error } // ClearCancelledAt clears the cancellation timestamp (user resubscribed) func (r *SubscriptionRepository) ClearCancelledAt(userID uint) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("cancelled_at", nil).Error } // === Tier Limits === // GetTierLimits gets the limits for a subscription tier func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*models.TierLimits, error) { var limits models.TierLimits err := r.db.Where("tier = ?", tier).First(&limits).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return defaults if tier == models.TierFree { defaults := models.GetDefaultFreeLimits() return &defaults, nil } defaults := models.GetDefaultProLimits() return &defaults, nil } return nil, err } return &limits, nil } // GetAllTierLimits gets all tier limits func (r *SubscriptionRepository) GetAllTierLimits() ([]models.TierLimits, error) { var limits []models.TierLimits err := r.db.Find(&limits).Error return limits, err } // === Subscription Settings (Singleton) === // GetSettings gets the subscription settings func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, error) { var settings models.SubscriptionSettings err := r.db.First(&settings).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // Return default settings (limitations disabled) return &models.SubscriptionSettings{ EnableLimitations: false, }, nil } return nil, err } return &settings, nil } // === Upgrade Triggers === // GetUpgradeTrigger gets an upgrade trigger by key func (r *SubscriptionRepository) GetUpgradeTrigger(key string) (*models.UpgradeTrigger, error) { var trigger models.UpgradeTrigger err := r.db.Where("trigger_key = ? AND is_active = ?", key, true).First(&trigger).Error if err != nil { return nil, err } return &trigger, nil } // GetAllUpgradeTriggers gets all active upgrade triggers func (r *SubscriptionRepository) GetAllUpgradeTriggers() ([]models.UpgradeTrigger, error) { var triggers []models.UpgradeTrigger err := r.db.Where("is_active = ?", true).Find(&triggers).Error return triggers, err } // === Feature Benefits === // GetFeatureBenefits gets all active feature benefits func (r *SubscriptionRepository) GetFeatureBenefits() ([]models.FeatureBenefit, error) { var benefits []models.FeatureBenefit err := r.db.Where("is_active = ?", true).Order("display_order").Find(&benefits).Error return benefits, err } // === Promotions === // GetActivePromotions gets all currently active promotions for a tier func (r *SubscriptionRepository) GetActivePromotions(tier models.SubscriptionTier) ([]models.Promotion, error) { now := time.Now().UTC() var promotions []models.Promotion err := r.db.Where("is_active = ? AND target_tier = ? AND start_date <= ? AND end_date >= ?", true, tier, now, now). Order("start_date DESC"). Find(&promotions).Error return promotions, err } // GetPromotionByID gets a promotion by ID func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.Promotion, error) { var promotion models.Promotion err := r.db.Where("promotion_id = ? AND is_active = ?", promotionID, true).First(&promotion).Error if err != nil { return nil, err } return &promotion, nil } // === Stripe Lookups === // FindByStripeCustomerID finds a subscription by Stripe customer ID func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) { var sub models.UserSubscription err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error if err != nil { return nil, err } return &sub, nil } // FindByStripeSubscriptionID finds a subscription by Stripe subscription ID func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) { var sub models.UserSubscription err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error if err != nil { return nil, err } return &sub, nil } // UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error { return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{ "stripe_customer_id": customerID, "stripe_subscription_id": subscriptionID, "stripe_price_id": priceID, }).Error } // ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access) func (r *SubscriptionRepository) ClearStripeData(userID uint) error { return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{ "stripe_subscription_id": nil, "stripe_price_id": nil, }).Error } // === Trial Management === // SetTrialDates sets the trial start, end, and marks trial as used func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error { return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{ "trial_start": trialStart, "trial_end": trialEnd, "trial_used": true, }).Error } // UpdateExpiresAt updates the expires_at field for a user's subscription func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error { return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID). Update("expires_at", expiresAt).Error }