Harden API security: input validation, safe auth extraction, new tests, and deploy config

Comprehensive security hardening from audit findings:
- Add validation tags to all DTO request structs (max lengths, ranges, enums)
- Replace unsafe type assertions with MustGetAuthUser helper across all handlers
- Remove query-param token auth from admin middleware (prevents URL token leakage)
- Add request validation calls in handlers that were missing c.Validate()
- Remove goroutines in handlers (timezone update now synchronous)
- Add sanitize middleware and path traversal protection (path_utils)
- Stop resetting admin passwords on migration restart
- Warn on well-known default SECRET_KEY
- Add ~30 new test files covering security regressions, auth safety, repos, and services
- Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-02 09:48:01 -06:00
parent 56d6fa4514
commit 7690f07a2b
123 changed files with 8321 additions and 750 deletions

View File

@@ -195,6 +195,18 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
if req.IsFavorite != nil {
contractor.IsFavorite = *req.IsFavorite
}
// If residence_id is provided, verify the user has access to the NEW residence.
// This prevents an attacker from reassigning a contractor to someone else's residence.
if req.ResidenceID != nil {
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, apperrors.Forbidden("error.residence_access_denied")
}
}
// If residence_id is not sent in the request (nil), it means the user
// removed the residence association - contractor becomes personal
contractor.ResidenceID = req.ResidenceID

View File

@@ -0,0 +1,98 @@
package services
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/testutil"
)
func setupContractorService(t *testing.T) (*ContractorService, *repositories.ContractorRepository, *repositories.ResidenceRepository) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
return service, contractorRepo, residenceRepo
}
func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
// Create two users: owner and attacker
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password")
// Owner creates a residence
ownerResidence := testutil.CreateTestResidence(t, db, owner.ID, "Owner House")
// Attacker creates a residence and a contractor in their residence
attackerResidence := testutil.CreateTestResidence(t, db, attacker.ID, "Attacker House")
contractor := testutil.CreateTestContractor(t, db, attackerResidence.ID, attacker.ID, "My Contractor")
// Attacker tries to reassign their contractor to the owner's residence
// This should be denied because the attacker does not have access to the owner's residence
newResidenceID := ownerResidence.ID
req := &requests.UpdateContractorRequest{
ResidenceID: &newResidenceID,
}
_, err := service.UpdateContractor(contractor.ID, attacker.ID, req)
require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
}
func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence1 := testutil.CreateTestResidence(t, db, owner.ID, "House 1")
residence2 := testutil.CreateTestResidence(t, db, owner.ID, "House 2")
contractor := testutil.CreateTestContractor(t, db, residence1.ID, owner.ID, "My Contractor")
// Owner reassigns contractor to their other residence - should succeed
newResidenceID := residence2.ID
newName := "Updated Contractor"
req := &requests.UpdateContractorRequest{
Name: &newName,
ResidenceID: &newResidenceID,
}
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
require.NoError(t, err, "should allow reassigning contractor to a residence the user owns")
require.NotNil(t, resp)
require.Equal(t, "Updated Contractor", resp.Name)
}
func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewContractorService(contractorRepo, residenceRepo)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, owner.ID, "My House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "My Contractor")
// Setting ResidenceID to nil should remove the residence association (make it personal)
req := &requests.UpdateContractorRequest{
ResidenceID: nil,
}
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
require.NoError(t, err, "should allow removing residence association")
require.NotNil(t, resp)
}

View File

@@ -323,10 +323,21 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
}, nil
}
// validateLegacyReceipt uses Apple's legacy verifyReceipt endpoint
// validateLegacyReceipt uses Apple's legacy verifyReceipt endpoint.
// It delegates to validateLegacyReceiptWithSandbox using the client's
// configured sandbox setting. This avoids mutating the struct field
// during the sandbox-retry flow, which caused a data race when
// multiple goroutines shared the same AppleIAPClient.
func (c *AppleIAPClient) validateLegacyReceipt(ctx context.Context, receiptData string) (*AppleValidationResult, error) {
return c.validateLegacyReceiptWithSandbox(ctx, receiptData, c.sandbox)
}
// validateLegacyReceiptWithSandbox performs legacy receipt validation against
// the specified environment. The sandbox parameter is passed by value (not
// stored on the struct) so this function is safe for concurrent use.
func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, receiptData string, useSandbox bool) (*AppleValidationResult, error) {
url := "https://buy.itunes.apple.com/verifyReceipt"
if c.sandbox {
if useSandbox {
url = "https://sandbox.itunes.apple.com/verifyReceipt"
}
@@ -378,12 +389,10 @@ func (c *AppleIAPClient) validateLegacyReceipt(ctx context.Context, receiptData
}
// Status codes: 0 = valid, 21007 = sandbox receipt on production, 21008 = production receipt on sandbox
if legacyResponse.Status == 21007 && !c.sandbox {
// Retry with sandbox
c.sandbox = true
result, err := c.validateLegacyReceipt(ctx, receiptData)
c.sandbox = false
return result, err
if legacyResponse.Status == 21007 && !useSandbox {
// Retry with sandbox -- pass sandbox=true as a parameter instead of
// mutating c.sandbox, which avoids a data race.
return c.validateLegacyReceiptWithSandbox(ctx, receiptData, true)
}
if legacyResponse.Status != 0 {

View File

@@ -355,20 +355,43 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error)
return result, nil
}
// DeleteDevice deletes a device
// DeleteDevice deactivates a device after verifying it belongs to the requesting user.
// Without ownership verification, an attacker could deactivate push notifications for other users.
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error {
var err error
switch platform {
case push.PlatformIOS:
err = s.notificationRepo.DeactivateAPNSDevice(deviceID)
device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.device_not_found")
}
return apperrors.Internal(err)
}
// Verify the device belongs to the requesting user
if device.UserID == nil || *device.UserID != userID {
return apperrors.Forbidden("error.device_access_denied")
}
if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil {
return apperrors.Internal(err)
}
case push.PlatformAndroid:
err = s.notificationRepo.DeactivateGCMDevice(deviceID)
device, err := s.notificationRepo.FindGCMDeviceByID(deviceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return apperrors.NotFound("error.device_not_found")
}
return apperrors.Internal(err)
}
// Verify the device belongs to the requesting user
if device.UserID == nil || *device.UserID != userID {
return apperrors.Forbidden("error.device_access_denied")
}
if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil {
return apperrors.Internal(err)
}
default:
return apperrors.BadRequest("error.invalid_platform")
}
if err != nil {
return apperrors.Internal(err)
}
return nil
}
@@ -549,9 +572,9 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
// RegisterDeviceRequest represents device registration request
type RegisterDeviceRequest struct {
Name string `json:"name"`
DeviceID string `json:"device_id" binding:"required"`
RegistrationID string `json:"registration_id" binding:"required"`
Platform string `json:"platform" binding:"required,oneof=ios android"`
DeviceID string `json:"device_id" validate:"required"`
RegistrationID string `json:"registration_id" validate:"required"`
Platform string `json:"platform" validate:"required,oneof=ios android"`
}
// === Task Notifications with Actions ===

