Add webhook logging, pagination, middleware, migrations, and prod hardening

- Webhook event logging repo and subscription webhook idempotency
- Pagination helper (echohelpers) with cursor/offset support
- Request ID and structured logging middleware
- Push client improvements (FCM HTTP v1, better error handling)
- Task model version column, business constraint migrations, targeted indexes
- Expanded categorization chain tests
- Email service and config hardening
- CI workflow updates, .gitignore additions, .env.example updates

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
treyt
2026-02-24 21:32:09 -06:00
parent 806bd07f80
commit e26116e2cf
50 changed files with 1681 additions and 97 deletions

View File

@@ -25,6 +25,7 @@ type Config struct {
GoogleAuth GoogleAuthConfig
AppleIAP AppleIAPConfig
GoogleIAP GoogleIAPConfig
Features FeatureFlags
}
type ServerConfig struct {
@@ -126,6 +127,17 @@ type StorageConfig struct {
AllowedTypes string // Comma-separated MIME types
}
// FeatureFlags holds kill switches for major subsystems.
// All default to true (enabled). Set to false via env vars to disable.
type FeatureFlags struct {
PushEnabled bool // FEATURE_PUSH_ENABLED (default: true)
EmailEnabled bool // FEATURE_EMAIL_ENABLED (default: true)
WebhooksEnabled bool // FEATURE_WEBHOOKS_ENABLED (default: true)
OnboardingEmailsEnabled bool // FEATURE_ONBOARDING_EMAILS_ENABLED (default: true)
PDFReportsEnabled bool // FEATURE_PDF_REPORTS_ENABLED (default: true)
WorkerEnabled bool // FEATURE_WORKER_ENABLED (default: true)
}
var cfg *Config
// Load reads configuration from environment variables
@@ -236,6 +248,14 @@ func Load() (*Config, error) {
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
},
Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
},
}
// Validate required fields
@@ -302,6 +322,14 @@ func setDefaults() {
viper.SetDefault("APPLE_IAP_SANDBOX", true) // Default to sandbox for safety
// Google IAP defaults - no defaults needed, will fail gracefully if not configured
// Feature flags (all enabled by default)
viper.SetDefault("FEATURE_PUSH_ENABLED", true)
viper.SetDefault("FEATURE_EMAIL_ENABLED", true)
viper.SetDefault("FEATURE_WEBHOOKS_ENABLED", true)
viper.SetDefault("FEATURE_ONBOARDING_EMAILS_ENABLED", true)
viper.SetDefault("FEATURE_PDF_REPORTS_ENABLED", true)
viper.SetDefault("FEATURE_WORKER_ENABLED", true)
}
func validate(cfg *Config) error {

View File

@@ -13,22 +13,38 @@ import (
"github.com/treytartt/casera-api/internal/models"
)
// zerologGormWriter adapts zerolog for GORM's logger interface
type zerologGormWriter struct{}
func (w zerologGormWriter) Printf(format string, args ...interface{}) {
log.Warn().Msgf(format, args...)
}
var db *gorm.DB
// Connect establishes a connection to the PostgreSQL database
func Connect(cfg *config.DatabaseConfig, debug bool) (*gorm.DB, error) {
// Configure GORM logger
// Configure GORM logger with slow query detection
logLevel := logger.Silent
if debug {
logLevel = logger.Info
}
gormLogger := logger.New(
zerologGormWriter{},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logLevel,
IgnoreRecordNotFoundError: true,
},
)
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logLevel),
Logger: gormLogger,
NowFunc: func() time.Time {
return time.Now().UTC()
},
PrepareStmt: true, // Cache prepared statements
PrepareStmt: true,
}
// Connect to database

View File

@@ -0,0 +1,32 @@
package echohelpers
import (
"strconv"
"github.com/labstack/echo/v4"
)
// ParsePagination extracts limit and offset from query parameters with bounded defaults.
// maxLimit caps the maximum page size to prevent unbounded queries.
func ParsePagination(c echo.Context, maxLimit int) (limit, offset int) {
const defaultLimit = 50
limit = defaultLimit
if l := c.QueryParam("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
if limit > maxLimit {
limit = maxLimit
}
offset = 0
if o := c.QueryParam("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
}
}
return limit, offset
}

View File

@@ -0,0 +1,77 @@
package echohelpers
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestParsePagination(t *testing.T) {
tests := []struct {
name string
query string
maxLimit int
expectedLimit int
expectedOffset int
}{
{
name: "Defaults - no query params",
query: "/",
maxLimit: 200,
expectedLimit: 50,
expectedOffset: 0,
},
{
name: "Custom values",
query: "/?limit=20&offset=10",
maxLimit: 200,
expectedLimit: 20,
expectedOffset: 10,
},
{
name: "Max limit capped",
query: "/?limit=500",
maxLimit: 200,
expectedLimit: 200,
expectedOffset: 0,
},
{
name: "Negative offset ignored",
query: "/?offset=-5",
maxLimit: 200,
expectedLimit: 50,
expectedOffset: 0,
},
{
name: "Invalid limit falls back to default",
query: "/?limit=abc",
maxLimit: 200,
expectedLimit: 50,
expectedOffset: 0,
},
{
name: "Zero limit falls back to default",
query: "/?limit=0",
maxLimit: 200,
expectedLimit: 50,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, tt.query, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
limit, offset := ParsePagination(c, tt.maxLimit)
assert.Equal(t, tt.expectedLimit, limit, "limit mismatch")
assert.Equal(t, tt.expectedOffset, offset, "offset mismatch")
})
}
}

View File

