package repositories import ( "time" "gorm.io/gorm" "github.com/treytartt/casera-api/internal/models" ) // SubscriptionRepository handles database operations for subscriptions type SubscriptionRepository struct { db *gorm.DB } // NewSubscriptionRepository creates a new subscription repository func NewSubscriptionRepository(db *gorm.DB) *SubscriptionRepository { return &SubscriptionRepository{db: db} } // === User Subscription === // FindByUserID finds a subscription by user ID func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscription, error) { var sub models.UserSubscription err := r.db.Where("user_id = ?", userID).First(&sub).Error if err != nil { return nil, err } return &sub, nil } // GetOrCreate gets or creates a subscription for a user (defaults to free tier) func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) { sub, err := r.FindByUserID(userID) if err == nil { return sub, nil } if err == gorm.ErrRecordNotFound { sub = &models.UserSubscription{ UserID: userID, Tier: models.TierFree, AutoRenew: true, } if err := r.db.Create(sub).Error; err != nil { return nil, err } return sub, nil } return nil, err } // Update updates a subscription func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error { return r.db.Save(sub).Error } // UpgradeToPro upgrades a user to Pro tier func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error { now := time.Now().UTC() return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Updates(map[string]interface{}{ "tier": models.TierPro, "subscribed_at": now, "expires_at": expiresAt, "cancelled_at": nil, "platform": platform, "auto_renew": true, }).Error } // DowngradeToFree downgrades a user to Free tier func (r *SubscriptionRepository) DowngradeToFree(userID uint) error { now := time.Now().UTC() return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Updates(map[string]interface{}{ "tier": models.TierFree, "cancelled_at": now, "auto_renew": false, }).Error } // SetAutoRenew sets the auto-renew flag func (r *SubscriptionRepository) SetAutoRenew(userID uint, autoRenew bool) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("auto_renew", autoRenew).Error } // UpdateReceiptData updates the Apple receipt data func (r *SubscriptionRepository) UpdateReceiptData(userID uint, receiptData string) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("apple_receipt_data", receiptData).Error } // UpdatePurchaseToken updates the Google purchase token func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) error { return r.db.Model(&models.UserSubscription{}). Where("user_id = ?", userID). Update("google_purchase_token", token).Error } // === Tier Limits === // GetTierLimits gets the limits for a subscription tier func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*models.TierLimits, error) { var limits models.TierLimits err := r.db.Where("tier = ?", tier).First(&limits).Error if err != nil { if err == gorm.ErrRecordNotFound { // Return defaults if tier == models.TierFree { defaults := models.GetDefaultFreeLimits() return &defaults, nil } defaults := models.GetDefaultProLimits() return &defaults, nil } return nil, err } return &limits, nil } // GetAllTierLimits gets all tier limits func (r *SubscriptionRepository) GetAllTierLimits() ([]models.TierLimits, error) { var limits []models.TierLimits err := r.db.Find(&limits).Error return limits, err } // === Subscription Settings (Singleton) === // GetSettings gets the subscription settings func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, error) { var settings models.SubscriptionSettings err := r.db.First(&settings).Error if err != nil { if err == gorm.ErrRecordNotFound { // Return default settings (limitations disabled) return &models.SubscriptionSettings{ EnableLimitations: false, }, nil } return nil, err } return &settings, nil } // === Upgrade Triggers === // GetUpgradeTrigger gets an upgrade trigger by key func (r *SubscriptionRepository) GetUpgradeTrigger(key string) (*models.UpgradeTrigger, error) { var trigger models.UpgradeTrigger err := r.db.Where("trigger_key = ? AND is_active = ?", key, true).First(&trigger).Error if err != nil { return nil, err } return &trigger, nil } // GetAllUpgradeTriggers gets all active upgrade triggers func (r *SubscriptionRepository) GetAllUpgradeTriggers() ([]models.UpgradeTrigger, error) { var triggers []models.UpgradeTrigger err := r.db.Where("is_active = ?", true).Find(&triggers).Error return triggers, err } // === Feature Benefits === // GetFeatureBenefits gets all active feature benefits func (r *SubscriptionRepository) GetFeatureBenefits() ([]models.FeatureBenefit, error) { var benefits []models.FeatureBenefit err := r.db.Where("is_active = ?", true).Order("display_order").Find(&benefits).Error return benefits, err } // === Promotions === // GetActivePromotions gets all currently active promotions for a tier func (r *SubscriptionRepository) GetActivePromotions(tier models.SubscriptionTier) ([]models.Promotion, error) { now := time.Now().UTC() var promotions []models.Promotion err := r.db.Where("is_active = ? AND target_tier = ? AND start_date <= ? AND end_date >= ?", true, tier, now, now). Order("start_date DESC"). Find(&promotions).Error return promotions, err } // GetPromotionByID gets a promotion by ID func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.Promotion, error) { var promotion models.Promotion err := r.db.Where("promotion_id = ? AND is_active = ?", promotionID, true).First(&promotion).Error if err != nil { return nil, err } return &promotion, nil }