View File

@@ -0,0 +1,126 @@
package services
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/push"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/testutil"
)
func setupNotificationService(t *testing.T) (*NotificationService, *repositories.NotificationRepository) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
// pushClient is nil for testing (no actual push sends)
service := NewNotificationService(notifRepo, nil)
return service, notifRepo
}
func TestDeleteDevice_WrongUser_Returns403(t *testing.T) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
service := NewNotificationService(notifRepo, nil)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password")
// Register an iOS device for the owner
device := &models.APNSDevice{
UserID: &owner.ID,
Name: "Owner iPhone",
DeviceID: "device-123",
RegistrationID: "token-abc",
Active: true,
}
err := db.Create(device).Error
require.NoError(t, err)
// Attacker tries to deactivate the owner's device
err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID)
require.Error(t, err, "should not allow deleting another user's device")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
// Verify the device is still active
var found models.APNSDevice
err = db.First(&found, device.ID).Error
require.NoError(t, err)
assert.True(t, found.Active, "device should still be active after failed deletion")
}
func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
service := NewNotificationService(notifRepo, nil)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Register an iOS device for the owner
device := &models.APNSDevice{
UserID: &owner.ID,
Name: "Owner iPhone",
DeviceID: "device-123",
RegistrationID: "token-abc",
Active: true,
}
err := db.Create(device).Error
require.NoError(t, err)
// Owner deactivates their own device
err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID)
require.NoError(t, err, "owner should be able to deactivate their own device")
// Verify the device is now inactive
var found models.APNSDevice
err = db.First(&found, device.ID).Error
require.NoError(t, err)
assert.False(t, found.Active, "device should be deactivated")
}
func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
service := NewNotificationService(notifRepo, nil)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password")
// Register an Android device for the owner
device := &models.GCMDevice{
UserID: &owner.ID,
Name: "Owner Pixel",
DeviceID: "device-456",
RegistrationID: "token-def",
CloudMessageType: "FCM",
Active: true,
}
err := db.Create(device).Error
require.NoError(t, err)
// Attacker tries to deactivate the owner's Android device
err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID)
require.Error(t, err, "should not allow deleting another user's Android device")
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
// Verify the device is still active
var found models.GCMDevice
err = db.First(&found, device.ID).Error
require.NoError(t, err)
assert.True(t, found.Active, "Android device should still be active after failed deletion")
}
func TestDeleteDevice_NonExistent_Returns404(t *testing.T) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
service := NewNotificationService(notifRepo, nil)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
err := service.DeleteDevice(99999, push.PlatformIOS, user.ID)
require.Error(t, err, "should return error for non-existent device")
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
}

View File

@@ -40,9 +40,12 @@ func generateTrackingID() string {
// HasSentEmail checks if a specific email type has already been sent to a user
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
var count int64
s.db.Model(&models.OnboardingEmail{}).
if err := s.db.Model(&models.OnboardingEmail{}).
Where("user_id = ? AND email_type = ?", userID, emailType).
Count(&count)
Count(&count).Error; err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to check if email was sent")
return false
}
return count > 0
}
@@ -125,23 +128,31 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error)
// No residence email stats
var noResTotal, noResOpened int64
s.db.Model(&models.OnboardingEmail{}).
if err := s.db.Model(&models.OnboardingEmail{}).
Where("email_type = ?", models.OnboardingEmailNoResidence).
Count(&noResTotal)
s.db.Model(&models.OnboardingEmail{}).
Count(&noResTotal).Error; err != nil {
log.Error().Err(err).Msg("Failed to count no-residence emails")
}
if err := s.db.Model(&models.OnboardingEmail{}).
Where("email_type = ? AND opened_at IS NOT NULL", models.OnboardingEmailNoResidence).
Count(&noResOpened)
Count(&noResOpened).Error; err != nil {
log.Error().Err(err).Msg("Failed to count opened no-residence emails")
}
stats.NoResidenceTotal = noResTotal
stats.NoResidenceOpened = noResOpened
// No tasks email stats
var noTasksTotal, noTasksOpened int64
s.db.Model(&models.OnboardingEmail{}).
if err := s.db.Model(&models.OnboardingEmail{}).
Where("email_type = ?", models.OnboardingEmailNoTasks).
Count(&noTasksTotal)
s.db.Model(&models.OnboardingEmail{}).
Count(&noTasksTotal).Error; err != nil {
log.Error().Err(err).Msg("Failed to count no-tasks emails")
}
if err := s.db.Model(&models.OnboardingEmail{}).
Where("email_type = ? AND opened_at IS NOT NULL", models.OnboardingEmailNoTasks).
Count(&noTasksOpened)
Count(&noTasksOpened).Error; err != nil {
log.Error().Err(err).Msg("Failed to count opened no-tasks emails")
}
stats.NoTasksTotal = noTasksTotal
stats.NoTasksOpened = noTasksOpened
@@ -351,7 +362,9 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp
// If already sent before, delete the old record first to allow re-recording
// This allows admins to "resend" emails while still tracking them
if alreadySent {
s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{})
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
}
}
// Record that email was sent

View File

