Add Stripe billing, free trials, and cross-platform subscription guards

- Stripe integration: add StripeService with checkout sessions, customer
  portal, and webhook handling for subscription lifecycle events.
- Free trials: auto-start configurable trial on first subscription check,
  with admin-controllable duration and enable/disable toggle.
- Cross-platform guard: prevent duplicate subscriptions across iOS, Android,
  and Stripe by checking existing platform before allowing purchase.
- Subscription model: add Stripe fields (customer_id, subscription_id,
  price_id), trial fields (trial_start, trial_end, trial_used), and
  SubscriptionSource/IsTrialActive helpers.
- API: add trial and source fields to status response, update OpenAPI spec.
- Clean up stale migration and audit docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-05 11:36:14 -06:00
parent d5bb123cd0
commit 72db9050f8
35 changed files with 1555 additions and 1120 deletions

View File

@@ -262,3 +262,59 @@ func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.P
}
return &promotion, nil
}
// === Stripe Lookups ===
// FindByStripeCustomerID finds a subscription by Stripe customer ID
func (r *SubscriptionRepository) FindByStripeCustomerID(customerID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_customer_id = ?", customerID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// FindByStripeSubscriptionID finds a subscription by Stripe subscription ID
func (r *SubscriptionRepository) FindByStripeSubscriptionID(subscriptionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("stripe_subscription_id = ?", subscriptionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateStripeData updates all three Stripe fields (customer, subscription, price) in one call
func (r *SubscriptionRepository) UpdateStripeData(userID uint, customerID, subscriptionID, priceID string) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_customer_id": customerID,
"stripe_subscription_id": subscriptionID,
"stripe_price_id": priceID,
}).Error
}
// ClearStripeData clears the Stripe subscription and price IDs (customer ID stays for portal access)
func (r *SubscriptionRepository) ClearStripeData(userID uint) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"stripe_subscription_id": nil,
"stripe_price_id": nil,
}).Error
}
// === Trial Management ===
// SetTrialDates sets the trial start, end, and marks trial as used
func (r *SubscriptionRepository) SetTrialDates(userID uint, trialStart, trialEnd time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"trial_start": trialStart,
"trial_end": trialEnd,
"trial_used": true,
}).Error
}
// UpdateExpiresAt updates the expires_at field for a user's subscription
func (r *SubscriptionRepository) UpdateExpiresAt(userID uint, expiresAt time.Time) error {
return r.db.Model(&models.UserSubscription{}).Where("user_id = ?", userID).
Update("expires_at", expiresAt).Error
}

View File

@@ -2,9 +2,11 @@ package repositories
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
@@ -77,3 +79,150 @@ func TestGetOrCreate_Idempotent(t *testing.T) {
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
}
func TestFindByStripeCustomerID(t *testing.T) {
tests := []struct {
name string
customerID string
seedID string
wantErr bool
}{
{
name: "finds existing subscription by stripe customer ID",
customerID: "cus_test123",
seedID: "cus_test123",
wantErr: false,
},
{
name: "returns error for unknown stripe customer ID",
customerID: "cus_unknown999",
seedID: "cus_test456",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
StripeCustomerID: &tt.seedID,
}
err := db.Create(sub).Error
require.NoError(t, err)
found, err := repo.FindByStripeCustomerID(tt.customerID)
if tt.wantErr {
assert.Error(t, err)
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
} else {
require.NoError(t, err)
require.NotNil(t, found)
assert.Equal(t, user.ID, found.UserID)
assert.Equal(t, tt.seedID, *found.StripeCustomerID)
}
})
}
}
func TestUpdateStripeData(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
}
err := db.Create(sub).Error
require.NoError(t, err)
// Update all three Stripe fields
customerID := "cus_abc123"
subscriptionID := "sub_xyz789"
priceID := "price_monthly"
err = repo.UpdateStripeData(user.ID, customerID, subscriptionID, priceID)
require.NoError(t, err)
// Verify all three fields are set
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.StripeCustomerID)
require.NotNil(t, updated.StripeSubscriptionID)
require.NotNil(t, updated.StripePriceID)
assert.Equal(t, customerID, *updated.StripeCustomerID)
assert.Equal(t, subscriptionID, *updated.StripeSubscriptionID)
assert.Equal(t, priceID, *updated.StripePriceID)
// Now call ClearStripeData
err = repo.ClearStripeData(user.ID)
require.NoError(t, err)
// Verify subscription_id and price_id are cleared, customer_id preserved
var cleared models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&cleared).Error
require.NoError(t, err)
require.NotNil(t, cleared.StripeCustomerID, "customer_id should be preserved after ClearStripeData")
assert.Equal(t, customerID, *cleared.StripeCustomerID)
assert.Nil(t, cleared.StripeSubscriptionID, "subscription_id should be nil after ClearStripeData")
assert.Nil(t, cleared.StripePriceID, "price_id should be nil after ClearStripeData")
}
func TestSetTrialDates(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
TrialUsed: false,
}
err := db.Create(sub).Error
require.NoError(t, err)
trialStart := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
trialEnd := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
err = repo.SetTrialDates(user.ID, trialStart, trialEnd)
require.NoError(t, err)
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.TrialStart)
require.NotNil(t, updated.TrialEnd)
assert.True(t, updated.TrialUsed, "trial_used should be set to true")
assert.WithinDuration(t, trialStart, *updated.TrialStart, time.Second, "trial_start should match")
assert.WithinDuration(t, trialEnd, *updated.TrialEnd, time.Second, "trial_end should match")
}
func TestUpdateExpiresAt(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
}
err := db.Create(sub).Error
require.NoError(t, err)
newExpiry := time.Date(2027, 6, 15, 12, 0, 0, 0, time.UTC)
err = repo.UpdateExpiresAt(user.ID, newExpiry)
require.NoError(t, err)
var updated models.UserSubscription
err = db.Where("user_id = ?", user.ID).First(&updated).Error
require.NoError(t, err)
require.NotNil(t, updated.ExpiresAt)
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
}