Add Stripe billing, free trials, and cross-platform subscription guards
- Stripe integration: add StripeService with checkout sessions, customer portal, and webhook handling for subscription lifecycle events. - Free trials: auto-start configurable trial on first subscription check, with admin-controllable duration and enable/disable toggle. - Cross-platform guard: prevent duplicate subscriptions across iOS, Android, and Stripe by checking existing platform before allowing purchase. - Subscription model: add Stripe fields (customer_id, subscription_id, price_id), trial fields (trial_start, trial_end, trial_used), and SubscriptionSource/IsTrialActive helpers. - API: add trial and source fields to status response, update OpenAPI spec. - Clean up stale migration and audit docs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,9 +2,11 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
@@ -77,3 +79,150 @@ func TestGetOrCreate_Idempotent(t *testing.T) {
|
||||
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
|
||||
}
|
||||
|
||||
func TestFindByStripeCustomerID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
customerID string
|
||||
seedID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "finds existing subscription by stripe customer ID",
|
||||
customerID: "cus_test123",
|
||||
seedID: "cus_test123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns error for unknown stripe customer ID",
|
||||
customerID: "cus_unknown999",
|
||||
seedID: "cus_test456",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
StripeCustomerID: &tt.seedID,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByStripeCustomerID(tt.customerID)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
assert.Equal(t, tt.seedID, *found.StripeCustomerID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateStripeData(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update all three Stripe fields
|
||||
customerID := "cus_abc123"
|
||||
subscriptionID := "sub_xyz789"
|
||||
priceID := "price_monthly"
|
||||
|
||||
err = repo.UpdateStripeData(user.ID, customerID, subscriptionID, priceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all three fields are set
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.StripeCustomerID)
|
||||
require.NotNil(t, updated.StripeSubscriptionID)
|
||||
require.NotNil(t, updated.StripePriceID)
|
||||
assert.Equal(t, customerID, *updated.StripeCustomerID)
|
||||
assert.Equal(t, subscriptionID, *updated.StripeSubscriptionID)
|
||||
assert.Equal(t, priceID, *updated.StripePriceID)
|
||||
|
||||
// Now call ClearStripeData
|
||||
err = repo.ClearStripeData(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription_id and price_id are cleared, customer_id preserved
|
||||
var cleared models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&cleared).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cleared.StripeCustomerID, "customer_id should be preserved after ClearStripeData")
|
||||
assert.Equal(t, customerID, *cleared.StripeCustomerID)
|
||||
assert.Nil(t, cleared.StripeSubscriptionID, "subscription_id should be nil after ClearStripeData")
|
||||
assert.Nil(t, cleared.StripePriceID, "price_id should be nil after ClearStripeData")
|
||||
}
|
||||
|
||||
func TestSetTrialDates(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
TrialUsed: false,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
trialStart := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
trialEnd := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
err = repo.SetTrialDates(user.ID, trialStart, trialEnd)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, updated.TrialStart)
|
||||
require.NotNil(t, updated.TrialEnd)
|
||||
assert.True(t, updated.TrialUsed, "trial_used should be set to true")
|
||||
assert.WithinDuration(t, trialStart, *updated.TrialStart, time.Second, "trial_start should match")
|
||||
assert.WithinDuration(t, trialEnd, *updated.TrialEnd, time.Second, "trial_end should match")
|
||||
}
|
||||
|
||||
func TestUpdateExpiresAt(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
newExpiry := time.Date(2027, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
err = repo.UpdateExpiresAt(user.ID, newExpiry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.UserSubscription
|
||||
err = db.Where("user_id = ?", user.ID).First(&updated).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, updated.ExpiresAt)
|
||||
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user