@@ -0,0 +1,51 @@
package services
import (
"fmt"
"path/filepath"
"strings"
)
// SafeResolvePath resolves a user-supplied relative path within a base directory.
// Returns an error if the resolved path escapes the base directory (path traversal).
// The baseDir must be an absolute path.
func SafeResolvePath(baseDir, userInput string) (string, error) {
if userInput == "" {
return "", fmt.Errorf("empty path")
}
// Reject absolute paths
if filepath.IsAbs(userInput) {
return "", fmt.Errorf("absolute paths not allowed")
}
// Clean the user input to resolve . and .. components
cleaned := filepath.Clean(userInput)
// After cleaning, check if it starts with .. (escapes base)
if strings.HasPrefix(cleaned, "..") {
return "", fmt.Errorf("path traversal detected")
}
// Resolve the base directory to an absolute path
absBase, err := filepath.Abs(baseDir)
if err != nil {
return "", fmt.Errorf("invalid base directory: %w", err)
}
// Join and resolve the full path
fullPath := filepath.Join(absBase, cleaned)
// Final containment check: the resolved path must be within the base directory
absFullPath, err := filepath.Abs(fullPath)
if err != nil {
return "", fmt.Errorf("invalid resolved path: %w", err)
}
// Ensure the resolved path is strictly inside the base directory (not the base itself)
if !strings.HasPrefix(absFullPath, absBase+string(filepath.Separator)) {
return "", fmt.Errorf("path traversal detected")
}
return absFullPath, nil
}

View File

@@ -0,0 +1,55 @@
package services
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeResolvePath_Normal_Resolves(t *testing.T) {
result, err := SafeResolvePath("/var/uploads", "images/photo.jpg")
require.NoError(t, err)
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
}
func TestSafeResolvePath_SubdirPath_Resolves(t *testing.T) {
result, err := SafeResolvePath("/var/uploads", "documents/2024/report.pdf")
require.NoError(t, err)
assert.Equal(t, "/var/uploads/documents/2024/report.pdf", result)
}
func TestSafeResolvePath_DotDotTraversal_Blocked(t *testing.T) {
tests := []struct {
name string
input string
}{
{"simple dotdot", "../etc/passwd"},
{"nested dotdot", "../../etc/shadow"},
{"embedded dotdot", "images/../../etc/passwd"},
{"deep dotdot", "a/b/c/../../../../etc/passwd"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := SafeResolvePath("/var/uploads", tt.input)
assert.Error(t, err, "path traversal should be blocked: %s", tt.input)
})
}
}
func TestSafeResolvePath_AbsolutePath_Blocked(t *testing.T) {
_, err := SafeResolvePath("/var/uploads", "/etc/passwd")
assert.Error(t, err, "absolute paths should be blocked")
}
func TestSafeResolvePath_EmptyPath_Blocked(t *testing.T) {
_, err := SafeResolvePath("/var/uploads", "")
assert.Error(t, err, "empty paths should be blocked")
}
func TestSafeResolvePath_CurrentDir_Blocked(t *testing.T) {
// "." resolves to the base dir itself — this is not a file, so block it
_, err := SafeResolvePath("/var/uploads", ".")
assert.Error(t, err, "bare current directory should be blocked")
}

View File

