Security: - Replace all binding: tags with validate: + c.Validate() in admin handlers - Add rate limiting to auth endpoints (login, register, password reset) - Add security headers (HSTS, XSS protection, nosniff, frame options) - Wire Google Pub/Sub token verification into webhook handler - Replace ParseUnverified with proper OIDC/JWKS key verification - Verify inner Apple JWS signatures in webhook handler - Add io.LimitReader (1MB) to all webhook body reads - Add ownership verification to file deletion - Move hardcoded admin credentials to env vars - Add uniqueIndex to User.Email - Hide ConfirmationCode from JSON serialization - Mask confirmation codes in admin responses - Use http.DetectContentType for upload validation - Fix path traversal in storage service - Replace os.Getenv with Viper in stripe service - Sanitize Redis URLs before logging - Separate DEBUG_FIXED_CODES from DEBUG flag - Reject weak SECRET_KEY in production - Add host check on /_next/* proxy routes - Use explicit localhost CORS origins in debug mode - Replace err.Error() with generic messages in all admin error responses Critical fixes: - Rewrite FCM to HTTP v1 API with OAuth 2.0 service account auth - Fix user_customuser -> auth_user table names in raw SQL - Fix dashboard verified query to use UserProfile model - Add escapeLikeWildcards() to prevent SQL wildcard injection Bug fixes: - Add bounds checks for days/expiring_soon query params (1-3650) - Add receipt_data/transaction_id empty-check to RestoreSubscription - Change Active bool -> *bool in device handler - Check all unchecked GORM/FindByIDWithProfile errors - Add validation for notification hour fields (0-23) - Add max=10000 validation on task description updates Transactions & data integrity: - Wrap registration flow in transaction - Wrap QuickComplete in transaction - Move image creation inside completion transaction - Wrap SetSpecialties in transaction - Wrap GetOrCreateToken in transaction - Wrap completion+image deletion in transaction Performance: - Batch completion summaries (2 queries vs 2N) - Reuse single http.Client in IAP validation - Cache dashboard counts (30s TTL) - Batch COUNT queries in admin user list - Add Limit(500) to document queries - Add reminder_stage+due_date filters to reminder queries - Parse AllowedTypes once at init - In-memory user cache in auth middleware (30s TTL) - Timezone change detection cache - Optimize P95 with per-endpoint sorted buffers - Replace crypto/md5 with hash/fnv for ETags Code quality: - Add sync.Once to all monitoring Stop()/Close() methods - Replace 8 fmt.Printf with zerolog in auth service - Log previously discarded errors - Standardize delete response shapes - Route hardcoded English through i18n - Remove FileURL from DocumentResponse (keep MediaURL only) - Thread user timezone through kanban board responses - Initialize empty slices to prevent null JSON - Extract shared field map for task Update/UpdateTx - Delete unused SoftDeleteModel, min(), formatCron, legacy handlers Worker & jobs: - Wire Asynq email infrastructure into worker - Register HandleReminderLogCleanup with daily 3AM cron - Use per-user timezone in HandleSmartReminder - Replace direct DB queries with repository calls - Delete legacy reminder handlers (~200 lines) - Delete unused task type constants Dependencies: - Replace archived jung-kurt/gofpdf with go-pdf/fpdf - Replace unmaintained gomail.v2 with wneessen/go-mail - Add TODO for Echo jwt v3 transitive dep removal Test infrastructure: - Fix MakeRequest/SeedLookupData error handling - Replace os.Exit(0) with t.Skip() in scope/consistency tests - Add 11 new FCM v1 tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
330 lines
11 KiB
Go
330 lines
11 KiB
Go
package repositories
|
|
|
|
import (
|
|
"errors"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
|
|
"github.com/treytartt/honeydue-api/internal/models"
|
|
)
|
|
|
|
// SubscriptionRepository handles database operations for subscriptions
|
|
type SubscriptionRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewSubscriptionRepository creates a new subscription repository
|
|
func NewSubscriptionRepository(db *gorm.DB) *SubscriptionRepository {
|
|
return &SubscriptionRepository{db: db}
|
|
}
|
|
|
|
// === User Subscription ===
|
|
|
|
// FindByUserID finds a subscription by user ID
|
|
func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscription, error) {
|
|
var sub models.UserSubscription
|
|
err := r.db.Where("user_id = ?", userID).First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
|
|
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
|
|
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
|
|
var sub models.UserSubscription
|
|
|
|
err := r.db.Transaction(func(tx *gorm.DB) error {
|
|
err := tx.Where("user_id = ?", userID).First(&sub).Error
|
|
if err == nil {
|
|
return nil // Found existing subscription
|
|
}
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err // Unexpected error
|
|
}
|
|
|
|
// Record not found -- create with free tier defaults
|
|
sub = models.UserSubscription{
|
|
UserID: userID,
|
|
Tier: models.TierFree,
|
|
AutoRenew: true,
|
|
}
|
|
return tx.Create(&sub).Error
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// Update updates a subscription
|
|
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
|
|
return r.db.Omit("User").Save(sub).Error
|
|
}
|
|
|
|
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
|
|
// to prevent concurrent subscription mutations from corrupting state.
|
|
func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// Lock the row for update
|
|
var sub models.UserSubscription
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("user_id = ?", userID).First(&sub).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
return tx.Model(&sub).Updates(map[string]interface{}{
|
|
"tier": models.TierPro,
|
|
"subscribed_at": now,
|
|
"expires_at": expiresAt,
|
|
"cancelled_at": nil,
|
|
"platform": platform,
|
|
"auto_renew": true,
|
|
}).Error
|
|
})
|
|
}
|
|
|
|
// DowngradeToFree downgrades a user to Free tier using a transaction with row locking
|
|
// to prevent concurrent subscription mutations from corrupting state.
|
|
func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// Lock the row for update
|
|
var sub models.UserSubscription
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("user_id = ?", userID).First(&sub).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
return tx.Model(&sub).Updates(map[string]interface{}{
|
|
"tier": models.TierFree,
|
|
"cancelled_at": now,
|
|
"auto_renew": false,
|
|
}).Error
|
|
})
|
|
}
|
|
|
|
// SetAutoRenew sets the auto-renew flag
|
|
func (r *SubscriptionRepository) SetAutoRenew(userID uint, autoRenew bool) error {
|
|
return r.db.Model(&models.UserSubscription{}).
|
|
Where("user_id = ?", userID).
|
|
Update("auto_renew", autoRenew).Error
|
|
}
|
|
|
|
// UpdateReceiptData updates the Apple receipt data
|
|
func (r *SubscriptionRepository) UpdateReceiptData(userID uint, receiptData string) error {
|
|
return r.db.Model(&models.UserSubscription{}).
|
|
Where("user_id = ?", userID).
|
|
Update("apple_receipt_data", receiptData).Error
|
|
}
|
|
|
|
// UpdatePurchaseToken updates the Google purchase token
|
|
func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) error {
|
|
return r.db.Model(&models.UserSubscription{}).
|
|
Where("user_id = ?", userID).
|
|
Update("google_purchase_token", token).Error
|
|
}
|
|
|
|
// FindByAppleReceiptContains finds a subscription by Apple transaction ID.
|
|
// Used by webhooks to find the user associated with a transaction.
|
|
//
|
|
// PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which
|
|
// cannot use a B-tree index and results in a full table scan. For better
|
|
// performance at scale, add a dedicated indexed column:
|
|
//
|
|
// AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"`
|
|
//
|
|
// Then look up by exact match: WHERE apple_transaction_id = ?
|
|
func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) {
|
|
var sub models.UserSubscription
|
|
// Escape LIKE wildcards in the transaction ID to prevent wildcard injection
|
|
escaped := escapeLikeWildcards(transactionID)
|
|
err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// FindByGoogleToken finds a subscription by Google purchase token
|
|
// Used by webhooks to find the user associated with a purchase
|
|
func (r *SubscriptionRepository) FindByGoogleToken(purchaseToken string) (*models.UserSubscription, error) {
|
|
var sub models.UserSubscription
|
|
err := r.db.Where("google_purchase_token = ?", purchaseToken).First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// SetCancelledAt sets the cancellation timestamp
|
|
func (r *SubscriptionRepository) SetCancelledAt(userID uint, cancelledAt time.Time) error {
|
|
return r.db.Model(&models.UserSubscription{}).
|
|
Where("user_id = ?", userID).
|
|
Update("cancelled_at", cancelledAt).Error
|
|
}
|
|
|
|
// ClearCancelledAt clears the cancellation timestamp (user resubscribed)
|
|
func (r *SubscriptionRepository) ClearCancelledAt(userID uint) error {
|
|
return r.db.Model(&models.UserSubscription{}).
|
|
Where("user_id = ?", userID).
|
|
Update("cancelled_at", nil).Error
|
|
}
|
|
|
|
// === Tier Limits ===
|
|
|
|
// GetTierLimits gets the limits for a subscription tier
|
|
func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*models.TierLimits, error) {
|
|
var limits models.TierLimits
|
|
err := r.db.Where("tier = ?", tier).First(&limits).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Return defaults
|
|
if tier == models.TierFree {
|
|
defaults := models.GetDefaultFreeLimits()
|
|
return &defaults, nil
|
|
}
|
|
defaults := models.GetDefaultProLimits()
|
|
return &defaults, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &limits, nil
|
|
}
|
|
|
|
// GetAllTierLimits gets all tier limits
|
|
func (r *SubscriptionRepository) GetAllTierLimits() ([]models.TierLimits, error) {
|
|
var limits []models.TierLimits
|
|
err := r.db.Find(&limits).Error
|
|
return limits, err
|
|
}
|
|
|
|
// === Subscription Settings (Singleton) ===
|
|
|
|
// GetSettings gets the subscription settings
|
|
func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, error) {
|
|
var settings models.SubscriptionSettings
|
|
err := r.db.First(&settings).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Return default settings (limitations disabled)
|
|
return &models.SubscriptionSettings{
|
|
EnableLimitations: false,
|
|
}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &settings, nil
|
|
}
|
|
|
|
// === Upgrade Triggers ===
|
|
|
|
// GetUpgradeTrigger gets an upgrade trigger by key
|
|
func (r *SubscriptionRepository) GetUpgradeTrigger(key string) (*models.UpgradeTrigger, error) {
|
|
var trigger models.UpgradeTrigger
|
|
err := r.db.Where("trigger_key = ? AND is_active = ?", key, true).First(&trigger).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &trigger, nil
|
|
}
|
|
|
|
// GetAllUpgradeTriggers gets all active upgrade triggers
|
|
func (r *SubscriptionRepository) GetAllUpgradeTriggers() ([]models.UpgradeTrigger, error) {
|
|
var triggers []models.UpgradeTrigger
|
|
err := r.db.Where("is_active = ?", true).Find(&triggers).Error
|
|
return triggers, err
|
|
}
|
|
|
|
// === Feature Benefits ===
|
|
|
|
// GetFeatureBenefits gets all active feature benefits
|
|
func (r *SubscriptionRepository) GetFeatureBenefits() ([]models.FeatureBenefit, error) {
|
|
var benefits []models.FeatureBenefit
|
|
err := r.db.Where("is_active = ?", true).Order("display_order").Find(&benefits).Error
|
|
return benefits, err
|
|
}
|
|
|
|
// === Promotions ===
|
|
|
|
// GetActivePromotions gets all currently active promotions for a tier
|
|
func (r *SubscriptionRepository) GetActivePromotions(tier models.SubscriptionTier) ([]models.Promotion, error) {
|
|
now := time.Now().UTC()
|
|
var promotions []models.Promotion
|
|
err := r.db.Where("is_active = ? AND target_tier = ? AND start_date <= ? AND end_date >= ?",
|
|
true, tier, now, now).
|
|
Order("start_date DESC").
|
|
Find(&promotions).Error
|
|
return promotions, err
|
|
}
|
|
|
|
// GetPromotionByID gets a promotion by ID
|
|
func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.Promotion, error) {
|
|
var promotion models.Promotion
|
|
err := r.db.Where("promotion_id = ? AND is_active = ?", promotionID, true).First(&promotion).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &promotion, nil
|
|
}
|
|
|
|
// === Stripe Lookups ===
|
|
|
|
// FindByStripeCustomerID finds a subscription by Stripe customer ID
|
|
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
|
|
var sub models.UserSubscription
|
|
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
|
|
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
|
|
var sub models.UserSubscription
|
|
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sub, nil
|
|
}
|
|
|
|
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
|
|
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
|
|
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
|
"stripe_customer_id": customerID,
|
|
"stripe_subscription_id": subscriptionID,
|
|
"stripe_price_id": priceID,
|
|
}).Error
|
|
}
|
|
|
|
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
|
|
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
|
|
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
|
"stripe_subscription_id": nil,
|
|
"stripe_price_id": nil,
|
|
}).Error
|
|
}
|
|
|
|
// === Trial Management ===
|
|
|
|
// SetTrialDates sets the trial start, end, and marks trial as used
|
|
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
|
|
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
|
"trial_start": trialStart,
|
|
"trial_end": trialEnd,
|
|
"trial_used": true,
|
|
}).Error
|
|
}
|
|
|
|
// UpdateExpiresAt updates the expires_at field for a user's subscription
|
|
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
|
|
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
|
|
Update("expires_at", expiresAt).Error
|
|
}
|