perf(subscription-status): cache + parallelize + invalidate on mutations
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Build (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled

GET /api/subscription/status/ was the slowest endpoint in the API at
p50≈1750ms / p95≈2425ms — about 12× the floor for our cluster→Neon
geography. Jaeger traces showed seven sequential SQL queries each
costing roughly one transatlantic RTT (~110ms), with the actual queries
running in 0.073ms at the database. Pure network serialization, not slow
SQL.

Three changes, in order of leverage:

1. Cache the assembled SubscriptionStatusResponse per-user in Redis with
   a 5-minute TTL. Hot path collapses to a single Redis GET (~5ms) on
   warm reads; the TTL is a safety net against missed invalidations.

2. Parallelize the three independent COUNT queries in getUserUsage
   (task_task / task_contractor / task_document) via golang.org/x/sync
   errgroup. Three RTTs collapse to one. Also dropped the redundant
   residence_residence COUNT — len(residenceIDs) from FindResidenceIDsByOwner
   is the same number, no need to re-query.

3. Wire explicit invalidation into every mutation that could change a
   user's response — residence/task/contractor/document CRUD,
   residence membership changes (JoinWithCode, RemoveUser, DeleteResidence),
   and every subscription tier flip across the IAP/Stripe/webhook surface.
   Residence-scoped invalidations fan out to every user with access via a
   new ResidenceRepository.FindUserIDsByResidence helper, so members of a
   shared residence don't see stale `usage` numbers when another member
   adds a task.

Net effect: warm path goes from ~1350ms to ~5ms (Redis hit). Cold path
goes from ~1350ms to ~250-450ms (5 sequential queries → 2 phases:
residence IDs lookup, then parallel task/contractor/document counts).

Also fixed a pre-existing CheckLimit signature drift in
internal/integration/subscription_is_free_test.go that was blocking the
package build.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-05-01 11:00:23 -07:00
parent 0798ae8d74
commit 9bee436e86
11 changed files with 286 additions and 34 deletions
@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
@@ -38,6 +39,7 @@ type SubscriptionWebhookHandler struct {
webhookEventRepo *repositories.WebhookEventRepository webhookEventRepo *repositories.WebhookEventRepository
appleRootCerts []*x509.Certificate appleRootCerts []*x509.Certificate
stripeService *services.StripeService stripeService *services.StripeService
cache *services.CacheService
enabled bool enabled bool
} }
@@ -61,6 +63,24 @@ func (h *SubscriptionWebhookHandler) SetStripeService(stripeService *services.St
h.stripeService = stripeService h.stripeService = stripeService
} }
// SetCacheService wires Redis caching so post-mutation invalidation drops the
// per-user SubscriptionStatusResponse cache after Apple/Google webhooks change
// tier or auto-renew state.
func (h *SubscriptionWebhookHandler) SetCacheService(cache *services.CacheService) {
h.cache = cache
}
// invalidateStatusCache best-effort drops the per-user subscription_status
// cache after a webhook mutation. Background context — webhook handlers run
// in their own request lifecycle, but the cache write itself is fast enough
// that we don't need to bound it.
func (h *SubscriptionWebhookHandler) invalidateStatusCache(userID uint) {
if h.cache == nil {
return
}
_ = h.cache.InvalidateSubscriptionStatusForUsers(context.Background(), userID)
}
// ==================== // ====================
// Apple App Store Server Notifications v2 // Apple App Store Server Notifications v2
// ==================== // ====================
@@ -356,6 +376,7 @@ func (h *SubscriptionWebhookHandler) handleAppleSubscribed(userID uint, tx *Appl
if err := h.subscriptionRepo.SetAutoRenew(userID, autoRenew); err != nil { if err := h.subscriptionRepo.SetAutoRenew(userID, autoRenew); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Bool("auto_renew", autoRenew).Msg("Apple Webhook: User subscribed") log.Info().Uint("user_id", userID).Time("expires", expiresAt).Bool("auto_renew", autoRenew).Msg("Apple Webhook: User subscribed")
return nil return nil
@@ -367,6 +388,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewed(userID uint, tx *AppleTr
if err := h.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil { if err := h.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Msg("Apple Webhook: User renewed") log.Info().Uint("user_id", userID).Time("expires", expiresAt).Msg("Apple Webhook: User renewed")
return nil return nil
@@ -396,6 +418,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
} }
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned auto-renew back on") log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned auto-renew back on")
} }
h.invalidateStatusCache(userID)
return nil return nil
} }
@@ -673,6 +696,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRenewed(userID uint, notificati
if err := h.subscriptionRepo.UpgradeToPro(userID, newExpiry, "android"); err != nil { if err := h.subscriptionRepo.UpgradeToPro(userID, newExpiry, "android"); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Time("expires", newExpiry).Msg("Google Webhook: User renewed") log.Info().Uint("user_id", userID).Time("expires", newExpiry).Msg("Google Webhook: User renewed")
return nil return nil
@@ -684,6 +708,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRecovered(userID uint, notifica
if err := h.subscriptionRepo.UpgradeToPro(userID, newExpiry, "android"); err != nil { if err := h.subscriptionRepo.UpgradeToPro(userID, newExpiry, "android"); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription recovered") log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription recovered")
return nil return nil
@@ -698,6 +723,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleCanceled(userID uint, notificat
if err := h.subscriptionRepo.SetAutoRenew(userID, false); err != nil { if err := h.subscriptionRepo.SetAutoRenew(userID, false); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User canceled, will expire at end of period") log.Info().Uint("user_id", userID).Msg("Google Webhook: User canceled, will expire at end of period")
return nil return nil
@@ -727,6 +753,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
if err := h.subscriptionRepo.SetAutoRenew(userID, true); err != nil { if err := h.subscriptionRepo.SetAutoRenew(userID, true); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User restarted subscription") log.Info().Uint("user_id", userID).Msg("Google Webhook: User restarted subscription")
return nil return nil
@@ -762,7 +789,11 @@ func (h *SubscriptionWebhookHandler) safeDowngradeToFree(userID uint, reason str
sub, err := h.subscriptionRepo.FindByUserID(userID) sub, err := h.subscriptionRepo.FindByUserID(userID)
if err != nil { if err != nil {
log.Warn().Err(err).Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Could not find subscription for multi-source check, proceeding with downgrade") log.Warn().Err(err).Uint("user_id", userID).Str("reason", reason).Msg("Webhook: Could not find subscription for multi-source check, proceeding with downgrade")
return h.subscriptionRepo.DowngradeToFree(userID) if dErr := h.subscriptionRepo.DowngradeToFree(userID); dErr != nil {
return dErr
}
h.invalidateStatusCache(userID)
return nil
} }
// Check if Stripe subscription is still active // Check if Stripe subscription is still active
@@ -789,6 +820,7 @@ func (h *SubscriptionWebhookHandler) safeDowngradeToFree(userID uint, reason str
if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil { if err := h.subscriptionRepo.DowngradeToFree(userID); err != nil {
return err return err
} }
h.invalidateStatusCache(userID)
log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: User downgraded to free (no other active sources)") log.Info().Uint("user_id", userID).Str("reason", reason).Msg("Webhook: User downgraded to free (no other active sources)")
return nil return nil
@@ -2,6 +2,7 @@ package integration
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -238,7 +239,7 @@ func TestIntegration_IsFreeBypassesCheckLimit(t *testing.T) {
// ========== Test 1: Normal free user hits limit ========== // ========== Test 1: Normal free user hits limit ==========
// First property should succeed // First property should succeed
err = app.SubscriptionService.CheckLimit(userID, "properties") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "properties")
assert.NoError(t, err, "First property should be allowed") assert.NoError(t, err, "First property should be allowed")
// Create a property to use up the limit // Create a property to use up the limit
@@ -249,7 +250,7 @@ func TestIntegration_IsFreeBypassesCheckLimit(t *testing.T) {
app.DB.Create(residence) app.DB.Create(residence)
// Second property should fail // Second property should fail
err = app.SubscriptionService.CheckLimit(userID, "properties") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "properties")
assert.Error(t, err, "Second property should be blocked for normal free user") assert.Error(t, err, "Second property should be blocked for normal free user")
var appErr *apperrors.AppError var appErr *apperrors.AppError
require.ErrorAs(t, err, &appErr) require.ErrorAs(t, err, &appErr)
@@ -262,17 +263,17 @@ func TestIntegration_IsFreeBypassesCheckLimit(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// ========== Test 3: IsFree user bypasses limit ========== // ========== Test 3: IsFree user bypasses limit ==========
err = app.SubscriptionService.CheckLimit(userID, "properties") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "properties")
assert.NoError(t, err, "IsFree user should bypass property limits") assert.NoError(t, err, "IsFree user should bypass property limits")
// Should also bypass other limits // Should also bypass other limits
err = app.SubscriptionService.CheckLimit(userID, "tasks") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "tasks")
assert.NoError(t, err, "IsFree user should bypass task limits") assert.NoError(t, err, "IsFree user should bypass task limits")
err = app.SubscriptionService.CheckLimit(userID, "contractors") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "contractors")
assert.NoError(t, err, "IsFree user should bypass contractor limits") assert.NoError(t, err, "IsFree user should bypass contractor limits")
err = app.SubscriptionService.CheckLimit(userID, "documents") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "documents")
assert.NoError(t, err, "IsFree user should bypass document limits") assert.NoError(t, err, "IsFree user should bypass document limits")
} }
@@ -375,6 +376,6 @@ func TestIntegration_IsFreeWhenGlobalLimitationsDisabled(t *testing.T) {
"With IsFree and global limitations disabled, limitations_enabled should be false") "With IsFree and global limitations disabled, limitations_enabled should be false")
// Both cases result in the same outcome - no limitations // Both cases result in the same outcome - no limitations
err = app.SubscriptionService.CheckLimit(userID, "properties") err = app.SubscriptionService.CheckLimit(context.Background(), userID, "properties")
assert.NoError(t, err, "Should bypass limits when global limitations are disabled") assert.NoError(t, err, "Should bypass limits when global limitations are disabled")
} }
+18
View File
@@ -157,6 +157,24 @@ func (r *ResidenceRepository) GetResidenceUsers(residenceID uint) ([]models.User
return users, nil return users, nil
} }
// FindUserIDsByResidence returns the IDs of every user with access to the
// residence (owner + members from residence_residence_users). Lighter than
// GetResidenceUsers — selects only the ID column, no full user records.
// Used to fan out subscription_status cache invalidation when shared data
// (tasks/contractors/documents) changes for a residence.
func (r *ResidenceRepository) FindUserIDsByResidence(residenceID uint) ([]uint, error) {
var ids []uint
err := r.db.Raw(`
SELECT owner_id FROM residence_residence WHERE id = ? AND is_active = true
UNION
SELECT user_id FROM residence_residence_users WHERE residence_id = ?
`, residenceID, residenceID).Scan(&ids).Error
if err != nil {
return nil, err
}
return ids, nil
}
// HasAccess checks if a user has access to a residence // HasAccess checks if a user has access to a residence
func (r *ResidenceRepository) HasAccess(residenceID, userID uint) (bool, error) { func (r *ResidenceRepository) HasAccess(residenceID, userID uint) (bool, error) {
var count int64 var count int64
+1
View File
@@ -225,6 +225,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// Initialize webhook handler for Apple/Google/Stripe subscription notifications // Initialize webhook handler for Apple/Google/Stripe subscription notifications
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled) subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled)
subscriptionWebhookHandler.SetStripeService(stripeService) subscriptionWebhookHandler.SetStripeService(stripeService)
subscriptionWebhookHandler.SetCacheService(deps.Cache)
// Initialize middleware // Initialize middleware
authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg) authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg)
+61
View File
@@ -495,3 +495,64 @@ func (c *CacheService) InvalidateSubscriptionSettings(ctx context.Context) error
} }
return c.Delete(ctx, subscriptionSettingsKey) return c.Delete(ctx, subscriptionSettingsKey)
} }
// === SubscriptionStatus cache (per-user) ===
//
// SubscriptionStatusResponse aggregates subscription tier, all tier limits, and
// per-user usage counts (residences/tasks/contractors/documents). The usage
// part requires 4+ COUNT queries against the transatlantic Neon Postgres at
// ~110ms RTT each — about a second of wall-clock per call before parallelism.
// Caching the assembled response collapses that to a single Redis GET (~5ms).
//
// TTL is short (5 min) so stale state self-heals if any mutation path forgets
// to invalidate. The primary correctness mechanism is explicit invalidation
// via InvalidateSubscriptionStatusForUsers — called from every CRUD on
// residences, tasks, contractors, documents, and subscription itself, fanning
// out to every user with access to the affected residence.
const (
subscriptionStatusKeyPrefix = "sub_status:user:"
subscriptionStatusTTL = 5 * time.Minute
)
// CacheSubscriptionStatus stores the assembled SubscriptionStatusResponse for
// a user. Caller passes any encodable value to keep this package free of
// service-layer types; subscription_service.go marshals/unmarshals.
// Best-effort — Redis errors are returned but not fatal.
func (c *CacheService) CacheSubscriptionStatus(ctx context.Context, userID uint, status interface{}) error {
if c == nil {
return nil
}
key := fmt.Sprintf("%s%d", subscriptionStatusKeyPrefix, userID)
data, err := json.Marshal(status)
if err != nil {
return err
}
return c.client.Set(ctx, key, data, subscriptionStatusTTL).Err()
}
// GetCachedSubscriptionStatus unmarshals the cached response into dest.
// Returns redis.Nil on cache miss so callers can distinguish from genuine errors.
func (c *CacheService) GetCachedSubscriptionStatus(ctx context.Context, userID uint, dest interface{}) error {
if c == nil {
return fmt.Errorf("cache not available")
}
key := fmt.Sprintf("%s%d", subscriptionStatusKeyPrefix, userID)
return c.Get(ctx, key, dest)
}
// InvalidateSubscriptionStatusForUsers drops the cached status for one or more
// users. Used by every mutation that could change a user's usage counts —
// residence create/delete/share, task/contractor/document CRUD, subscription
// purchase/cancel/restore. Membership-changing residence ops fan out to every
// user with access to that residence.
func (c *CacheService) InvalidateSubscriptionStatusForUsers(ctx context.Context, userIDs ...uint) error {
if c == nil || len(userIDs) == 0 {
return nil
}
keys := make([]string, len(userIDs))
for i, id := range userIDs {
keys[i] = fmt.Sprintf("%s%d", subscriptionStatusKeyPrefix, id)
}
return c.Delete(ctx, keys...)
}
+15
View File
@@ -146,6 +146,15 @@ func (s *ContractorService) CreateContractor(ctx context.Context, req *requests.
return nil, apperrors.Internal(reloadErr) return nil, apperrors.Internal(reloadErr)
} }
// contractors_count for every user with access to this residence just
// changed. Contractor without a residence is rare (created via global
// add) and only the creator counts it — drop only their cache then.
if req.ResidenceID != nil {
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, *req.ResidenceID)
} else if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
}
resp := responses.NewContractorResponse(contractor) resp := responses.NewContractorResponse(contractor)
return &resp, nil return &resp, nil
} }
@@ -258,6 +267,12 @@ func (s *ContractorService) DeleteContractor(ctx context.Context, contractorID,
return apperrors.Internal(err) return apperrors.Internal(err)
} }
if contractor.ResidenceID != nil {
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, *contractor.ResidenceID)
} else if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
}
return nil return nil
} }
+4
View File
@@ -178,6 +178,8 @@ func (s *DocumentService) CreateDocument(ctx context.Context, req *requests.Crea
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, req.ResidenceID)
resp := responses.NewDocumentResponse(document) resp := responses.NewDocumentResponse(document)
return &resp, nil return &resp, nil
} }
@@ -282,6 +284,8 @@ func (s *DocumentService) DeleteDocument(ctx context.Context, documentID, userID
return apperrors.Internal(err) return apperrors.Internal(err)
} }
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, document.ResidenceID)
return nil return nil
} }
+14 -3
View File
@@ -266,8 +266,10 @@ func (s *ResidenceService) CreateResidence(ctx context.Context, req *requests.Cr
} }
if s.cache != nil { if s.cache != nil {
// Owner now has a new residence — drop cached IDs so the next // Owner now has a new residence — drop cached IDs so the next
// list-residences call doesn't omit it. // list-residences call doesn't omit it. Also bust the subscription
// status cache so properties_count reflects the new residence.
_ = s.cache.InvalidateResidenceIDsForUsers(ctx, ownerID) _ = s.cache.InvalidateResidenceIDsForUsers(ctx, ownerID)
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, ownerID)
} }
// Reload with relations // Reload with relations
@@ -450,6 +452,10 @@ func (s *ResidenceService) DeleteResidence(ctx context.Context, residenceID, use
} }
if s.cache != nil && len(affectedUserIDs) > 0 { if s.cache != nil && len(affectedUserIDs) > 0 {
_ = s.cache.InvalidateResidenceIDsForUsers(ctx, affectedUserIDs...) _ = s.cache.InvalidateResidenceIDsForUsers(ctx, affectedUserIDs...)
// All counts (properties + tasks/contractors/documents that lived in
// the deleted residence) just dropped for every member, not only the
// owner.
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, affectedUserIDs...)
} }
// Get updated summary // Get updated summary
@@ -578,8 +584,11 @@ func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
if s.cache != nil { if s.cache != nil {
// The joining user's residence-IDs cache is now stale. // The joining user's residence-IDs cache is now stale, and their
// subscription status now reflects an extra residence with all of its
// tasks/contractors/documents.
_ = s.cache.InvalidateResidenceIDsForUsers(ctx, userID) _ = s.cache.InvalidateResidenceIDsForUsers(ctx, userID)
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
} }
// Mark share code as used (one-time use) // Mark share code as used (one-time use)
@@ -663,8 +672,10 @@ func (s *ResidenceService) RemoveUser(ctx context.Context, residenceID, userIDTo
return apperrors.Internal(err) return apperrors.Internal(err)
} }
if s.cache != nil { if s.cache != nil {
// The removed user's residence-IDs cache is now stale. // The removed user lost access to one residence and all of its
// tasks/contractors/documents — their counts must be recomputed.
_ = s.cache.InvalidateResidenceIDsForUsers(ctx, userIDToRemove) _ = s.cache.InvalidateResidenceIDsForUsers(ctx, userIDToRemove)
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userIDToRemove)
} }
return nil return nil
+16
View File
@@ -33,6 +33,17 @@ func (s *StripeService) SetCacheService(cache *CacheService) {
s.cache = cache s.cache = cache
} }
// invalidateStatusCache drops the per-user SubscriptionStatusResponse cache
// after any tier-changing webhook so the next /api/subscription/status/ call
// reflects the new state immediately instead of waiting out the 5-min TTL.
// Best-effort: webhook handlers shouldn't fail just because Redis is down.
func (s *StripeService) invalidateStatusCache(userID uint) {
if s.cache == nil {
return
}
_ = s.cache.InvalidateSubscriptionStatusForUsers(context.Background(), userID)
}
// NewStripeService creates a new Stripe service. It initializes the global // NewStripeService creates a new Stripe service. It initializes the global
// Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key // Stripe API key from the STRIPE_SECRET_KEY environment variable. If the key
// is not set, a warning is logged but the service is still returned (matching // is not set, a warning is logged but the service is still returned (matching
@@ -223,6 +234,7 @@ func (s *StripeService) handleCheckoutCompleted(event stripe.Event) error {
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil { if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
s.invalidateStatusCache(userID)
customerID := "" customerID := ""
if session.Customer != nil { if session.Customer != nil {
@@ -264,6 +276,7 @@ func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil { if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
s.invalidateStatusCache(sub.UserID)
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active") log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("Stripe subscription active")
case stripe.SubscriptionStatusPastDue: case stripe.SubscriptionStatusPastDue:
@@ -282,6 +295,7 @@ func (s *StripeService) handleSubscriptionUpdated(event stripe.Event) error {
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil { if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
s.invalidateStatusCache(sub.UserID)
log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended") log.Info().Uint("user_id", sub.UserID).Str("status", string(subscription.Status)).Msg("User downgraded to Free after Stripe subscription ended")
} }
@@ -314,6 +328,7 @@ func (s *StripeService) handleSubscriptionDeleted(event stripe.Event) error {
if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil { if err := s.subscriptionRepo.DowngradeToFree(sub.UserID); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
s.invalidateStatusCache(sub.UserID)
log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted") log.Info().Uint("user_id", sub.UserID).Msg("User downgraded to Free after Stripe subscription deleted")
@@ -346,6 +361,7 @@ func (s *StripeService) handleInvoicePaid(event stripe.Event) error {
if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil { if err := s.subscriptionRepo.UpgradeToPro(sub.UserID, expiresAt, models.PlatformStripe); err != nil {
return apperrors.Internal(err) return apperrors.Internal(err)
} }
s.invalidateStatusCache(sub.UserID)
log.Info(). log.Info().
Uint("user_id", sub.UserID). Uint("user_id", sub.UserID).
+95 -23
View File
@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/apperrors"
@@ -112,8 +113,24 @@ func (s *SubscriptionService) GetSubscription(ctx context.Context, userID uint)
return NewSubscriptionResponse(sub), nil return NewSubscriptionResponse(sub), nil
} }
// GetSubscriptionStatus gets detailed subscription status including limits // GetSubscriptionStatus gets detailed subscription status including limits.
//
// Hot path on the iOS launch screen — runs 7+ sequential SQL queries against
// transatlantic Neon Postgres at ~110ms RTT each (~800ms floor before
// optimization). The assembled response is cached per-user in Redis with a
// 5-minute TTL; mutation paths (residence/task/contractor/document/sub CRUD)
// invalidate via cache.InvalidateSubscriptionStatusForUsers, fanning out to
// every member of a shared residence.
func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID uint) (*SubscriptionStatusResponse, error) { func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID uint) (*SubscriptionStatusResponse, error) {
// Cache fast path — only used on warm reads. Cold reads, trial-start
// branch, and the actual mutation paths below all populate fresh.
if s.cache != nil {
var cached SubscriptionStatusResponse
if err := s.cache.GetCachedSubscriptionStatus(ctx, userID, &cached); err == nil {
return &cached, nil
}
}
sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID) sub, err := s.subscriptionRepo.WithContext(ctx).GetOrCreate(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
@@ -204,43 +221,59 @@ func (s *SubscriptionService) GetSubscriptionStatus(ctx context.Context, userID
resp.TrialActive = sub.IsTrialActive() resp.TrialActive = sub.IsTrialActive()
resp.SubscriptionSource = sub.SubscriptionSource() resp.SubscriptionSource = sub.SubscriptionSource()
// Best-effort cache write. Errors are logged at the cache layer, not fatal.
if s.cache != nil {
_ = s.cache.CacheSubscriptionStatus(ctx, userID, resp)
}
return resp, nil return resp, nil
} }
// getUserUsage calculates current usage for a user. // getUserUsage calculates current usage for a user.
// P-10: Uses CountByOwner for properties count instead of loading all owned residences. //
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)). // Performance: residence ID lookup is one query (we use len() for the
// properties count instead of a redundant COUNT). The three IN-clause counts
// against task_task / task_contractor / task_document don't depend on each
// other and run concurrently via errgroup, collapsing 3 transatlantic RTTs
// into 1. With residence IDs that's 2 RTT total instead of the prior 5.
func (s *SubscriptionService) getUserUsage(ctx context.Context, userID uint) (*UsageResponse, error) { func (s *SubscriptionService) getUserUsage(ctx context.Context, userID uint) (*UsageResponse, error) {
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records // One query — used both for the properties count (len) and as the IN-list
propertiesCount, err := s.residenceRepo.WithContext(ctx).CountByOwner(userID) // for the three downstream counts. Replaces the prior CountByOwner +
if err != nil { // FindResidenceIDsByOwner pair, which queried residence_residence twice
return nil, apperrors.Internal(err) // with the same predicate.
}
// Still need residence IDs for batch counting tasks/contractors/documents
residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByOwner(userID) residenceIDs, err := s.residenceRepo.WithContext(ctx).FindResidenceIDsByOwner(userID)
if err != nil { if err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Count tasks, contractors, and documents across all residences with single queries each var (
tasksCount, err := s.taskRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs) tasksCount int64
if err != nil { contractorsCount int64
return nil, apperrors.Internal(err) documentsCount int64
} )
contractorsCount, err := s.contractorRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs) g, gCtx := errgroup.WithContext(ctx)
if err != nil { g.Go(func() error {
return nil, apperrors.Internal(err) c, err := s.taskRepo.WithContext(gCtx).CountByResidenceIDs(residenceIDs)
} tasksCount = c
return err
documentsCount, err := s.documentRepo.WithContext(ctx).CountByResidenceIDs(residenceIDs) })
if err != nil { g.Go(func() error {
c, err := s.contractorRepo.WithContext(gCtx).CountByResidenceIDs(residenceIDs)
contractorsCount = c
return err
})
g.Go(func() error {
c, err := s.documentRepo.WithContext(gCtx).CountByResidenceIDs(residenceIDs)
documentsCount = c
return err
})
if err := g.Wait(); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
return &UsageResponse{ return &UsageResponse{
PropertiesCount: propertiesCount, PropertiesCount: int64(len(residenceIDs)),
TasksCount: tasksCount, TasksCount: tasksCount,
ContractorsCount: contractorsCount, ContractorsCount: contractorsCount,
DocumentsCount: documentsCount, DocumentsCount: documentsCount,
@@ -416,6 +449,12 @@ func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID u
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Tier flipped — drop cached SubscriptionStatusResponse so the next call
// returns Pro immediately instead of stale Free.
if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
}
return s.GetSubscription(ctx, userID) return s.GetSubscription(ctx, userID)
} }
@@ -473,6 +512,10 @@ func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
}
return s.GetSubscription(ctx, userID) return s.GetSubscription(ctx, userID)
} }
@@ -481,6 +524,10 @@ func (s *SubscriptionService) CancelSubscription(ctx context.Context, userID uin
if err := s.subscriptionRepo.WithContext(ctx).SetAutoRenew(userID, false); err != nil { if err := s.subscriptionRepo.WithContext(ctx).SetAutoRenew(userID, false); err != nil {
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// auto_renew flips a field surfaced in SubscriptionStatusResponse.
if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
}
return s.GetSubscription(ctx, userID) return s.GetSubscription(ctx, userID)
} }
@@ -657,6 +704,31 @@ func NewUpgradeTriggerDataResponse(t *models.UpgradeTrigger) *UpgradeTriggerData
} }
} }
// invalidateSubStatusForResidence drops the per-user subscription_status cache
// for every user with access to a residence (owner + members from
// residence_residence_users). Used by every mutation that changes shared data
// counts — tasks, contractors, documents — so members of a shared residence
// don't see stale `usage` numbers.
//
// Best-effort: failures are logged but never returned. The 5-min cache TTL is
// the safety net if this ever silently fails.
func invalidateSubStatusForResidence(ctx context.Context, cache *CacheService, residenceRepo *repositories.ResidenceRepository, residenceID uint) {
if cache == nil {
return
}
userIDs, err := residenceRepo.FindUserIDsByResidence(residenceID)
if err != nil {
log.Warn().Err(err).Uint("residence_id", residenceID).Msg("sub_status invalidation: residence lookup failed")
return
}
if len(userIDs) == 0 {
return
}
if err := cache.InvalidateSubscriptionStatusForUsers(ctx, userIDs...); err != nil {
log.Warn().Err(err).Uint("residence_id", residenceID).Msg("sub_status invalidation: redis delete failed")
}
}
// FeatureBenefitResponse represents a feature benefit // FeatureBenefitResponse represents a feature benefit
type FeatureBenefitResponse struct { type FeatureBenefitResponse struct {
FeatureName string `json:"feature_name"` FeatureName string `json:"feature_name"`
+21
View File
@@ -197,6 +197,9 @@ func (s *TaskService) CreateTask(ctx context.Context, req *requests.CreateTaskRe
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// tasks_count for every member of this residence just changed.
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, req.ResidenceID)
return &responses.TaskWithSummaryResponse{ return &responses.TaskWithSummaryResponse{
Data: responses.NewTaskResponseWithTime(task, 30, now), Data: responses.NewTaskResponseWithTime(task, 30, now),
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),
@@ -273,6 +276,10 @@ func (s *TaskService) BulkCreateTasks(ctx context.Context, req *requests.BulkCre
created = append(created, responses.NewTaskResponseWithTime(t, 30, now)) created = append(created, responses.NewTaskResponseWithTime(t, 30, now))
} }
// One residence per batch, so a single fanout invalidation covers all
// affected users.
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, req.ResidenceID)
return &responses.BulkCreateTasksResponse{ return &responses.BulkCreateTasksResponse{
Tasks: created, Tasks: created,
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),
@@ -385,6 +392,8 @@ func (s *TaskService) DeleteTask(ctx context.Context, taskID, userID uint) (*res
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, task.ResidenceID)
return &responses.DeleteWithSummaryResponse{ return &responses.DeleteWithSummaryResponse{
Data: "task deleted", Data: "task deleted",
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),
@@ -469,6 +478,9 @@ func (s *TaskService) CancelTask(ctx context.Context, taskID, userID uint, now t
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// CountByResidenceIDs filters out is_cancelled, so this drops tasks_count.
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, task.ResidenceID)
return &responses.TaskWithSummaryResponse{ return &responses.TaskWithSummaryResponse{
Data: responses.NewTaskResponseWithTime(task, 30, now), Data: responses.NewTaskResponseWithTime(task, 30, now),
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),
@@ -508,6 +520,9 @@ func (s *TaskService) UncancelTask(ctx context.Context, taskID, userID uint, now
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reverse of Cancel — tasks_count goes back up.
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, task.ResidenceID)
return &responses.TaskWithSummaryResponse{ return &responses.TaskWithSummaryResponse{
Data: responses.NewTaskResponseWithTime(task, 30, now), Data: responses.NewTaskResponseWithTime(task, 30, now),
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),
@@ -551,6 +566,9 @@ func (s *TaskService) ArchiveTask(ctx context.Context, taskID, userID uint, now
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Same as Cancel — CountByResidenceIDs filters is_archived too.
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, task.ResidenceID)
return &responses.TaskWithSummaryResponse{ return &responses.TaskWithSummaryResponse{
Data: responses.NewTaskResponseWithTime(task, 30, now), Data: responses.NewTaskResponseWithTime(task, 30, now),
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),
@@ -590,6 +608,9 @@ func (s *TaskService) UnarchiveTask(ctx context.Context, taskID, userID uint, no
return nil, apperrors.Internal(err) return nil, apperrors.Internal(err)
} }
// Reverse of Archive — tasks_count goes back up.
invalidateSubStatusForResidence(ctx, s.cache, s.residenceRepo, task.ResidenceID)
return &responses.TaskWithSummaryResponse{ return &responses.TaskWithSummaryResponse{
Data: responses.NewTaskResponseWithTime(task, 30, now), Data: responses.NewTaskResponseWithTime(task, 30, now),
Summary: s.getSummaryForUser(userID), Summary: s.getSummaryForUser(userID),