@@ -126,10 +126,11 @@ func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte
pdf.SetFillColor(255, 255, 255) // White
}
// Title (truncate if too long)
// Title (truncate if too long, use runes to avoid cutting multi-byte UTF-8 characters)
title := task.Title
if len(title) > 35 {
title = title[:32] + "..."
titleRunes := []rune(title)
if len(titleRunes) > 35 {
title = string(titleRunes[:32]) + "..."
}
// Status text

View File

@@ -4,6 +4,7 @@ import (
"errors"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
@@ -31,10 +32,11 @@ var (
// ResidenceService handles residence business logic
type ResidenceService struct {
residenceRepo *repositories.ResidenceRepository
userRepo *repositories.UserRepository
taskRepo *repositories.TaskRepository
config *config.Config
residenceRepo *repositories.ResidenceRepository
userRepo *repositories.UserRepository
taskRepo *repositories.TaskRepository
subscriptionService *SubscriptionService
config *config.Config
}
// NewResidenceService creates a new residence service
@@ -51,6 +53,11 @@ func (s *ResidenceService) SetTaskRepository(taskRepo *repositories.TaskReposito
s.taskRepo = taskRepo
}
// SetSubscriptionService sets the subscription service (used for tier limit enforcement)
func (s *ResidenceService) SetSubscriptionService(subService *SubscriptionService) {
s.subscriptionService = subService
}
// GetResidence gets a residence by ID with access check
func (s *ResidenceService) GetResidence(residenceID, userID uint) (*responses.ResidenceResponse, error) {
// Check access
@@ -152,12 +159,12 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary {
// CreateResidence creates a new residence and returns it with updated summary
func (s *ResidenceService) CreateResidence(req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) {
// TODO: Check subscription tier limits
// count, err := s.residenceRepo.CountByOwner(ownerID)
// if err != nil {
// return nil, err
// }
// Check against tier limits...
// Check subscription tier limits (if subscription service is wired up)
if s.subscriptionService != nil {
if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil {
return nil, err
}
}
isPrimary := true
if req.IsPrimary != nil {
@@ -447,6 +454,7 @@ func (s *ResidenceService) JoinWithCode(code string, userID uint) (*responses.Jo
if err := s.residenceRepo.DeactivateShareCode(shareCode.ID); err != nil {
// Log the error but don't fail the join - the user has already been added
// The code will just be usable by others until it expires
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate share code after join")
}
// Get the residence with full details

View File

@@ -1,15 +1,19 @@
package services
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/testutil"
)
@@ -333,3 +337,122 @@ func TestResidenceService_RemoveUser_CannotRemoveOwner(t *testing.T) {
err := service.RemoveUser(residence.ID, owner.ID, owner.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.cannot_remove_owner")
}
// setupResidenceServiceWithSubscription creates a ResidenceService wired with a
// SubscriptionService, enabling tier limit enforcement in tests.
func setupResidenceServiceWithSubscription(t *testing.T) (*ResidenceService, *gorm.DB) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
cfg := &config.Config{}
service := NewResidenceService(residenceRepo, userRepo, cfg)
subscriptionService := NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
service.SetSubscriptionService(subscriptionService)
return service, db
}
func TestCreateResidence_FreeTier_EnforcesLimit(t *testing.T) {
service, db := setupResidenceServiceWithSubscription(t)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Enable global limitations
db.Where("1=1").Delete(&models.SubscriptionSettings{})
err := db.Create(&models.SubscriptionSettings{EnableLimitations: true}).Error
require.NoError(t, err)
// Set free tier limit to 1 property
one := 1
db.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{})
err = db.Create(&models.TierLimits{
Tier: models.TierFree,
PropertiesLimit: &one,
}).Error
require.NoError(t, err)
// Ensure user has a free-tier subscription record
subscriptionRepo := repositories.NewSubscriptionRepository(db)
_, err = subscriptionRepo.GetOrCreate(owner.ID)
require.NoError(t, err)
// First residence should succeed (under the limit)
req := &requests.CreateResidenceRequest{
Name: "First House",
StreetAddress: "1 Main St",
City: "Austin",
StateProvince: "TX",
PostalCode: "78701",
}
resp, err := service.CreateResidence(req, owner.ID)
require.NoError(t, err)
assert.Equal(t, "First House", resp.Data.Name)
// Second residence should be rejected (at the limit)
req2 := &requests.CreateResidenceRequest{
Name: "Second House",
StreetAddress: "2 Main St",
City: "Austin",
StateProvince: "TX",
PostalCode: "78702",
}
_, err = service.CreateResidence(req2, owner.ID)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.properties_limit_exceeded")
}
func TestCreateResidence_ProTier_AllowsMore(t *testing.T) {
service, db := setupResidenceServiceWithSubscription(t)
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Enable global limitations
db.Where("1=1").Delete(&models.SubscriptionSettings{})
err := db.Create(&models.SubscriptionSettings{EnableLimitations: true}).Error
require.NoError(t, err)
// Set free tier limit to 1 property (pro is unlimited by default: nil limits)
one := 1
db.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{})
err = db.Create(&models.TierLimits{
Tier: models.TierFree,
PropertiesLimit: &one,
}).Error
require.NoError(t, err)
// Create a pro-tier subscription for the user
subscriptionRepo := repositories.NewSubscriptionRepository(db)
sub, err := subscriptionRepo.GetOrCreate(owner.ID)
require.NoError(t, err)
// Upgrade to Pro with a future expiration
future := time.Now().UTC().Add(30 * 24 * time.Hour)
sub.Tier = models.TierPro
sub.ExpiresAt = &future
sub.SubscribedAt = ptrTime(time.Now().UTC())
err = subscriptionRepo.Update(sub)
require.NoError(t, err)
// Create multiple residences — all should succeed for Pro users
for i := 1; i <= 3; i++ {
req := &requests.CreateResidenceRequest{
Name: fmt.Sprintf("House %d", i),
StreetAddress: fmt.Sprintf("%d Main St", i),
City: "Austin",
StateProvince: "TX",
PostalCode: "78701",
}
resp, err := service.CreateResidence(req, owner.ID)
require.NoError(t, err, "Pro user should be able to create residence %d", i)
assert.Equal(t, fmt.Sprintf("House %d", i), resp.Data.Name)
}
}
// ptrTime returns a pointer to the given time.
func ptrTime(t time.Time) *time.Time {
return &t
}

View File

@@ -72,7 +72,7 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
if ext == "" {
ext = s.getExtensionFromMimeType(mimeType)
}
newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String()[:8], ext)
newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String(), ext)
// Determine subdirectory based on category
subdir := "images"
@@ -134,9 +134,15 @@ func (s *StorageService) Delete(fileURL string) error {
fullPath := filepath.Join(s.cfg.UploadDir, relativePath)
// Security check: ensure path is within upload directory
absUploadDir, _ := filepath.Abs(s.cfg.UploadDir)
absFilePath, _ := filepath.Abs(fullPath)
if !strings.HasPrefix(absFilePath, absUploadDir) {
absUploadDir, err := filepath.Abs(s.cfg.UploadDir)
if err != nil {
return fmt.Errorf("failed to resolve upload directory: %w", err)
}
absFilePath, err := filepath.Abs(fullPath)
if err != nil {
return fmt.Errorf("failed to resolve file path: %w", err)
}
if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir {
return fmt.Errorf("invalid file path")
}
@@ -181,3 +187,9 @@ func (s *StorageService) getExtensionFromMimeType(mimeType string) string {
func (s *StorageService) GetUploadDir() string {
return s.cfg.UploadDir
}
// NewStorageServiceForTest creates a StorageService without creating directories.
// This is intended only for unit tests that need a StorageService with a known config.
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
return &StorageService{cfg: cfg}
}

View File

@@ -3,9 +3,9 @@ package services
import (
"context"
"errors"
"log"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
@@ -74,11 +74,11 @@ func NewSubscriptionService(
appleClient, err := NewAppleIAPClient(cfg.AppleIAP)
if err != nil {
if !errors.Is(err, ErrIAPNotConfigured) {
log.Printf("Warning: Failed to initialize Apple IAP client: %v", err)
log.Warn().Err(err).Msg("Failed to initialize Apple IAP client")
}
} else {
svc.appleClient = appleClient
log.Println("Apple IAP validation client initialized")
log.Info().Msg("Apple IAP validation client initialized")
}
// Initialize Google IAP client
@@ -86,11 +86,11 @@ func NewSubscriptionService(
googleClient, err := NewGoogleIAPClient(ctx, cfg.GoogleIAP)
if err != nil {
if !errors.Is(err, ErrIAPNotConfigured) {
log.Printf("Warning: Failed to initialize Google IAP client: %v", err)
log.Warn().Err(err).Msg("Failed to initialize Google IAP client")
}
} else {
svc.googleClient = googleClient
log.Println("Google IAP validation client initialized")
log.Info().Msg("Google IAP validation client initialized")
}
}
@@ -173,7 +173,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
return resp, nil
}
// getUserUsage calculates current usage for a user
// getUserUsage calculates current usage for a user.
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
residences, err := s.residenceRepo.FindOwnedByUser(userID)
if err != nil {
@@ -181,26 +182,26 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
}
propertiesCount := int64(len(residences))
// Count tasks, contractors, and documents across all user's residences
var tasksCount, contractorsCount, documentsCount int64
for _, r := range residences {
tc, err := s.taskRepo.CountByResidence(r.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
tasksCount += tc
// Collect residence IDs for batch queries
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
cc, err := s.contractorRepo.CountByResidence(r.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
contractorsCount += cc
// Count tasks, contractors, and documents across all residences with single queries each
tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
dc, err := s.documentRepo.CountByResidence(r.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
documentsCount += dc
contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs)
if err != nil {
return nil, apperrors.Internal(err)
}
return &UsageResponse{
@@ -342,46 +343,40 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
return nil, apperrors.Internal(err)
}
// Validate with Apple if client is configured
var expiresAt time.Time
if s.appleClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var result *AppleValidationResult
var err error
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
if transactionID != "" {
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
} else if receiptData != "" {
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
}
if err != nil {
// Log the validation error
log.Printf("Apple validation warning for user %d: %v", userID, err)
// Check if it's a fatal error
if errors.Is(err, ErrInvalidReceipt) || errors.Is(err, ErrSubscriptionCancelled) {
return nil, err
}
// For other errors (network, etc.), fall back with shorter expiry
expiresAt = time.Now().UTC().AddDate(0, 1, 0) // 1 month fallback
} else if result != nil {
// Use the expiration date from Apple
expiresAt = result.ExpiresAt
log.Printf("Apple purchase validated for user %d: product=%s, expires=%v, env=%s",
userID, result.ProductID, result.ExpiresAt, result.Environment)
}
} else {
// Apple validation not configured - trust client but log warning
log.Printf("Warning: Apple IAP validation not configured, trusting client for user %d", userID)
expiresAt = time.Now().UTC().AddDate(1, 0, 0) // 1 year default
// Apple IAP client must be configured to validate purchases.
// Without server-side validation, we cannot trust client-provided receipts.
if s.appleClient == nil {
log.Error().Uint("user_id", userID).Msg("Apple IAP validation not configured, rejecting purchase")
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
}
// Upgrade to Pro with the determined expiration
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var result *AppleValidationResult
var err error
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
if transactionID != "" {
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
} else if receiptData != "" {
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
}
if err != nil {
// Validation failed -- do NOT fall through to grant Pro.
log.Error().Err(err).Uint("user_id", userID).Msg("Apple validation failed")
return nil, err
}
if result == nil {
return nil, apperrors.BadRequest("error.no_receipt_or_transaction")
}
expiresAt := result.ExpiresAt
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
// Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return nil, apperrors.Internal(err)
}
@@ -397,59 +392,48 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
return nil, apperrors.Internal(err)
}
// Validate the purchase with Google if client is configured
var expiresAt time.Time
if s.googleClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var result *GoogleValidationResult
var err error
// If productID is provided, use it directly; otherwise try known IDs
if productID != "" {
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
} else {
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
}
if err != nil {
// Log the validation error
log.Printf("Google purchase validation warning for user %d: %v", userID, err)
// Check if it's a fatal error
if errors.Is(err, ErrInvalidPurchaseToken) || errors.Is(err, ErrSubscriptionCancelled) {
return nil, err
}
if errors.Is(err, ErrSubscriptionExpired) {
// Subscription expired - still allow but set past expiry
expiresAt = time.Now().UTC().Add(-1 * time.Hour)
} else {
// For other errors, fall back with shorter expiry
expiresAt = time.Now().UTC().AddDate(0, 1, 0) // 1 month fallback
}
} else if result != nil {
// Use the expiration date from Google
expiresAt = result.ExpiresAt
log.Printf("Google purchase validated for user %d: product=%s, expires=%v, autoRenew=%v",
userID, result.ProductID, result.ExpiresAt, result.AutoRenewing)
// Acknowledge the subscription if not already acknowledged
if !result.AcknowledgedState {
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
log.Printf("Warning: Failed to acknowledge subscription for user %d: %v", userID, err)
// Don't fail the purchase, just log the warning
}
}
}
} else {
// Google validation not configured - trust client but log warning
log.Printf("Warning: Google IAP validation not configured, trusting client for user %d", userID)
expiresAt = time.Now().UTC().AddDate(1, 0, 0) // 1 year default
// Google IAP client must be configured to validate purchases.
// Without server-side validation, we cannot trust client-provided tokens.
if s.googleClient == nil {
log.Error().Uint("user_id", userID).Msg("Google IAP validation not configured, rejecting purchase")
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
}
// Upgrade to Pro with the determined expiration
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var result *GoogleValidationResult
var err error
// If productID is provided, use it directly; otherwise try known IDs
if productID != "" {
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
} else {
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
}
if err != nil {
// Validation failed -- do NOT fall through to grant Pro.
log.Error().Err(err).Uint("user_id", userID).Msg("Google purchase validation failed")
return nil, err
}
if result == nil {
return nil, apperrors.BadRequest("error.no_purchase_token")
}
expiresAt := result.ExpiresAt
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Bool("auto_renew", result.AutoRenewing).Msg("Google purchase validated")
// Acknowledge the subscription if not already acknowledged
if !result.AcknowledgedState {
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to acknowledge Google subscription")
// Don't fail the purchase, just log the warning
}
}
// Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil {
return nil, apperrors.Internal(err)
}
@@ -654,5 +638,5 @@ type ProcessPurchaseRequest struct {
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
PurchaseToken string `json:"purchase_token"` // Android
ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
Platform string `json:"platform" binding:"required,oneof=ios android"`
Platform string `json:"platform" validate:"required,oneof=ios android"`
}

View File

@@ -0,0 +1,181 @@
package services
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/testutil"
)
// setupSubscriptionService creates a SubscriptionService with the given
// IAP clients (nil means "not configured"). It bypasses NewSubscriptionService
// which tries to load config from environment.
func setupSubscriptionService(t *testing.T, appleClient *AppleIAPClient, googleClient *GoogleIAPClient) (*SubscriptionService, *repositories.SubscriptionRepository) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
// Create a test user and subscription record for the test
user := testutil.CreateTestUser(t, db, "subuser", "subuser@test.com", "password")
// Create subscription record so GetOrCreate will find it
sub := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierFree,
}
err := db.Create(sub).Error
require.NoError(t, err)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
appleClient: appleClient,
googleClient: googleClient,
}
return svc, subscriptionRepo
}
func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "subuser", "subuser@test.com", "password")
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
require.NoError(t, db.Create(sub).Error)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
appleClient: nil, // Not configured
googleClient: nil,
}
_, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "")
assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil")
// Verify user was NOT upgraded to Pro
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier when IAP client is nil")
}
func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
// We cannot easily create a real AppleIAPClient that will fail validation
// in a unit test (it requires real keys and network access).
// Instead, we test the code path logic:
// When appleClient is nil, the service must NOT upgrade the user.
// This is the same as TestProcessApplePurchase_ClientNil_ReturnsError
// but validates no fallback occurs for the specific case.
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "subuser2", "subuser2@test.com", "password")
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
require.NoError(t, db.Create(sub).Error)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
appleClient: nil,
googleClient: nil,
}
// Neither receipt data nor transaction ID - should still not grant Pro
_, err := svc.ProcessApplePurchase(user.ID, "", "")
assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data")
// Verify no upgrade happened
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
}
func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "subuser3", "subuser3@test.com", "password")
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
require.NoError(t, db.Create(sub).Error)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
appleClient: nil,
googleClient: nil, // Not configured
}
_, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.casera.pro.monthly")
assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil")
// Verify user was NOT upgraded to Pro
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier when IAP client is nil")
}
func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "subuser4", "subuser4@test.com", "password")
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
require.NoError(t, db.Create(sub).Error)
svc := &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
appleClient: nil,
googleClient: nil, // Not configured
}
// With empty token
_, err := svc.ProcessGooglePurchase(user.ID, "", "")
assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil")
// Verify no upgrade happened
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
require.NoError(t, err)
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
}

