Files
honeyDueAPI/internal/repositories/subscription_repo.go
T
Trey t c77ff07ce9
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled
Backend CI / Build (push) Has been cancelled
fix(security): remediate 2026-05-12 audit findings (Stages 2–5)
Remediation of the 2026-05-12/13 audits (78 findings + cluster gaps),
tracked in deploy-k3s/SECURITY.md, plus fixes from two independent
post-remediation reviews.

Auth & sessions:
- SHA-256 hashed auth-token storage (C1); prior-token cache eviction on
  re-login (MEDIUM-1)
- local Google JWKS verification, iss/aud/exp checks (C2/C3)
- constant-time login + generic errors (L1/LIVE-L11/LIVE-L13)
- per-account login lockout keyed on distinct source IPs (M5/MEDIUM-3)
- verified-email gating, login rate limiting (LIVE-L19, H1-H3)

IAP & webhooks:
- Apple/Google cross-account replay protection (C5/C6/C10/C13, H5/H6)
- migrations 000003-000006 (token hashing, IAP replay, audit_log +
  webhook_event_log table creation, append-only audit log)

Authorization & races:
- file-ownership owner-OR-member fix (C7), atomic share-code join
  (C9/H9), device-token reassignment (C8/LOW-3)

Secrets & deploy:
- secrets file-mounted at /etc/honeydue/secrets, not env (F8); Redis
  password out of the ConfigMap (HIGH-1); B2 keys reconciled
- digest-pinned images, admin ingress hardening, CSP/HSTS, /metrics
  lockdown; kubeconfig 0600, etcd secrets-encryption, fail2ban +
  unattended-upgrades at provision; secret-rotation runbook

Build, vet, and the full test suite (incl. -race) pass; the goose
migration chain is verified against PostgreSQL 16.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-16 22:28:33 -05:00

360 lines
12 KiB
Go