@@ -17,17 +17,19 @@ import (
// ResidenceHandler handles residence-related HTTP requests
type ResidenceHandler struct {
residenceService *services.ResidenceService
pdfService *services.PDFService
emailService *services.EmailService
residenceService *services.ResidenceService
pdfService *services.PDFService
emailService *services.EmailService
pdfReportsEnabled bool
}
// NewResidenceHandler creates a new residence handler
func NewResidenceHandler(residenceService *services.ResidenceService, pdfService *services.PDFService, emailService *services.EmailService) *ResidenceHandler {
func NewResidenceHandler(residenceService *services.ResidenceService, pdfService *services.PDFService, emailService *services.EmailService, pdfReportsEnabled bool) *ResidenceHandler {
return &ResidenceHandler{
residenceService: residenceService,
pdfService: pdfService,
emailService: emailService,
residenceService: residenceService,
pdfService: pdfService,
emailService: emailService,
pdfReportsEnabled: pdfReportsEnabled,
}
}
@@ -283,6 +285,10 @@ func (h *ResidenceHandler) GetResidenceTypes(c echo.Context) error {
// GenerateTasksReport handles POST /api/residences/:id/generate-tasks-report/
// Generates a PDF report of tasks for the residence and emails it
func (h *ResidenceHandler) GenerateTasksReport(c echo.Context) error {
if !h.pdfReportsEnabled {
return apperrors.BadRequest("error.feature_disabled")
}
user := c.Get(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)

View File

@@ -25,7 +25,7 @@ func setupResidenceHandler(t *testing.T) (*ResidenceHandler, *echo.Echo, *gorm.D
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
handler := NewResidenceHandler(residenceService, nil, nil)
handler := NewResidenceHandler(residenceService, nil, nil, true)
e := testutil.SetupTestRouter()
return handler, e, db
}

View File

@@ -26,17 +26,23 @@ import (
type SubscriptionWebhookHandler struct {
subscriptionRepo *repositories.SubscriptionRepository
userRepo *repositories.UserRepository
webhookEventRepo *repositories.WebhookEventRepository
appleRootCerts []*x509.Certificate
enabled bool
}
// NewSubscriptionWebhookHandler creates a new webhook handler
func NewSubscriptionWebhookHandler(
subscriptionRepo *repositories.SubscriptionRepository,
userRepo *repositories.UserRepository,
webhookEventRepo *repositories.WebhookEventRepository,
enabled bool,
) *SubscriptionWebhookHandler {
return &SubscriptionWebhookHandler{
subscriptionRepo: subscriptionRepo,
userRepo: userRepo,
webhookEventRepo: webhookEventRepo,
enabled: enabled,
}
}
@@ -94,6 +100,11 @@ type AppleRenewalInfo struct {
// HandleAppleWebhook handles POST /api/subscription/webhook/apple/
func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
if !h.enabled {
log.Printf("Apple Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Printf("Apple Webhook: Failed to read body: %v", err)
@@ -116,6 +127,18 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
log.Printf("Apple Webhook: Received %s (subtype: %s) for bundle %s",
notification.NotificationType, notification.Subtype, notification.Data.BundleID)
// Dedup check using notificationUUID
if notification.NotificationUUID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
if err != nil {
log.Printf("Apple Webhook: Failed to check dedup: %v", err)
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
log.Printf("Apple Webhook: Duplicate event %s, skipping", notification.NotificationUUID)
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
}
// Verify bundle ID matches our app
cfg := config.Get()
if cfg != nil && cfg.AppleIAP.BundleID != "" {
@@ -145,6 +168,13 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
// Still return 200 to prevent Apple from retrying
}
// Record processed event for dedup
if notification.NotificationUUID != "" {
if err := h.webhookEventRepo.RecordEvent("apple", notification.NotificationUUID, notification.NotificationType, ""); err != nil {
log.Printf("Apple Webhook: Failed to record event: %v", err)
}
}
// Always return 200 OK to acknowledge receipt
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
}
@@ -450,6 +480,11 @@ const (
// HandleGoogleWebhook handles POST /api/subscription/webhook/google/
func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
if !h.enabled {
log.Printf("Google Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Printf("Google Webhook: Failed to read body: %v", err)
@@ -475,6 +510,19 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid developer notification"})
}
// Dedup check using messageId
messageID := notification.Message.MessageID
if messageID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
if err != nil {
log.Printf("Google Webhook: Failed to check dedup: %v", err)
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
log.Printf("Google Webhook: Duplicate event %s, skipping", messageID)
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
}
// Handle test notification
if devNotification.TestNotification != nil {
log.Printf("Google Webhook: Received test notification")
@@ -499,6 +547,17 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
}
}
// Record processed event for dedup
if messageID != "" {
eventType := "unknown"
if devNotification.SubscriptionNotification != nil {
eventType = fmt.Sprintf("subscription_%d", devNotification.SubscriptionNotification.NotificationType)
}
if err := h.webhookEventRepo.RecordEvent("google", messageID, eventType, ""); err != nil {
log.Printf("Google Webhook: Failed to record event: %v", err)
}
}
// Acknowledge the message
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
}

View File

@@ -358,7 +358,7 @@ func TestTaskHandler_UncancelTask(t *testing.T) {
// Cancel first
taskRepo := repositories.NewTaskRepository(db)
taskRepo.Cancel(task.ID)
taskRepo.Cancel(task.ID, task.Version)
authGroup := e.Group("/api/tasks")
authGroup.Use(testutil.MockAuthMiddleware(user))
@@ -418,7 +418,7 @@ func TestTaskHandler_UnarchiveTask(t *testing.T) {
// Archive first
taskRepo := repositories.NewTaskRepository(db)
taskRepo.Archive(task.ID)
taskRepo.Archive(task.ID, task.Version)
authGroup := e.Group("/api/tasks")
authGroup.Use(testutil.MockAuthMiddleware(user))

View File

@@ -16,7 +16,7 @@ var translationFS embed.FS
var Bundle *i18n.Bundle
// SupportedLanguages lists all supported language codes
var SupportedLanguages = []string{"en", "es", "fr", "de", "pt"}
var SupportedLanguages = []string{"en", "es", "fr", "de", "pt", "it", "ja", "ko", "nl", "zh"}
// DefaultLanguage is the fallback language
const DefaultLanguage = "en"

View File

@@ -137,7 +137,7 @@ func setupIntegrationTest(t *testing.T) *TestApp {
// Create handlers
authHandler := handlers.NewAuthHandler(authService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(contractorService)
@@ -1621,7 +1621,7 @@ func setupContractorTest(t *testing.T) *TestApp {
// Create handlers
authHandler := handlers.NewAuthHandler(authService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(contractorService)

View File

@@ -63,7 +63,7 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
// Create handlers
authHandler := handlers.NewAuthHandler(authService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil)
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
// Create router

View File

@@ -0,0 +1,53 @@
package middleware
import (
"time"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/models"
)
// StructuredLogger is zerolog-based request logging middleware that includes
// correlation IDs, user IDs, and latency metrics.
func StructuredLogger() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
start := time.Now()
err := next(c)
latency := time.Since(start)
// Build structured log event
event := log.Info()
if c.Response().Status >= 500 {
event = log.Error()
} else if c.Response().Status >= 400 {
event = log.Warn()
}
// Request ID
if reqID := GetRequestID(c); reqID != "" {
event = event.Str("request_id", reqID)
}
// User ID (from auth middleware)
if user, ok := c.Get(AuthUserKey).(*models.User); ok && user != nil {
event = event.Uint("user_id", user.ID)
}
event.
Str("method", c.Request().Method).
Str("path", c.Path()).
Str("uri", c.Request().RequestURI).
Int("status", c.Response().Status).
Int64("latency_ms", latency.Milliseconds()).
Str("remote_ip", c.RealIP()).
Msg("request")
return err
}
}
}

View File

@@ -0,0 +1,43 @@
package middleware
import (
"github.com/google/uuid"
"github.com/labstack/echo/v4"
)
const (
// HeaderXRequestID is the header key for request correlation IDs
HeaderXRequestID = "X-Request-ID"
// ContextKeyRequestID is the echo context key for the request ID
ContextKeyRequestID = "request_id"
)
// RequestIDMiddleware generates a UUID per request, sets it as X-Request-ID header,
// and stores it in the echo context for downstream use.
func RequestIDMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Use existing request ID from header if present, otherwise generate one
reqID := c.Request().Header.Get(HeaderXRequestID)
if reqID == "" {
reqID = uuid.New().String()
}
// Store in context
c.Set(ContextKeyRequestID, reqID)
// Set response header
c.Response().Header().Set(HeaderXRequestID, reqID)
return next(c)
}
}
}
// GetRequestID extracts the request ID from the echo context
func GetRequestID(c echo.Context) string {
if id, ok := c.Get(ContextKeyRequestID).(string); ok {
return id
}
return ""
}

View File

@@ -85,6 +85,9 @@ type Task struct {
IsCancelled bool `gorm:"column:is_cancelled;default:false;index" json:"is_cancelled"`
IsArchived bool `gorm:"column:is_archived;default:false;index" json:"is_archived"`
// Optimistic locking version
Version int `gorm:"column:version;not null;default:1" json:"-"`
// Parent task for recurring tasks
ParentTaskID *uint `gorm:"column:parent_task_id;index" json:"parent_task_id"`
ParentTask *Task `gorm:"foreignKey:ParentTaskID" json:"parent_task,omitempty"`

View File

@@ -16,13 +16,14 @@ const (
// Client provides a unified interface for sending push notifications
type Client struct {
apns *APNsClient
fcm *FCMClient
apns *APNsClient
fcm *FCMClient
enabled bool
}
// NewClient creates a new unified push notification client
func NewClient(cfg *config.PushConfig) (*Client, error) {
client := &Client{}
func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
client := &Client{enabled: enabled}
// Initialize APNs client (iOS)
if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" {
@@ -55,6 +56,10 @@ func NewClient(cfg *config.PushConfig) (*Client, error) {
// SendToIOS sends a push notification to iOS devices
func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag")
return nil
}
if c.apns == nil {
log.Warn().Msg("APNs client not initialized, skipping iOS push")
return nil
@@ -64,6 +69,10 @@ func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message
// SendToAndroid sends a push notification to Android devices
func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag")
return nil
}
if c.fcm == nil {
log.Warn().Msg("FCM client not initialized, skipping Android push")
return nil
@@ -73,6 +82,10 @@ func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, mess
// SendToAll sends a push notification to both iOS and Android devices
func (c *Client) SendToAll(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error {
if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag")
return nil
}
var lastErr error
if len(iosTokens) > 0 {
@@ -105,6 +118,10 @@ func (c *Client) IsAndroidEnabled() bool {
// SendActionableNotification sends notifications with action button support
// iOS receives a category for actionable notifications, Android handles actions via data payload
func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error {
if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag")
return nil
}
var lastErr error
if len(iosTokens) > 0 {

View File

@@ -57,12 +57,19 @@ func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
return r.db.Save(sub).Error
}
// UpgradeToPro upgrades a user to Pro tier
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
// to prevent concurrent subscription mutations from corrupting state.
func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error {
now := time.Now().UTC()
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Updates(map[string]interface{}{
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
now := time.Now().UTC()
return tx.Model(&sub).Updates(map[string]interface{}{
"tier": models.TierPro,
"subscribed_at": now,
"expires_at": expiresAt,
@@ -70,18 +77,27 @@ func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time,
"platform": platform,
"auto_renew": true,
}).Error
})
}
// DowngradeToFree downgrades a user to Free tier
// DowngradeToFree downgrades a user to Free tier using a transaction with row locking
// to prevent concurrent subscription mutations from corrupting state.
func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
now := time.Now().UTC()
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Updates(map[string]interface{}{
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
now := time.Now().UTC()
return tx.Model(&sub).Updates(map[string]interface{}{
"tier": models.TierFree,
"cancelled_at": now,
"auto_renew": false,
}).Error
})
}
// SetAutoRenew sets the auto-renew flag

View File

@@ -1,6 +1,7 @@
package repositories
import (
"errors"
"fmt"
"time"
@@ -11,6 +12,9 @@ import (
"github.com/treytartt/casera-api/internal/task/categorization"
)
// ErrVersionConflict indicates a concurrent modification was detected
var ErrVersionConflict = errors.New("version conflict: task was modified by another request")
// TaskRepository handles database operations for tasks
type TaskRepository struct {
db *gorm.DB
@@ -294,10 +298,39 @@ func (r *TaskRepository) Create(task *models.Task) error {
return r.db.Create(task).Error
}
// Update updates a task
// Uses Omit to exclude associations that shouldn't be updated via Save
// Update updates a task with optimistic locking.
// The update only succeeds if the task's version in the database matches the expected version.
// On success, the local task.Version is incremented to reflect the new version.
func (r *TaskRepository) Update(task *models.Task) error {
return r.db.Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").Save(task).Error
result := r.db.Model(task).
Where("id = ? AND version = ?", task.ID, task.Version).
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
Updates(map[string]interface{}{
"title": task.Title,
"description": task.Description,
"category_id": task.CategoryID,
"priority_id": task.PriorityID,
"frequency_id": task.FrequencyID,
"custom_interval_days": task.CustomIntervalDays,
"in_progress": task.InProgress,
"assigned_to_id": task.AssignedToID,
"due_date": task.DueDate,
"next_due_date": task.NextDueDate,
"estimated_cost": task.EstimatedCost,
"actual_cost": task.ActualCost,
"contractor_id": task.ContractorID,
"is_cancelled": task.IsCancelled,
"is_archived": task.IsArchived,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
task.Version++ // Update local copy
return nil
}
// Delete hard-deletes a task
@@ -307,39 +340,89 @@ func (r *TaskRepository) Delete(id uint) error {
// === Task State Operations ===
// MarkInProgress marks a task as in progress
func (r *TaskRepository) MarkInProgress(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("in_progress", true).Error
// MarkInProgress marks a task as in progress with optimistic locking.
func (r *TaskRepository) MarkInProgress(id uint, version int) error {
result := r.db.Model(&models.Task{}).
Where("id = ? AND version = ?", id, version).
Updates(map[string]interface{}{
"in_progress": true,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
return nil
}
// Cancel cancels a task
func (r *TaskRepository) Cancel(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_cancelled", true).Error
// Cancel cancels a task with optimistic locking.
func (r *TaskRepository) Cancel(id uint, version int) error {
result := r.db.Model(&models.Task{}).
Where("id = ? AND version = ?", id, version).
Updates(map[string]interface{}{
"is_cancelled": true,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
return nil
}
// Uncancel uncancels a task
func (r *TaskRepository) Uncancel(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_cancelled", false).Error
// Uncancel uncancels a task with optimistic locking.
func (r *TaskRepository) Uncancel(id uint, version int) error {
result := r.db.Model(&models.Task{}).
Where("id = ? AND version = ?", id, version).
Updates(map[string]interface{}{
"is_cancelled": false,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
return nil
}
// Archive archives a task
func (r *TaskRepository) Archive(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_archived", true).Error
// Archive archives a task with optimistic locking.
func (r *TaskRepository) Archive(id uint, version int) error {
result := r.db.Model(&models.Task{}).
Where("id = ? AND version = ?", id, version).
Updates(map[string]interface{}{
"is_archived": true,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
return nil
}
// Unarchive unarchives a task
func (r *TaskRepository) Unarchive(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_archived", false).Error
// Unarchive unarchives a task with optimistic locking.
func (r *TaskRepository) Unarchive(id uint, version int) error {
result := r.db.Model(&models.Task{}).
Where("id = ? AND version = ?", id, version).
Updates(map[string]interface{}{
"is_archived": false,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
return nil
}
// === Kanban Board ===

View File

@@ -113,7 +113,7 @@ func TestTaskRepository_Cancel(t *testing.T) {
assert.False(t, task.IsCancelled)
err := repo.Cancel(task.ID)
err := repo.Cancel(task.ID, task.Version)
require.NoError(t, err)
found, err := repo.FindByID(task.ID)
@@ -129,8 +129,8 @@ func TestTaskRepository_Uncancel(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
repo.Cancel(task.ID)
err := repo.Uncancel(task.ID)
repo.Cancel(task.ID, task.Version)
err := repo.Uncancel(task.ID, task.Version+1) // version incremented by Cancel
require.NoError(t, err)
found, err := repo.FindByID(task.ID)
@@ -146,7 +146,7 @@ func TestTaskRepository_Archive(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
err := repo.Archive(task.ID)
err := repo.Archive(task.ID, task.Version)
require.NoError(t, err)
found, err := repo.FindByID(task.ID)
@@ -162,8 +162,8 @@ func TestTaskRepository_Unarchive(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
repo.Archive(task.ID)
err := repo.Unarchive(task.ID)
repo.Archive(task.ID, task.Version)
err := repo.Unarchive(task.ID, task.Version+1) // version incremented by Archive
require.NoError(t, err)
found, err := repo.FindByID(task.ID)
@@ -316,7 +316,7 @@ func TestKanbanBoard_CancelledTasksHiddenFromKanbanBoard(t *testing.T) {
// Create a cancelled task
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Cancelled Task")
repo.Cancel(task.ID)
repo.Cancel(task.ID, task.Version)
board, err := repo.GetKanbanData(residence.ID, 30, time.Now().UTC())
require.NoError(t, err)
@@ -571,7 +571,7 @@ func TestKanbanBoard_ArchivedTasksHiddenFromKanbanBoard(t *testing.T) {
// Create a regular task and an archived task
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Regular Task")
archivedTask := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Archived Task")
repo.Archive(archivedTask.ID)
repo.Archive(archivedTask.ID, archivedTask.Version)
board, err := repo.GetKanbanData(residence.ID, 30, time.Now().UTC())
require.NoError(t, err)
@@ -856,7 +856,7 @@ func TestKanbanBoard_MultipleResidences(t *testing.T) {
// Create a cancelled task in house 1
cancelledTask := testutil.CreateTestTask(t, db, residence1.ID, user.ID, "Cancelled in House 1")
repo.Cancel(cancelledTask.ID)
repo.Cancel(cancelledTask.ID, cancelledTask.Version)
board, err := repo.GetKanbanDataForMultipleResidences([]uint{residence1.ID, residence2.ID}, 30, time.Now().UTC())
require.NoError(t, err)

View File

@@ -0,0 +1,54 @@
package repositories
import (
"time"
"gorm.io/gorm"
)
// WebhookEvent represents a processed webhook event for deduplication
type WebhookEvent struct {
ID uint `gorm:"primaryKey"`
EventID string `gorm:"column:event_id;size:255;not null;uniqueIndex:idx_provider_event_id"`
Provider string `gorm:"column:provider;size:20;not null;uniqueIndex:idx_provider_event_id"`
EventType string `gorm:"column:event_type;size:100;not null"`
ProcessedAt time.Time `gorm:"column:processed_at;autoCreateTime"`
PayloadHash string `gorm:"column:payload_hash;size:64"`
}
func (WebhookEvent) TableName() string {
return "webhook_event_log"
}
// WebhookEventRepository handles webhook event deduplication
type WebhookEventRepository struct {
db *gorm.DB
}
// NewWebhookEventRepository creates a new webhook event repository
func NewWebhookEventRepository(db *gorm.DB) *WebhookEventRepository {
return &WebhookEventRepository{db: db}
}
// HasProcessed checks if an event has already been processed
func (r *WebhookEventRepository) HasProcessed(provider, eventID string) (bool, error) {
var count int64
err := r.db.Model(&WebhookEvent{}).
Where("provider = ? AND event_id = ?", provider, eventID).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// RecordEvent records a processed webhook event
func (r *WebhookEventRepository) RecordEvent(provider, eventID, eventType, payloadHash string) error {
event := &WebhookEvent{
EventID: eventID,
Provider: provider,
EventType: eventType,
PayloadHash: payloadHash,
}
return r.db.Create(event).Error
}

View File

@@ -0,0 +1,104 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// setupWebhookTestDB creates an in-memory SQLite database with the
// WebhookEvent table auto-migrated. This is separate from testutil.SetupTestDB
// because WebhookEvent lives in the repositories package (not models/) and
// only needs its own table for testing.
func setupWebhookTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&WebhookEvent{})
require.NoError(t, err)
return db
}
func TestWebhookEventRepo_RecordAndCheck(t *testing.T) {
db := setupWebhookTestDB(t)
repo := NewWebhookEventRepository(db)
// Record an event
err := repo.RecordEvent("apple", "evt_001", "INITIAL_BUY", "abc123hash")
require.NoError(t, err)
// HasProcessed should return true for the same provider + event ID
processed, err := repo.HasProcessed("apple", "evt_001")
require.NoError(t, err)
assert.True(t, processed, "expected HasProcessed to return true for a recorded event")
// HasProcessed should return false for a different event ID
processed, err = repo.HasProcessed("apple", "evt_999")
require.NoError(t, err)
assert.False(t, processed, "expected HasProcessed to return false for an unrecorded event ID")
// HasProcessed should return false for a different provider
processed, err = repo.HasProcessed("google", "evt_001")
require.NoError(t, err)
assert.False(t, processed, "expected HasProcessed to return false for a different provider")
}
func TestWebhookEventRepo_DuplicateInsert(t *testing.T) {
db := setupWebhookTestDB(t)
repo := NewWebhookEventRepository(db)
// First insert should succeed
err := repo.RecordEvent("apple", "evt_dup", "RENEWAL", "hash1")
require.NoError(t, err)
// Second insert with the same provider + event ID should fail (unique constraint)
err = repo.RecordEvent("apple", "evt_dup", "RENEWAL", "hash1")
require.Error(t, err, "expected an error when inserting a duplicate provider + event_id")
// Verify only one row exists
var count int64
db.Model(&WebhookEvent{}).Where("provider = ? AND event_id = ?", "apple", "evt_dup").Count(&count)
assert.Equal(t, int64(1), count, "expected exactly one row for the duplicated event")
}
func TestWebhookEventRepo_DifferentProviders(t *testing.T) {
db := setupWebhookTestDB(t)
repo := NewWebhookEventRepository(db)
sharedEventID := "evt_shared_123"
// Record event for "apple" provider
err := repo.RecordEvent("apple", sharedEventID, "INITIAL_BUY", "applehash")
require.NoError(t, err)
// HasProcessed should return true for "apple"
processed, err := repo.HasProcessed("apple", sharedEventID)
require.NoError(t, err)
assert.True(t, processed, "expected HasProcessed to return true for apple provider")
// HasProcessed should return false for "google" with the same event ID
processed, err = repo.HasProcessed("google", sharedEventID)
require.NoError(t, err)
assert.False(t, processed, "expected HasProcessed to return false for google provider with the same event ID")
// Recording the same event ID under "google" should succeed (different provider)
err = repo.RecordEvent("google", sharedEventID, "INITIAL_BUY", "googlehash")
require.NoError(t, err)
// Now both providers should show as processed
processed, err = repo.HasProcessed("apple", sharedEventID)
require.NoError(t, err)
assert.True(t, processed, "expected apple to still be processed")
processed, err = repo.HasProcessed("google", sharedEventID)
require.NoError(t, err)
assert.True(t, processed, "expected google to now be processed")
}

View File

@@ -54,8 +54,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// which don't use trailing slashes. Mobile API routes explicitly include trailing slashes.
// Global middleware
e.Use(custommiddleware.RequestIDMiddleware())
e.Use(utils.EchoRecovery())
e.Use(utils.EchoLogger())
e.Use(custommiddleware.StructuredLogger())
e.Use(middleware.BodyLimit("1M")) // 1MB default for JSON payloads
e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
Timeout: 30 * time.Second,
}))
e.Use(corsMiddleware(cfg))
e.Use(i18n.Middleware())
@@ -126,8 +131,11 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
taskTemplateService := services.NewTaskTemplateService(taskTemplateRepo)
// Initialize webhook event repo for deduplication
webhookEventRepo := repositories.NewWebhookEventRepository(deps.DB)
// Initialize webhook handler for Apple/Google subscription notifications
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo)
subscriptionWebhookHandler := handlers.NewSubscriptionWebhookHandler(subscriptionRepo, userRepo, webhookEventRepo, cfg.Features.WebhooksEnabled)
// Initialize middleware
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
@@ -141,7 +149,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
authHandler.SetAppleAuthService(appleAuthService)
authHandler.SetGoogleAuthService(googleAuthService)
userHandler := handlers.NewUserHandler(userService)
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService)
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled)
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
contractorHandler := handlers.NewContractorHandler(contractorService)
documentHandler := handlers.NewDocumentHandler(documentService, deps.StorageService)

View File

@@ -15,22 +15,28 @@ import (
// EmailService handles sending emails
type EmailService struct {
cfg *config.EmailConfig
dialer *gomail.Dialer
cfg *config.EmailConfig
dialer *gomail.Dialer
enabled bool
}
// NewEmailService creates a new email service
func NewEmailService(cfg *config.EmailConfig) *EmailService {
func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password)
return &EmailService{
cfg: cfg,
dialer: dialer,
cfg: cfg,
dialer: dialer,
enabled: enabled,
}
}
// SendEmail sends an email
func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
if !s.enabled {
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
@@ -64,6 +70,10 @@ type EmbeddedImage struct {
// SendEmailWithAttachment sends an email with an attachment
func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody string, attachment *EmailAttachment) error {
if !s.enabled {
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
@@ -94,6 +104,10 @@ func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody s
// SendEmailWithEmbeddedImages sends an email with inline embedded images
func (s *EmailService) SendEmailWithEmbeddedImages(to, subject, htmlBody, textBody string, images []EmbeddedImage) error {
if !s.enabled {
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)

View File

@@ -271,6 +271,9 @@ func (s *TaskService) UpdateTask(taskID, userID uint, req *requests.UpdateTaskRe
}
if err := s.taskRepo.Update(task); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
return nil, apperrors.Internal(err)
}
@@ -337,7 +340,10 @@ func (s *TaskService) MarkInProgress(taskID, userID uint, now time.Time) (*respo
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.MarkInProgress(taskID); err != nil {
if err := s.taskRepo.MarkInProgress(taskID, task.Version); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
return nil, apperrors.Internal(err)
}
@@ -377,7 +383,10 @@ func (s *TaskService) CancelTask(taskID, userID uint, now time.Time) (*responses
return nil, apperrors.BadRequest("error.task_already_cancelled")
}
if err := s.taskRepo.Cancel(taskID); err != nil {
if err := s.taskRepo.Cancel(taskID, task.Version); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
return nil, apperrors.Internal(err)
}
@@ -413,7 +422,10 @@ func (s *TaskService) UncancelTask(taskID, userID uint, now time.Time) (*respons
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.Uncancel(taskID); err != nil {
if err := s.taskRepo.Uncancel(taskID, task.Version); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
return nil, apperrors.Internal(err)
}
@@ -453,7 +465,10 @@ func (s *TaskService) ArchiveTask(taskID, userID uint, now time.Time) (*response
return nil, apperrors.BadRequest("error.task_already_archived")
}
if err := s.taskRepo.Archive(taskID); err != nil {
if err := s.taskRepo.Archive(taskID, task.Version); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
return nil, apperrors.Internal(err)
}
@@ -489,7 +504,10 @@ func (s *TaskService) UnarchiveTask(taskID, userID uint, now time.Time) (*respon
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.Unarchive(taskID); err != nil {
if err := s.taskRepo.Unarchive(taskID, task.Version); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
return nil, apperrors.Internal(err)
}
@@ -581,6 +599,9 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
task.InProgress = false
}
if err := s.taskRepo.Update(task); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after completion")
}
@@ -702,6 +723,9 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
task.InProgress = false
}
if err := s.taskRepo.Update(task); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
return apperrors.Conflict("error.version_conflict")
}
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion")
return apperrors.Internal(err) // Return error so caller knows the update failed
}

View File

@@ -0,0 +1,241 @@
package categorization_test
import (
"math/rand"
"testing"
"time"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/task/categorization"
)
// validColumns is the complete set of KanbanColumn values the chain may return.
var validColumns = map[categorization.KanbanColumn]bool{
categorization.ColumnOverdue: true,
categorization.ColumnDueSoon: true,
categorization.ColumnUpcoming: true,
categorization.ColumnInProgress: true,
categorization.ColumnCompleted: true,
categorization.ColumnCancelled: true,
}
// FuzzCategorizeTask feeds random task states into CategorizeTask and asserts
// that the result is always a non-empty, valid KanbanColumn constant.
func FuzzCategorizeTask(f *testing.F) {
f.Add(false, false, false, false, false, 0, false, 0)
f.Add(true, false, false, false, false, 0, false, 0)
f.Add(false, true, false, false, false, 0, false, 0)
f.Add(false, false, true, false, false, 0, false, 0)
f.Add(false, false, false, true, false, 0, false, 0)
f.Add(false, false, false, false, true, -5, false, 0)
f.Add(false, false, false, false, false, 0, true, -5)
f.Add(false, false, false, false, false, 0, true, 5)
f.Add(false, false, false, false, false, 0, true, 60)
f.Add(true, true, true, true, true, -10, true, -10)
f.Add(false, false, false, false, true, 100, true, 100)
f.Fuzz(func(t *testing.T,
isCancelled, isArchived, inProgress, hasCompletions bool,
hasDueDate bool, dueDateOffsetDays int,
hasNextDueDate bool, nextDueDateOffsetDays int,
) {
now := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
task := &models.Task{
IsCancelled: isCancelled,
IsArchived: isArchived,
InProgress: inProgress,
}
if hasDueDate {
d := now.AddDate(0, 0, dueDateOffsetDays)
task.DueDate = &d
}
if hasNextDueDate {
d := now.AddDate(0, 0, nextDueDateOffsetDays)
task.NextDueDate = &d
}
if hasCompletions {
task.Completions = []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}},
}
} else {
task.Completions = []models.TaskCompletion{}
}
result := categorization.CategorizeTask(task, 30)
if result == "" {
t.Fatalf("CategorizeTask returned empty string for task %+v", task)
}
if !validColumns[result] {
t.Fatalf("CategorizeTask returned invalid column %q for task %+v", result, task)
}
})
}
// === Property Tests (1000 random tasks) ===
// TestCategorizeTask_PropertyEveryTaskMapsToExactlyOneColumn uses random tasks
// to validate the property that every task maps to exactly one column.
func TestCategorizeTask_PropertyEveryTaskMapsToExactlyOneColumn(t *testing.T) {
rng := rand.New(rand.NewSource(42)) // Deterministic seed for reproducibility
now := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
for i := 0; i < 1000; i++ {
task := randomTask(rng, now)
column := categorization.CategorizeTask(task, 30)
if !validColumns[column] {
t.Fatalf("Task %d mapped to invalid column %q: %+v", i, column, task)
}
}
}
// TestCategorizeTask_CancelledAlwaysWins validates that cancelled takes priority
// over all other states regardless of other flags using randomized tasks.
func TestCategorizeTask_CancelledAlwaysWins(t *testing.T) {
rng := rand.New(rand.NewSource(42))
now := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
for i := 0; i < 500; i++ {
task := randomTask(rng, now)
task.IsCancelled = true
column := categorization.CategorizeTask(task, 30)
if column != categorization.ColumnCancelled {
t.Fatalf("Cancelled task %d mapped to %q instead of cancelled_tasks: %+v", i, column, task)
}
}
}
// === Timezone / DST Boundary Tests ===
// TestCategorizeTask_UTCMidnightBoundary tests task categorization at exactly
// UTC midnight, which is the boundary between days.
func TestCategorizeTask_UTCMidnightBoundary(t *testing.T) {
midnight := time.Date(2025, 3, 9, 0, 0, 0, 0, time.UTC)
dueDate := midnight
task := &models.Task{
DueDate: &dueDate,
}
// At midnight of the due date, task is NOT overdue (due today)
column := categorization.CategorizeTaskWithTime(task, 30, midnight)
if column == categorization.ColumnOverdue {
t.Errorf("Task due today should not be overdue at midnight, got %q", column)
}
// One day later, task IS overdue
nextDay := midnight.AddDate(0, 0, 1)
column = categorization.CategorizeTaskWithTime(task, 30, nextDay)
if column != categorization.ColumnOverdue {
t.Errorf("Task due yesterday should be overdue, got %q", column)
}
}
// TestCategorizeTask_DSTSpringForward tests categorization across DST spring-forward.
// In US Eastern time, 2:00 AM jumps to 3:00 AM on the second Sunday of March.
func TestCategorizeTask_DSTSpringForward(t *testing.T) {
loc, err := time.LoadLocation("America/New_York")
if err != nil {
t.Skip("America/New_York timezone not available")
}
// March 9, 2025 is DST spring-forward in Eastern Time
dueDate := time.Date(2025, 3, 9, 0, 0, 0, 0, time.UTC) // Stored as UTC
task := &models.Task{DueDate: &dueDate}
// Check at start of March 9 in Eastern time
nowET := time.Date(2025, 3, 9, 0, 0, 0, 0, loc)
column := categorization.CategorizeTaskWithTime(task, 30, nowET)
if column == categorization.ColumnOverdue {
t.Errorf("Task due March 9 should not be overdue on March 9 (DST spring-forward), got %q", column)
}
// Check at March 10 - should be overdue now
nextDayET := time.Date(2025, 3, 10, 0, 0, 0, 0, loc)
column = categorization.CategorizeTaskWithTime(task, 30, nextDayET)
if column != categorization.ColumnOverdue {
t.Errorf("Task due March 9 should be overdue on March 10, got %q", column)
}
}
// TestCategorizeTask_DSTFallBack tests categorization across DST fall-back.
// In US Eastern time, 2:00 AM jumps back to 1:00 AM on the first Sunday of November.
func TestCategorizeTask_DSTFallBack(t *testing.T) {
loc, err := time.LoadLocation("America/New_York")
if err != nil {
t.Skip("America/New_York timezone not available")
}
// November 2, 2025 is DST fall-back in Eastern Time
dueDate := time.Date(2025, 11, 2, 0, 0, 0, 0, time.UTC)
task := &models.Task{DueDate: &dueDate}
// On the due date itself - not overdue
nowET := time.Date(2025, 11, 2, 0, 0, 0, 0, loc)
column := categorization.CategorizeTaskWithTime(task, 30, nowET)
if column == categorization.ColumnOverdue {
t.Errorf("Task due Nov 2 should not be overdue on Nov 2 (DST fall-back), got %q", column)
}
// Next day - should be overdue
nextDayET := time.Date(2025, 11, 3, 0, 0, 0, 0, loc)
column = categorization.CategorizeTaskWithTime(task, 30, nextDayET)
if column != categorization.ColumnOverdue {
t.Errorf("Task due Nov 2 should be overdue on Nov 3, got %q", column)
}
}
// TestIsOverdue_UTCMidnightEdge validates the overdue predicate at exact midnight.
func TestIsOverdue_UTCMidnightEdge(t *testing.T) {
dueDate := time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC)
task := &models.Task{DueDate: &dueDate}
// On due date: NOT overdue
atDueDate := time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC)
column := categorization.CategorizeTaskWithTime(task, 30, atDueDate)
if column == categorization.ColumnOverdue {
t.Error("Task should not be overdue on its due date")
}
// One second after midnight next day: overdue
afterDueDate := time.Date(2026, 1, 1, 0, 0, 1, 0, time.UTC)
column = categorization.CategorizeTaskWithTime(task, 30, afterDueDate)
if column != categorization.ColumnOverdue {
t.Errorf("Task should be overdue after its due date, got %q", column)
}
}
// === Helper ===
func randomTask(rng *rand.Rand, baseTime time.Time) *models.Task {
task := &models.Task{
IsCancelled: rng.Intn(10) == 0, // 10% chance
IsArchived: rng.Intn(10) == 0, // 10% chance
InProgress: rng.Intn(5) == 0, // 20% chance
}
if rng.Intn(4) > 0 { // 75% have due date
d := baseTime.AddDate(0, 0, rng.Intn(120)-60)
task.DueDate = &d
}
if rng.Intn(3) == 0 { // 33% recurring
d := baseTime.AddDate(0, 0, rng.Intn(120)-60)
task.NextDueDate = &d
}
if rng.Intn(3) == 0 { // 33% have completions
count := rng.Intn(3) + 1
for i := 0; i < count; i++ {
task.Completions = append(task.Completions, models.TaskCompletion{
BaseModel: models.BaseModel{ID: uint(i + 1)},
})
}
}
return task
}

View File

@@ -4,10 +4,14 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/task/categorization"
)
// Ensure assert is used (referenced in fuzz/property tests below)
var _ = assert.Equal
// Helper to create a time pointer
func timePtr(t time.Time) *time.Time {
return &t
@@ -545,3 +549,255 @@ func TestTimezone_MultipleTasksIntoColumns(t *testing.T) {
t.Errorf("Expected task 3 (Jan 15) in due_soon column")
}
}
// ============================================================================
// FUZZ / PROPERTY TESTS
// These tests verify invariants that must hold for ALL possible task states,
// not just specific hand-crafted examples.
//
// validColumns is defined in chain_breakit_test.go and shared across test files
// in the categorization_test package.
// ============================================================================
// FuzzCategorizeTaskExtended feeds random task states into CategorizeTask using
// separate boolean flags for date presence and day-offset integers for date
// values. This complements FuzzCategorizeTask (in chain_breakit_test.go) by
// exercising the nil-date paths more directly.
func FuzzCategorizeTaskExtended(f *testing.F) {
// Seed corpus: cover a representative spread of boolean/date combinations.
// isCancelled, isArchived, inProgress, hasCompletions,
// hasDueDate, dueDateOffsetDays, hasNextDueDate, nextDueDateOffsetDays
f.Add(false, false, false, false, false, 0, false, 0)
f.Add(true, false, false, false, false, 0, false, 0)
f.Add(false, true, false, false, false, 0, false, 0)
f.Add(false, false, true, false, false, 0, false, 0)
f.Add(false, false, false, true, false, 0, false, 0) // completed (no next due, has completions)
f.Add(false, false, false, false, true, -5, false, 0) // overdue via DueDate
f.Add(false, false, false, false, false, 0, true, -5) // overdue via NextDueDate
f.Add(false, false, false, false, false, 0, true, 5) // due soon
f.Add(false, false, false, false, false, 0, true, 60) // upcoming
f.Add(true, true, true, true, true, -10, true, -10) // everything set
f.Add(false, false, false, false, true, 100, true, 100) // far future
f.Fuzz(func(t *testing.T,
isCancelled, isArchived, inProgress, hasCompletions bool,
hasDueDate bool, dueDateOffsetDays int,
hasNextDueDate bool, nextDueDateOffsetDays int,
) {
now := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
task := &models.Task{
IsCancelled: isCancelled,
IsArchived: isArchived,
InProgress: inProgress,
}
if hasDueDate {
d := now.AddDate(0, 0, dueDateOffsetDays)
task.DueDate = &d
}
if hasNextDueDate {
d := now.AddDate(0, 0, nextDueDateOffsetDays)
task.NextDueDate = &d
}
if hasCompletions {
task.Completions = []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}},
}
} else {
task.Completions = []models.TaskCompletion{}
}
result := categorization.CategorizeTask(task, 30)
// Invariant 1: result must never be the empty string.
if result == "" {
t.Fatalf("CategorizeTask returned empty string for task %+v", task)
}
// Invariant 2: result must be one of the valid KanbanColumn constants.
if !validColumns[result] {
t.Fatalf("CategorizeTask returned invalid column %q for task %+v", result, task)
}
})
}
// TestCategorizeTask_MutuallyExclusive exhaustively enumerates all boolean
// state combinations (IsCancelled, IsArchived, InProgress, hasCompletions)
// crossed with representative date positions (no date, past, today, within
// threshold, beyond threshold) and asserts that every task maps to exactly
// one valid, non-empty column.
func TestCategorizeTask_MutuallyExclusive(t *testing.T) {
now := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
daysThreshold := 30
// Date scenarios relative to "now" for both DueDate and NextDueDate.
type dateScenario struct {
name string
dueDate *time.Time
nextDue *time.Time
}
past := now.AddDate(0, 0, -5)
today := now
withinThreshold := now.AddDate(0, 0, 10)
beyondThreshold := now.AddDate(0, 0, 60)
dateScenarios := []dateScenario{
{"no dates", nil, nil},
{"DueDate past only", &past, nil},
{"DueDate today only", &today, nil},
{"DueDate within threshold", &withinThreshold, nil},
{"DueDate beyond threshold", &beyondThreshold, nil},
{"NextDueDate past", nil, &past},
{"NextDueDate today", nil, &today},
{"NextDueDate within threshold", nil, &withinThreshold},
{"NextDueDate beyond threshold", nil, &beyondThreshold},
{"both past", &past, &past},
{"DueDate past NextDueDate future", &past, &withinThreshold},
{"both beyond threshold", &beyondThreshold, &beyondThreshold},
}
boolCombos := []struct {
cancelled, archived, inProgress, hasCompletions bool
}{
{false, false, false, false},
{true, false, false, false},
{false, true, false, false},
{false, false, true, false},
{false, false, false, true},
{true, true, false, false},
{true, false, true, false},
{true, false, false, true},
{false, true, true, false},
{false, true, false, true},
{false, false, true, true},
{true, true, true, false},
{true, true, false, true},
{true, false, true, true},
{false, true, true, true},
{true, true, true, true},
}
for _, ds := range dateScenarios {
for _, bc := range boolCombos {
task := &models.Task{
IsCancelled: bc.cancelled,
IsArchived: bc.archived,
InProgress: bc.inProgress,
DueDate: ds.dueDate,
NextDueDate: ds.nextDue,
}
if bc.hasCompletions {
task.Completions = []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}},
}
} else {
task.Completions = []models.TaskCompletion{}
}
result := categorization.CategorizeTaskWithTime(task, daysThreshold, now)
assert.NotEmpty(t, result,
"empty column for dates=%s cancelled=%v archived=%v inProgress=%v completions=%v",
ds.name, bc.cancelled, bc.archived, bc.inProgress, bc.hasCompletions)
assert.True(t, validColumns[result],
"invalid column %q for dates=%s cancelled=%v archived=%v inProgress=%v completions=%v",
result, ds.name, bc.cancelled, bc.archived, bc.inProgress, bc.hasCompletions)
}
}
}
// TestCategorizeTask_CancelledAlwaysCancelled verifies the property that any
// task with IsCancelled=true is always categorized into ColumnCancelled,
// regardless of all other field values.
func TestCategorizeTask_CancelledAlwaysCancelled(t *testing.T) {
now := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
daysThreshold := 30
past := now.AddDate(0, 0, -5)
future := now.AddDate(0, 0, 10)
farFuture := now.AddDate(0, 0, 60)
dates := []*time.Time{nil, &past, &future, &farFuture}
bools := []bool{true, false}
for _, isArchived := range bools {
for _, inProgress := range bools {
for _, hasCompletions := range bools {
for _, dueDate := range dates {
for _, nextDueDate := range dates {
task := &models.Task{
IsCancelled: true, // always cancelled
IsArchived: isArchived,
InProgress: inProgress,
DueDate: dueDate,
NextDueDate: nextDueDate,
}
if hasCompletions {
task.Completions = []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}},
}
} else {
task.Completions = []models.TaskCompletion{}
}
result := categorization.CategorizeTaskWithTime(task, daysThreshold, now)
assert.Equal(t, categorization.ColumnCancelled, result,
"cancelled task should always map to ColumnCancelled, got %q "+
"(archived=%v inProgress=%v completions=%v dueDate=%v nextDueDate=%v)",
result, isArchived, inProgress, hasCompletions, dueDate, nextDueDate)
}
}
}
}
}
}
// TestCategorizeTask_ArchivedAlwaysArchived verifies the property that any
// task with IsArchived=true and IsCancelled=false is always categorized into
// ColumnCancelled (archived tasks share the cancelled column as both represent
// "inactive" states), regardless of all other field values.
func TestCategorizeTask_ArchivedAlwaysArchived(t *testing.T) {
now := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
daysThreshold := 30
past := now.AddDate(0, 0, -5)
future := now.AddDate(0, 0, 10)
farFuture := now.AddDate(0, 0, 60)
dates := []*time.Time{nil, &past, &future, &farFuture}
bools := []bool{true, false}
for _, inProgress := range bools {
for _, hasCompletions := range bools {
for _, dueDate := range dates {
for _, nextDueDate := range dates {
task := &models.Task{
IsCancelled: false, // not cancelled
IsArchived: true, // always archived
InProgress: inProgress,
DueDate: dueDate,
NextDueDate: nextDueDate,
}
if hasCompletions {
task.Completions = []models.TaskCompletion{
{BaseModel: models.BaseModel{ID: 1}},
}
} else {
task.Completions = []models.TaskCompletion{}
}
result := categorization.CategorizeTaskWithTime(task, daysThreshold, now)
assert.Equal(t, categorization.ColumnCancelled, result,
"archived (non-cancelled) task should always map to ColumnCancelled, got %q "+
"(inProgress=%v completions=%v dueDate=%v nextDueDate=%v)",
result, inProgress, hasCompletions, dueDate, nextDueDate)
}
}
}
}
}

View File

@@ -208,6 +208,7 @@ func CreateTestTask(t *testing.T, db *gorm.DB, residenceID, createdByID uint, ti
Title: title,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
err := db.Create(task).Error
require.NoError(t, err)

View File

@@ -542,6 +542,11 @@ func NewSendPushTask(userID uint, title, message string, data map[string]string)
// 2. Users who created a residence 5+ days ago but haven't created any tasks
// Each email type is only sent once per user, ever.
func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task) error {
if !h.config.Features.OnboardingEmailsEnabled {
log.Debug().Msg("Onboarding emails disabled by feature flag, skipping")
return nil
}
log.Info().Msg("Processing onboarding emails...")
if h.onboardingService == nil {