View File

@@ -560,11 +560,7 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
Rating: req.Rating,
}
if err := s.taskRepo.CreateCompletion(completion); err != nil {
return nil, apperrors.Internal(err)
}
// Update next_due_date and in_progress based on frequency
// Determine interval days for NextDueDate calculation before entering the transaction.
// - If frequency is "Once" (days = nil or 0), set next_due_date to nil (marks as completed)
// - If frequency is "Custom", use task.CustomIntervalDays for recurrence
// - If frequency is recurring, calculate next_due_date = completion_date + frequency_days
@@ -598,11 +594,25 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
// instead of staying in "In Progress" column
task.InProgress = false
}
if err := s.taskRepo.Update(task); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
// P1-5: Wrap completion creation and task update in a transaction.
// If either operation fails, both are rolled back to prevent orphaned completions.
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
return err
}
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
return err
}
return nil
})
if txErr != nil {
// P1-6: Return the error instead of swallowing it.
if errors.Is(txErr, repositories.ErrVersionConflict) {
return nil, apperrors.Conflict("error.version_conflict")
}
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after completion")
log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task")
return nil, apperrors.Internal(txErr)
}
// Create images if provided
@@ -731,8 +741,15 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
}
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
// Send notification (fire and forget)
go s.sendTaskCompletedNotification(task, completion)
// Send notification (fire and forget with panic recovery)
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Uint("task_id", task.ID).Msg("Panic in quick-complete notification goroutine")
}
}()
s.sendTaskCompletedNotification(task, completion)
}()
return nil
}
@@ -764,23 +781,23 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
emailImages = s.loadCompletionImagesForEmail(completion.Images)
}
// Notify all users
// Notify all users synchronously to avoid unbounded goroutine spawning.
// This method is already called from a goroutine (QuickComplete) or inline
// (CreateCompletion) where blocking is acceptable for notification delivery.
for _, user := range users {
isCompleter := user.ID == completion.CompletedByID
// Send push notification (to everyone EXCEPT the person who completed it)
if !isCompleter && s.notificationService != nil {
go func(userID uint) {
ctx := context.Background()
if err := s.notificationService.CreateAndSendTaskNotification(
ctx,
userID,
models.NotificationTaskCompleted,
task,
); err != nil {
log.Error().Err(err).Uint("user_id", userID).Uint("task_id", task.ID).Msg("Failed to send task completion push notification")
}
}(user.ID)
ctx := context.Background()
if err := s.notificationService.CreateAndSendTaskNotification(
ctx,
user.ID,
models.NotificationTaskCompleted,
task,
); err != nil {
log.Error().Err(err).Uint("user_id", user.ID).Uint("task_id", task.ID).Msg("Failed to send task completion push notification")
}
}
// Send email notification (to everyone INCLUDING the person who completed it)
@@ -789,20 +806,18 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
prefs, err := s.notificationService.GetPreferences(user.ID)
if err != nil || (prefs != nil && prefs.EmailTaskCompleted) {
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
go func(u models.User, images []EmbeddedImage) {
if err := s.emailService.SendTaskCompletedEmail(
u.Email,
u.GetFullName(),
task.Title,
completedByName,
residenceName,
images,
); err != nil {
log.Error().Err(err).Str("email", u.Email).Uint("task_id", task.ID).Msg("Failed to send task completion email")
} else {
log.Info().Str("email", u.Email).Uint("task_id", task.ID).Int("images", len(images)).Msg("Task completion email sent")
}
}(user, emailImages)
if err := s.emailService.SendTaskCompletedEmail(
user.Email,
user.GetFullName(),
task.Title,
completedByName,
residenceName,
emailImages,
); err != nil {
log.Error().Err(err).Str("email", user.Email).Uint("task_id", task.ID).Msg("Failed to send task completion email")
} else {
log.Info().Str("email", user.Email).Uint("task_id", task.ID).Int("images", len(emailImages)).Msg("Task completion email sent")
}
}
}
}
@@ -846,20 +861,28 @@ func (s *TaskService) loadCompletionImagesForEmail(images []models.TaskCompletio
return emailImages
}
// resolveImageFilePath converts a stored URL to an actual file path
// resolveImageFilePath converts a stored URL to an actual file path.
// Returns empty string if the URL is empty or the resolved path would escape
// the upload directory (path traversal attempt).
func (s *TaskService) resolveImageFilePath(storedURL, uploadDir string) string {
if storedURL == "" {
return ""
}
// Handle /uploads/... URLs
// Strip legacy /uploads/ prefix to get relative path
relativePath := storedURL
if strings.HasPrefix(storedURL, "/uploads/") {
relativePath := strings.TrimPrefix(storedURL, "/uploads/")
return filepath.Join(uploadDir, relativePath)
relativePath = strings.TrimPrefix(storedURL, "/uploads/")
}
// Handle relative paths
return filepath.Join(uploadDir, storedURL)
// Use SafeResolvePath to validate containment within upload directory
resolved, err := SafeResolvePath(uploadDir, relativePath)
if err != nil {
// Path traversal or invalid path — return empty to signal file not found
return ""
}
return resolved
}
// getContentTypeFromPath returns the MIME type based on file extension
@@ -977,7 +1000,11 @@ func (s *TaskService) UpdateCompletion(completionID, userID uint, req *requests.
return &resp, nil
}
// DeleteCompletion deletes a task completion
// DeleteCompletion deletes a task completion and recalculates the task's NextDueDate.
//
// P1-7: After deleting a completion, NextDueDate must be recalculated:
// - If no completions remain: restore NextDueDate = DueDate (original schedule)
// - If completions remain (recurring): recalculate from latest remaining completion + frequency days
func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.DeleteWithSummaryResponse, error) {
completion, err := s.taskRepo.FindCompletionByID(completionID)
if err != nil {
@@ -996,10 +1023,66 @@ func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.De
return nil, apperrors.Forbidden("error.task_access_denied")
}
taskID := completion.TaskID
if err := s.taskRepo.DeleteCompletion(completionID); err != nil {
return nil, apperrors.Internal(err)
}
// Recalculate NextDueDate based on remaining completions
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
// Non-fatal for the delete operation itself, but log the error
log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to reload task after completion deletion for NextDueDate recalculation")
return &responses.DeleteWithSummaryResponse{
Data: "completion deleted",
Summary: s.getSummaryForUser(userID),
}, nil
}
// Get remaining completions for this task
remainingCompletions, err := s.taskRepo.FindCompletionsByTask(taskID)
if err != nil {
log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to query remaining completions after deletion")
return &responses.DeleteWithSummaryResponse{
Data: "completion deleted",
Summary: s.getSummaryForUser(userID),
}, nil
}
// Determine the task's frequency interval
var intervalDays *int
if task.FrequencyID != nil {
frequency, freqErr := s.taskRepo.GetFrequencyByID(*task.FrequencyID)
if freqErr == nil && frequency != nil {
if frequency.Name == "Custom" {
intervalDays = task.CustomIntervalDays
} else {
intervalDays = frequency.Days
}
}
}
if len(remainingCompletions) == 0 {
// No completions remain: restore NextDueDate to the original DueDate
task.NextDueDate = task.DueDate
} else if intervalDays != nil && *intervalDays > 0 {
// Recurring task with remaining completions: recalculate from the latest completion
// remainingCompletions is ordered by completed_at DESC, so index 0 is the latest
latestCompletion := remainingCompletions[0]
nextDue := latestCompletion.CompletedAt.AddDate(0, 0, *intervalDays)
task.NextDueDate = &nextDue
} else {
// One-time task with remaining completions (unusual case): keep NextDueDate as nil
// since the task is still considered completed
task.NextDueDate = nil
}
if err := s.taskRepo.Update(task); err != nil {
log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to update task NextDueDate after completion deletion")
// The completion was already deleted; return success but log the update failure
}
return &responses.DeleteWithSummaryResponse{
Data: "completion deleted",
Summary: s.getSummaryForUser(userID),

View File

@@ -8,6 +8,7 @@ import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/models"
@@ -442,6 +443,333 @@ func TestTaskService_DeleteCompletion(t *testing.T) {
assert.Error(t, err)
}
func TestTaskService_CreateCompletion_TransactionIntegrity(t *testing.T) {
// Verifies P1-5 / P1-6: completion creation and task update are atomic.
// After completion, both the completion record AND the task's NextDueDate update
// should succeed together, and errors should be propagated (not swallowed).
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewTaskService(taskRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a one-time task with a due date
dueDate := time.Now().AddDate(0, 0, 7).UTC()
task := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "One-time Task",
DueDate: &dueDate,
NextDueDate: &dueDate,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
err := db.Create(task).Error
require.NoError(t, err)
req := &requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Notes: "Done",
}
now := time.Now().UTC()
resp, err := service.CreateCompletion(req, user.ID, now)
require.NoError(t, err)
assert.NotZero(t, resp.Data.ID)
// Verify the task was updated: NextDueDate should be nil for a one-time task
var reloaded models.Task
db.First(&reloaded, task.ID)
assert.Nil(t, reloaded.NextDueDate, "One-time task NextDueDate should be nil after completion")
assert.False(t, reloaded.InProgress, "InProgress should be false after completion")
// Verify completion record exists
var completion models.TaskCompletion
err = db.Where("task_id = ?", task.ID).First(&completion).Error
require.NoError(t, err, "Completion record should exist")
assert.Equal(t, "Done", completion.Notes)
}
func TestTaskService_CreateCompletion_UpdateError_ReturnedNotSwallowed(t *testing.T) {
// Verifies P1-5 and P1-6: the completion creation and task update are wrapped
// in a transaction, and update errors are returned (not swallowed).
//
// Strategy: We trigger a version conflict by using a goroutine that bumps
// the task version after the service reads the task but during the transaction.
// Since SQLite serializes writes, we instead verify the behavior by deleting
// the task between the service read and the transactional update. When UpdateTx
// tries to match the row by id+version, 0 rows are affected and ErrVersionConflict
// is returned. The transaction then rolls back the completion insert.
//
// However, because the entire CreateCompletion flow is synchronous and we cannot
// inject failures between steps, we instead verify the transactional guarantee
// indirectly: we confirm that a concurrent version bump (set before the call
// but after the SELECT) causes the version conflict to propagate. Since FindByID
// re-reads the current version, we must verify via a custom test that invokes
// the transaction layer directly.
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
dueDate := time.Now().AddDate(0, 0, 7).UTC()
task := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Conflict Task",
DueDate: &dueDate,
NextDueDate: &dueDate,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
err := db.Create(task).Error
require.NoError(t, err)
// Directly test that the transactional path returns an error on version conflict:
// Use a stale task object (version=1) when the DB has been bumped to version=999.
db.Model(&models.Task{}).Where("id = ?", task.ID).Update("version", 999)
completion := &models.TaskCompletion{
TaskID: task.ID,
CompletedByID: user.ID,
CompletedAt: time.Now().UTC(),
Notes: "Should be rolled back",
}
// Simulate the transaction that CreateCompletion now uses (task still has version=1)
txErr := taskRepo.DB().Transaction(func(tx *gorm.DB) error {
if err := taskRepo.CreateCompletionTx(tx, completion); err != nil {
return err
}
// task.Version is 1 but DB has 999 -> version conflict
if err := taskRepo.UpdateTx(tx, task); err != nil {
return err
}
return nil
})
require.Error(t, txErr, "Transaction should fail due to version conflict")
assert.ErrorIs(t, txErr, repositories.ErrVersionConflict, "Error should be ErrVersionConflict")
// Verify the completion was rolled back
var count int64
db.Model(&models.TaskCompletion{}).Where("task_id = ?", task.ID).Count(&count)
assert.Equal(t, int64(0), count, "Completion should not exist when transaction rolls back")
// Also verify that CreateCompletion (full service method) would propagate the error.
// Re-create the task with a normal version so FindByID works, then bump it.
db.Model(&models.Task{}).Where("id = ?", task.ID).Update("version", 1)
service := NewTaskService(taskRepo, residenceRepo)
req := &requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Notes: "Test error propagation",
}
now := time.Now().UTC()
// This call will succeed because FindByID loads version=1, UpdateTx uses version=1, DB has version=1.
// To verify error propagation, we use the direct transaction test above.
resp, err := service.CreateCompletion(req, user.ID, now)
require.NoError(t, err, "CreateCompletion should succeed with matching versions")
assert.NotZero(t, resp.Data.ID)
}
func TestTaskService_DeleteCompletion_OneTime_RestoresOriginalDueDate(t *testing.T) {
// Verifies P1-7: deleting the only completion on a one-time task
// should restore NextDueDate to the original DueDate.
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewTaskService(taskRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a one-time task with a due date
originalDueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
task := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "One-time Task",
DueDate: &originalDueDate,
NextDueDate: &originalDueDate,
IsCancelled: false,
IsArchived: false,
Version: 1,
// No FrequencyID = one-time task
}
err := db.Create(task).Error
require.NoError(t, err)
// Complete the task (sets NextDueDate to nil for one-time tasks)
req := &requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Notes: "Completed",
}
now := time.Now().UTC()
completionResp, err := service.CreateCompletion(req, user.ID, now)
require.NoError(t, err)
// Confirm NextDueDate is nil after completion
var taskAfterComplete models.Task
db.First(&taskAfterComplete, task.ID)
assert.Nil(t, taskAfterComplete.NextDueDate, "NextDueDate should be nil after one-time completion")
// Delete the completion
_, err = service.DeleteCompletion(completionResp.Data.ID, user.ID)
require.NoError(t, err)
// Verify NextDueDate is restored to the original DueDate
var taskAfterDelete models.Task
db.First(&taskAfterDelete, task.ID)
require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be restored after deleting completion")
assert.Equal(t, originalDueDate.Year(), taskAfterDelete.NextDueDate.Year())
assert.Equal(t, originalDueDate.Month(), taskAfterDelete.NextDueDate.Month())
assert.Equal(t, originalDueDate.Day(), taskAfterDelete.NextDueDate.Day())
}
func TestTaskService_DeleteCompletion_Recurring_RecalculatesFromLastCompletion(t *testing.T) {
// Verifies P1-7: deleting the latest completion on a recurring task
// should recalculate NextDueDate from the remaining latest completion.
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewTaskService(taskRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
var monthlyFrequency models.TaskFrequency
db.Where("name = ?", "Monthly").First(&monthlyFrequency)
// Create a recurring task
originalDueDate := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
task := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Recurring Task",
FrequencyID: &monthlyFrequency.ID,
DueDate: &originalDueDate,
NextDueDate: &originalDueDate,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
err := db.Create(task).Error
require.NoError(t, err)
// First completion on Jan 15
firstCompletedAt := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC)
firstReq := &requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Notes: "First completion",
CompletedAt: &firstCompletedAt,
}
now := time.Now().UTC()
_, err = service.CreateCompletion(firstReq, user.ID, now)
require.NoError(t, err)
// Second completion on Feb 15
secondCompletedAt := time.Date(2026, 2, 15, 10, 0, 0, 0, time.UTC)
secondReq := &requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Notes: "Second completion",
CompletedAt: &secondCompletedAt,
}
resp, err := service.CreateCompletion(secondReq, user.ID, now)
require.NoError(t, err)
// NextDueDate should be Feb 15 + 30 days = Mar 17
var taskAfterSecond models.Task
db.First(&taskAfterSecond, task.ID)
require.NotNil(t, taskAfterSecond.NextDueDate)
expectedAfterSecond := secondCompletedAt.AddDate(0, 0, 30)
assert.Equal(t, expectedAfterSecond.Year(), taskAfterSecond.NextDueDate.Year())
assert.Equal(t, expectedAfterSecond.Month(), taskAfterSecond.NextDueDate.Month())
assert.Equal(t, expectedAfterSecond.Day(), taskAfterSecond.NextDueDate.Day())
// Delete the second (latest) completion
_, err = service.DeleteCompletion(resp.Data.ID, user.ID)
require.NoError(t, err)
// NextDueDate should be recalculated from the first completion: Jan 15 + 30 = Feb 14
var taskAfterDelete models.Task
db.First(&taskAfterDelete, task.ID)
require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be set after deleting latest completion")
expectedRecalculated := firstCompletedAt.AddDate(0, 0, 30)
assert.Equal(t, expectedRecalculated.Year(), taskAfterDelete.NextDueDate.Year())
assert.Equal(t, expectedRecalculated.Month(), taskAfterDelete.NextDueDate.Month())
assert.Equal(t, expectedRecalculated.Day(), taskAfterDelete.NextDueDate.Day())
}
func TestTaskService_DeleteCompletion_LastCompletion_RestoresDueDate(t *testing.T) {
// Verifies P1-7: deleting the only completion on a recurring task
// should restore NextDueDate to the original DueDate.
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
service := NewTaskService(taskRepo, residenceRepo)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
var weeklyFrequency models.TaskFrequency
db.Where("name = ?", "Weekly").First(&weeklyFrequency)
// Create a recurring task
originalDueDate := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
task := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Weekly Task",
FrequencyID: &weeklyFrequency.ID,
DueDate: &originalDueDate,
NextDueDate: &originalDueDate,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
err := db.Create(task).Error
require.NoError(t, err)
// Complete the task
completedAt := time.Date(2026, 3, 2, 10, 0, 0, 0, time.UTC)
req := &requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Notes: "First completion",
CompletedAt: &completedAt,
}
now := time.Now().UTC()
completionResp, err := service.CreateCompletion(req, user.ID, now)
require.NoError(t, err)
// Verify NextDueDate was set to completedAt + 7 days
var taskAfterComplete models.Task
db.First(&taskAfterComplete, task.ID)
require.NotNil(t, taskAfterComplete.NextDueDate)
// Delete the only completion
_, err = service.DeleteCompletion(completionResp.Data.ID, user.ID)
require.NoError(t, err)
// NextDueDate should be restored to original DueDate since no completions remain
var taskAfterDelete models.Task
db.First(&taskAfterDelete, task.ID)
require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be restored to original DueDate")
assert.Equal(t, originalDueDate.Year(), taskAfterDelete.NextDueDate.Year())
assert.Equal(t, originalDueDate.Month(), taskAfterDelete.NextDueDate.Month())
assert.Equal(t, originalDueDate.Day(), taskAfterDelete.NextDueDate.Day())
}
func TestTaskService_GetCategories(t *testing.T) {
db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db)