package repositories
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/treytartt/honeydue-api/internal/models"
)
// SubscriptionRepository handles database operations for subscriptions
type SubscriptionRepository struct {
db *gorm.DB
}
// NewSubscriptionRepository creates a new subscription repository
func NewSubscriptionRepository(db *gorm.DB) *SubscriptionRepository {
return &SubscriptionRepository{db: db}
}
// === User Subscription ===
// FindByUserID finds a subscription by user ID
func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("user_id = ?", userID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ?", userID).First(&sub).Error
if err == nil {
return nil // Found existing subscription
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err // Unexpected error
}
// Record not found -- create with free tier defaults
sub = models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
AutoRenew: true,
}
return tx.Create(&sub).Error
})
if err != nil {
return nil, err
}
return &sub, nil
}
// Update updates a subscription
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
return r.db.Omit("User").Save(sub).Error
}
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
// to prevent concurrent subscription mutations from corrupting state.
func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
now := time.Now().UTC()
return tx.Model(&sub).Updates(map[string]interface{}{
"tier": models.TierPro,
"subscribed_at": now,
"expires_at": expiresAt,
"cancelled_at": nil,
"platform": platform,
"auto_renew": true,
}).Error
})
}
// DowngradeToFree downgrades a user to Free tier using a transaction with row locking
// to prevent concurrent subscription mutations from corrupting state.
func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
now := time.Now().UTC()
return tx.Model(&sub).Updates(map[string]interface{}{
"tier": models.TierFree,
"cancelled_at": now,
"auto_renew": false,
}).Error
})
}
// SetAutoRenew sets the auto-renew flag
func (r *SubscriptionRepository) SetAutoRenew(userID uint, autoRenew bool) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("auto_renew", autoRenew).Error
}
// UpdateReceiptData updates the Apple receipt data
func (r *SubscriptionRepository) UpdateReceiptData(userID uint, receiptData string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("apple_receipt_data", receiptData).Error
}
// UpdatePurchaseToken updates the Google purchase token
func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("google_purchase_token", token).Error
}
// FindByAppleReceiptContains finds a subscription by Apple transaction ID.
// Used by webhooks to find the user associated with a transaction.
//
// PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which
// cannot use a B-tree index and results in a full table scan. For better
// performance at scale, add a dedicated indexed column:
//
// AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"`
//
// Then look up by exact match: WHERE apple_transaction_id = ?
func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
// Escape LIKE wildcards in the transaction ID to prevent wildcard injection
escaped := escapeLikeWildcards(transactionID)
err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByAppleOriginalTransactionID finds a subscription by the Apple original
// transaction ID (audit C5/C13). Exact match on an indexed column — replaces
// the LIKE scan in FindByAppleReceiptContains for both replay detection and
// webhook user lookup.
func (r *SubscriptionRepository) FindByAppleOriginalTransactionID(originalTransactionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("apple_original_transaction_id = ?", originalTransactionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateAppleOriginalTransactionID binds an Apple original transaction ID to a
// user's subscription. A partial unique index enforces one account per
// transaction (audit C5) — a second account claiming the same ID fails here.
func (r *SubscriptionRepository) UpdateAppleOriginalTransactionID(userID uint, originalTransactionID string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("apple_original_transaction_id", originalTransactionID).Error
}
// FindByGoogleToken finds a subscription by Google purchase token
// Used by webhooks to find the user associated with a purchase
func (r *SubscriptionRepository) FindByGoogleToken(purchaseToken string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("google_purchase_token = ?", purchaseToken).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// SetCancelledAt sets the cancellation timestamp
func (r *SubscriptionRepository) SetCancelledAt(userID uint, cancelledAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("cancelled_at", cancelledAt).Error
}
// ClearCancelledAt clears the cancellation timestamp (user resubscribed)
func (r *SubscriptionRepository) ClearCancelledAt(userID uint) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("cancelled_at", nil).Error
}
// === Tier Limits ===
// GetTierLimits gets the limits for a subscription tier
func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*models.TierLimits, error) {
var limits models.TierLimits
err := r.db.Where("tier = ?", tier).First(&limits).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return defaults
if tier == models.TierFree {
defaults := models.GetDefaultFreeLimits()
return &defaults, nil
}
defaults := models.GetDefaultProLimits()
return &defaults, nil
}
return nil, err
}
return &limits, nil
}
// GetAllTierLimits gets all tier limits
func (r *SubscriptionRepository) GetAllTierLimits() ([]models.TierLimits, error) {
var limits []models.TierLimits
err := r.db.Find(&limits).Error
return limits, err
}
// === Subscription Settings (Singleton) ===
// GetSettings gets the subscription settings
func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, error) {
var settings models.SubscriptionSettings
err := r.db.First(&settings).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return default settings (limitations disabled)
return &models.SubscriptionSettings{
EnableLimitations: false,
}, nil
}
return nil, err
}
return &settings, nil
}
// === Upgrade Triggers ===
// GetUpgradeTrigger gets an upgrade trigger by key
func (r *SubscriptionRepository) GetUpgradeTrigger(key string) (*models.UpgradeTrigger, error) {
var trigger models.UpgradeTrigger
err := r.db.Where("trigger_key = ? AND is_active = ?", key, true).First(&trigger).Error
if err != nil {
return nil, err
}
return &trigger, nil
}
// GetAllUpgradeTriggers gets all active upgrade triggers
func (r *SubscriptionRepository) GetAllUpgradeTriggers() ([]models.UpgradeTrigger, error) {
var triggers []models.UpgradeTrigger
err := r.db.Where("is_active = ?", true).Find(&triggers).Error
return triggers, err
}
// === Feature Benefits ===
// GetFeatureBenefits gets all active feature benefits
func (r *SubscriptionRepository) GetFeatureBenefits() ([]models.FeatureBenefit, error) {
var benefits []models.FeatureBenefit
err := r.db.Where("is_active = ?", true).Order("display_order").Find(&benefits).Error
return benefits, err
}
// === Promotions ===
// GetActivePromotions gets all currently active promotions for a tier
func (r *SubscriptionRepository) GetActivePromotions(tier models.SubscriptionTier) ([]models.Promotion, error) {
now := time.Now().UTC()
var promotions []models.Promotion
err := r.db.Where("is_active = ? AND target_tier = ? AND start_date <= ? AND end_date >= ?",
true, tier, now, now).
Order("start_date DESC").
Find(&promotions).Error
return promotions, err
}
// GetPromotionByID gets a promotion by ID
func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.Promotion, error) {
var promotion models.Promotion
err := r.db.Where("promotion_id = ? AND is_active = ?", promotionID, true).First(&promotion).Error
if err != nil {
return nil, err
}
return &promotion, nil
}
// === Stripe Lookups ===
// FindByStripeCustomerID finds a subscription by Stripe customer ID
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_customer_id": customerID,
"stripe_subscription_id": subscriptionID,
"stripe_price_id": priceID,
}).Error
}
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_subscription_id": nil,
"stripe_price_id": nil,
}).Error
}
// === Trial Management ===
// SetTrialDates sets the trial start, end, and marks trial as used
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"trial_start": trialStart,
"trial_end": trialEnd,
"trial_used": true,
}).Error
}
// UpdateExpiresAt updates the expires_at field for a user's subscription
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
Update("expires_at", expiresAt).Error
}
// WithContext returns a copy of the repository whose underlying *gorm.DB carries
// the supplied context. SQL emitted via this copy gets attached to ctx's trace span
// (when otelgorm is registered) and respects ctx cancellation/deadlines.
func (r *SubscriptionRepository) WithContext(ctx context.Context) *SubscriptionRepository {
return &SubscriptionRepository{db: r.db.WithContext(ctx)}
}