Files
honeyDueAPI/internal/repositories/subscription_repo.go
T
Trey t bc3da007db
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Build (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled
Wire OpenTelemetry tracing — HTTP, B2, APNs, FCM, asynq, GORM (partial)
Step 1 — OTel SDK: cmd/api and cmd/worker initialize a tracer provider
that exports OTLP/HTTP to obs.88oakapps.com (Jaeger all-in-one). Sampling
is AlwaysSample in dev (DEBUG=true) and TraceIDRatioBased(0.1) in prod,
overridable via OTEL_TRACES_SAMPLER_ARG. Service names are honeydue-api
and honeydue-worker. otelecho.Middleware opens a span per HTTP request.

Step 2 — Manual spans: storage_service.Upload now takes ctx and emits
storage.upload + b2.PutObject spans (size_bytes, key, mime_type, bucket,
result attrs). APNs Send/SendWithCategory and FCM sendOne emit per-token
spans with topic, status_code, reason. Asynq middleware emits
asynq.handle:<task_type> per job with retry/payload attrs and records
asynq_job_duration_seconds.

Step 3 — Database: otelgorm plugin registered in database.Connect, so
any SQL emitted via db.WithContext(ctx) attaches to the request span.
Every repository now exposes WithContext(ctx) *XRepository as the
migration helper. TaskService.ListTasks and GetTasksByResidence are
migrated end-to-end (ctx threaded through handler → service → repo);
remaining services adopt the same pattern incrementally — pre-migration
methods still emit untraced SQL via the unchanged db field.

OBS_TRACES_URL and OBS_INGEST_TOKEN flow from deploy/prod.env →
honeydue-secrets → api+worker Deployments via secretKeyRef (optional).
02-setup-secrets.sh sources them from prod.env on next run; manifests
mark both env vars optional so the deployment rolls without traces if
the secret is absent.

ch15 observability doc now lists what produces spans today vs the
remaining migration work, with the explicit per-method pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-25 15:28:05 -05:00

338 lines
11 KiB
Go

package repositories
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/treytartt/honeydue-api/internal/models"
)
// SubscriptionRepository handles database operations for subscriptions
type SubscriptionRepository struct {
db *gorm.DB
}
// NewSubscriptionRepository creates a new subscription repository
func NewSubscriptionRepository(db *gorm.DB) *SubscriptionRepository {
return &SubscriptionRepository{db: db}
}
// === User Subscription ===
// FindByUserID finds a subscription by user ID
func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("user_id = ?", userID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ?", userID).First(&sub).Error
if err == nil {
return nil // Found existing subscription
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err // Unexpected error
}
// Record not found -- create with free tier defaults
sub = models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
AutoRenew: true,
}
return tx.Create(&sub).Error
})
if err != nil {
return nil, err
}
return &sub, nil
}
// Update updates a subscription
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
return r.db.Omit("User").Save(sub).Error
}
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
// to prevent concurrent subscription mutations from corrupting state.
func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
now := time.Now().UTC()
return tx.Model(&sub).Updates(map[string]interface{}{
"tier": models.TierPro,
"subscribed_at": now,
"expires_at": expiresAt,
"cancelled_at": nil,
"platform": platform,
"auto_renew": true,
}).Error
})
}
// DowngradeToFree downgrades a user to Free tier using a transaction with row locking
// to prevent concurrent subscription mutations from corrupting state.
func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
now := time.Now().UTC()
return tx.Model(&sub).Updates(map[string]interface{}{
"tier": models.TierFree,
"cancelled_at": now,
"auto_renew": false,
}).Error
})
}
// SetAutoRenew sets the auto-renew flag
func (r *SubscriptionRepository) SetAutoRenew(userID uint, autoRenew bool) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("auto_renew", autoRenew).Error
}
// UpdateReceiptData updates the Apple receipt data
func (r *SubscriptionRepository) UpdateReceiptData(userID uint, receiptData string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("apple_receipt_data", receiptData).Error
}
// UpdatePurchaseToken updates the Google purchase token
func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("google_purchase_token", token).Error
}
// FindByAppleReceiptContains finds a subscription by Apple transaction ID.
// Used by webhooks to find the user associated with a transaction.
//
// PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which
// cannot use a B-tree index and results in a full table scan. For better
// performance at scale, add a dedicated indexed column:
//
// AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"`
//
// Then look up by exact match: WHERE apple_transaction_id = ?
func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
// Escape LIKE wildcards in the transaction ID to prevent wildcard injection
escaped := escapeLikeWildcards(transactionID)
err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByGoogleToken finds a subscription by Google purchase token
// Used by webhooks to find the user associated with a purchase
func (r *SubscriptionRepository) FindByGoogleToken(purchaseToken string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("google_purchase_token = ?", purchaseToken).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// SetCancelledAt sets the cancellation timestamp
func (r *SubscriptionRepository) SetCancelledAt(userID uint, cancelledAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("cancelled_at", cancelledAt).Error
}
// ClearCancelledAt clears the cancellation timestamp (user resubscribed)
func (r *SubscriptionRepository) ClearCancelledAt(userID uint) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("cancelled_at", nil).Error
}
// === Tier Limits ===
// GetTierLimits gets the limits for a subscription tier
func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*models.TierLimits, error) {
var limits models.TierLimits
err := r.db.Where("tier = ?", tier).First(&limits).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return defaults
if tier == models.TierFree {
defaults := models.GetDefaultFreeLimits()
return &defaults, nil
}
defaults := models.GetDefaultProLimits()
return &defaults, nil
}
return nil, err
}
return &limits, nil
}
// GetAllTierLimits gets all tier limits
func (r *SubscriptionRepository) GetAllTierLimits() ([]models.TierLimits, error) {
var limits []models.TierLimits
err := r.db.Find(&limits).Error
return limits, err
}
// === Subscription Settings (Singleton) ===
// GetSettings gets the subscription settings
func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, error) {
var settings models.SubscriptionSettings
err := r.db.First(&settings).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return default settings (limitations disabled)
return &models.SubscriptionSettings{
EnableLimitations: false,
}, nil
}
return nil, err
}
return &settings, nil
}
// === Upgrade Triggers ===
// GetUpgradeTrigger gets an upgrade trigger by key
func (r *SubscriptionRepository) GetUpgradeTrigger(key string) (*models.UpgradeTrigger, error) {
var trigger models.UpgradeTrigger
err := r.db.Where("trigger_key = ? AND is_active = ?", key, true).First(&trigger).Error
if err != nil {
return nil, err
}
return &trigger, nil
}
// GetAllUpgradeTriggers gets all active upgrade triggers
func (r *SubscriptionRepository) GetAllUpgradeTriggers() ([]models.UpgradeTrigger, error) {
var triggers []models.UpgradeTrigger
err := r.db.Where("is_active = ?", true).Find(&triggers).Error
return triggers, err
}
// === Feature Benefits ===
// GetFeatureBenefits gets all active feature benefits
func (r *SubscriptionRepository) GetFeatureBenefits() ([]models.FeatureBenefit, error) {
var benefits []models.FeatureBenefit
err := r.db.Where("is_active = ?", true).Order("display_order").Find(&benefits).Error
return benefits, err
}
// === Promotions ===
// GetActivePromotions gets all currently active promotions for a tier
func (r *SubscriptionRepository) GetActivePromotions(tier models.SubscriptionTier) ([]models.Promotion, error) {
now := time.Now().UTC()
var promotions []models.Promotion
err := r.db.Where("is_active = ? AND target_tier = ? AND start_date <= ? AND end_date >= ?",
true, tier, now, now).
Order("start_date DESC").
Find(&promotions).Error
return promotions, err
}
// GetPromotionByID gets a promotion by ID
func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.Promotion, error) {
var promotion models.Promotion
err := r.db.Where("promotion_id = ? AND is_active = ?", promotionID, true).First(&promotion).Error
if err != nil {
return nil, err
}
return &promotion, nil
}
// === Stripe Lookups ===
// FindByStripeCustomerID finds a subscription by Stripe customer ID
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_customer_id": customerID,
"stripe_subscription_id": subscriptionID,
"stripe_price_id": priceID,
}).Error
}
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_subscription_id": nil,
"stripe_price_id": nil,
}).Error
}
// === Trial Management ===
// SetTrialDates sets the trial start, end, and marks trial as used
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"trial_start": trialStart,
"trial_end": trialEnd,
"trial_used": true,
}).Error
}
// UpdateExpiresAt updates the expires_at field for a user's subscription
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
Update("expires_at", expiresAt).Error
}
// WithContext returns a copy of the repository whose underlying *gorm.DB carries
// the supplied context. SQL emitted via this copy gets attached to ctx's trace span
// (when otelgorm is registered) and respects ctx cancellation/deadlines.
func (r *SubscriptionRepository) WithContext(ctx context.Context) *SubscriptionRepository {
return &SubscriptionRepository{db: r.db.WithContext(ctx)}
}