Harden API security: input validation, safe auth extraction, new tests, and deploy config
Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -195,6 +195,18 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
|
||||
if req.IsFavorite != nil {
|
||||
contractor.IsFavorite = *req.IsFavorite
|
||||
}
|
||||
// If residence_id is provided, verify the user has access to the NEW residence.
|
||||
// This prevents an attacker from reassigning a contractor to someone else's residence.
|
||||
if req.ResidenceID != nil {
|
||||
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if !hasAccess {
|
||||
return nil, apperrors.Forbidden("error.residence_access_denied")
|
||||
}
|
||||
}
|
||||
|
||||
// If residence_id is not sent in the request (nil), it means the user
|
||||
// removed the residence association - contractor becomes personal
|
||||
contractor.ResidenceID = req.ResidenceID
|
||||
|
||||
98
internal/services/contractor_service_test.go
Normal file
98
internal/services/contractor_service_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupContractorService(t *testing.T) (*ContractorService, *repositories.ContractorRepository, *repositories.ResidenceRepository) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
return service, contractorRepo, residenceRepo
|
||||
}
|
||||
|
||||
func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
// Create two users: owner and attacker
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password")
|
||||
|
||||
// Owner creates a residence
|
||||
ownerResidence := testutil.CreateTestResidence(t, db, owner.ID, "Owner House")
|
||||
|
||||
// Attacker creates a residence and a contractor in their residence
|
||||
attackerResidence := testutil.CreateTestResidence(t, db, attacker.ID, "Attacker House")
|
||||
contractor := testutil.CreateTestContractor(t, db, attackerResidence.ID, attacker.ID, "My Contractor")
|
||||
|
||||
// Attacker tries to reassign their contractor to the owner's residence
|
||||
// This should be denied because the attacker does not have access to the owner's residence
|
||||
newResidenceID := ownerResidence.ID
|
||||
req := &requests.UpdateContractorRequest{
|
||||
ResidenceID: &newResidenceID,
|
||||
}
|
||||
|
||||
_, err := service.UpdateContractor(contractor.ID, attacker.ID, req)
|
||||
require.Error(t, err, "should not allow reassigning contractor to a residence the user has no access to")
|
||||
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
||||
}
|
||||
|
||||
func TestUpdateContractor_SameResidence_Succeeds(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence1 := testutil.CreateTestResidence(t, db, owner.ID, "House 1")
|
||||
residence2 := testutil.CreateTestResidence(t, db, owner.ID, "House 2")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence1.ID, owner.ID, "My Contractor")
|
||||
|
||||
// Owner reassigns contractor to their other residence - should succeed
|
||||
newResidenceID := residence2.ID
|
||||
newName := "Updated Contractor"
|
||||
req := &requests.UpdateContractorRequest{
|
||||
Name: &newName,
|
||||
ResidenceID: &newResidenceID,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
|
||||
require.NoError(t, err, "should allow reassigning contractor to a residence the user owns")
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "Updated Contractor", resp.Name)
|
||||
}
|
||||
|
||||
func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "My House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "My Contractor")
|
||||
|
||||
// Setting ResidenceID to nil should remove the residence association (make it personal)
|
||||
req := &requests.UpdateContractorRequest{
|
||||
ResidenceID: nil,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateContractor(contractor.ID, owner.ID, req)
|
||||
require.NoError(t, err, "should allow removing residence association")
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
@@ -323,10 +323,21 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateLegacyReceipt uses Apple's legacy verifyReceipt endpoint
|
||||
// validateLegacyReceipt uses Apple's legacy verifyReceipt endpoint.
|
||||
// It delegates to validateLegacyReceiptWithSandbox using the client's
|
||||
// configured sandbox setting. This avoids mutating the struct field
|
||||
// during the sandbox-retry flow, which caused a data race when
|
||||
// multiple goroutines shared the same AppleIAPClient.
|
||||
func (c *AppleIAPClient) validateLegacyReceipt(ctx context.Context, receiptData string) (*AppleValidationResult, error) {
|
||||
return c.validateLegacyReceiptWithSandbox(ctx, receiptData, c.sandbox)
|
||||
}
|
||||
|
||||
// validateLegacyReceiptWithSandbox performs legacy receipt validation against
|
||||
// the specified environment. The sandbox parameter is passed by value (not
|
||||
// stored on the struct) so this function is safe for concurrent use.
|
||||
func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, receiptData string, useSandbox bool) (*AppleValidationResult, error) {
|
||||
url := "https://buy.itunes.apple.com/verifyReceipt"
|
||||
if c.sandbox {
|
||||
if useSandbox {
|
||||
url = "https://sandbox.itunes.apple.com/verifyReceipt"
|
||||
}
|
||||
|
||||
@@ -378,12 +389,10 @@ func (c *AppleIAPClient) validateLegacyReceipt(ctx context.Context, receiptData
|
||||
}
|
||||
|
||||
// Status codes: 0 = valid, 21007 = sandbox receipt on production, 21008 = production receipt on sandbox
|
||||
if legacyResponse.Status == 21007 && !c.sandbox {
|
||||
// Retry with sandbox
|
||||
c.sandbox = true
|
||||
result, err := c.validateLegacyReceipt(ctx, receiptData)
|
||||
c.sandbox = false
|
||||
return result, err
|
||||
if legacyResponse.Status == 21007 && !useSandbox {
|
||||
// Retry with sandbox -- pass sandbox=true as a parameter instead of
|
||||
// mutating c.sandbox, which avoids a data race.
|
||||
return c.validateLegacyReceiptWithSandbox(ctx, receiptData, true)
|
||||
}
|
||||
|
||||
if legacyResponse.Status != 0 {
|
||||
|
||||
@@ -355,20 +355,43 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteDevice deletes a device
|
||||
// DeleteDevice deactivates a device after verifying it belongs to the requesting user.
|
||||
// Without ownership verification, an attacker could deactivate push notifications for other users.
|
||||
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error {
|
||||
var err error
|
||||
switch platform {
|
||||
case push.PlatformIOS:
|
||||
err = s.notificationRepo.DeactivateAPNSDevice(deviceID)
|
||||
device, err := s.notificationRepo.FindAPNSDeviceByID(deviceID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return apperrors.NotFound("error.device_not_found")
|
||||
}
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
// Verify the device belongs to the requesting user
|
||||
if device.UserID == nil || *device.UserID != userID {
|
||||
return apperrors.Forbidden("error.device_access_denied")
|
||||
}
|
||||
if err := s.notificationRepo.DeactivateAPNSDevice(deviceID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
case push.PlatformAndroid:
|
||||
err = s.notificationRepo.DeactivateGCMDevice(deviceID)
|
||||
device, err := s.notificationRepo.FindGCMDeviceByID(deviceID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return apperrors.NotFound("error.device_not_found")
|
||||
}
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
// Verify the device belongs to the requesting user
|
||||
if device.UserID == nil || *device.UserID != userID {
|
||||
return apperrors.Forbidden("error.device_access_denied")
|
||||
}
|
||||
if err := s.notificationRepo.DeactivateGCMDevice(deviceID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
default:
|
||||
return apperrors.BadRequest("error.invalid_platform")
|
||||
}
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -549,9 +572,9 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
|
||||
// RegisterDeviceRequest represents device registration request
|
||||
type RegisterDeviceRequest struct {
|
||||
Name string `json:"name"`
|
||||
DeviceID string `json:"device_id" binding:"required"`
|
||||
RegistrationID string `json:"registration_id" binding:"required"`
|
||||
Platform string `json:"platform" binding:"required,oneof=ios android"`
|
||||
DeviceID string `json:"device_id" validate:"required"`
|
||||
RegistrationID string `json:"registration_id" validate:"required"`
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android"`
|
||||
}
|
||||
|
||||
// === Task Notifications with Actions ===
|
||||
|
||||
126
internal/services/notification_service_test.go
Normal file
126
internal/services/notification_service_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/push"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupNotificationService(t *testing.T) (*NotificationService, *repositories.NotificationRepository) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
// pushClient is nil for testing (no actual push sends)
|
||||
service := NewNotificationService(notifRepo, nil)
|
||||
return service, notifRepo
|
||||
}
|
||||
|
||||
func TestDeleteDevice_WrongUser_Returns403(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
service := NewNotificationService(notifRepo, nil)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password")
|
||||
|
||||
// Register an iOS device for the owner
|
||||
device := &models.APNSDevice{
|
||||
UserID: &owner.ID,
|
||||
Name: "Owner iPhone",
|
||||
DeviceID: "device-123",
|
||||
RegistrationID: "token-abc",
|
||||
Active: true,
|
||||
}
|
||||
err := db.Create(device).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attacker tries to deactivate the owner's device
|
||||
err = service.DeleteDevice(device.ID, push.PlatformIOS, attacker.ID)
|
||||
require.Error(t, err, "should not allow deleting another user's device")
|
||||
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
||||
|
||||
// Verify the device is still active
|
||||
var found models.APNSDevice
|
||||
err = db.First(&found, device.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.Active, "device should still be active after failed deletion")
|
||||
}
|
||||
|
||||
func TestDeleteDevice_CorrectUser_Succeeds(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
service := NewNotificationService(notifRepo, nil)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Register an iOS device for the owner
|
||||
device := &models.APNSDevice{
|
||||
UserID: &owner.ID,
|
||||
Name: "Owner iPhone",
|
||||
DeviceID: "device-123",
|
||||
RegistrationID: "token-abc",
|
||||
Active: true,
|
||||
}
|
||||
err := db.Create(device).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Owner deactivates their own device
|
||||
err = service.DeleteDevice(device.ID, push.PlatformIOS, owner.ID)
|
||||
require.NoError(t, err, "owner should be able to deactivate their own device")
|
||||
|
||||
// Verify the device is now inactive
|
||||
var found models.APNSDevice
|
||||
err = db.First(&found, device.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.Active, "device should be deactivated")
|
||||
}
|
||||
|
||||
func TestDeleteDevice_WrongUser_Android_Returns403(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
service := NewNotificationService(notifRepo, nil)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
attacker := testutil.CreateTestUser(t, db, "attacker", "attacker@test.com", "password")
|
||||
|
||||
// Register an Android device for the owner
|
||||
device := &models.GCMDevice{
|
||||
UserID: &owner.ID,
|
||||
Name: "Owner Pixel",
|
||||
DeviceID: "device-456",
|
||||
RegistrationID: "token-def",
|
||||
CloudMessageType: "FCM",
|
||||
Active: true,
|
||||
}
|
||||
err := db.Create(device).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attacker tries to deactivate the owner's Android device
|
||||
err = service.DeleteDevice(device.ID, push.PlatformAndroid, attacker.ID)
|
||||
require.Error(t, err, "should not allow deleting another user's Android device")
|
||||
testutil.AssertAppErrorCode(t, err, http.StatusForbidden)
|
||||
|
||||
// Verify the device is still active
|
||||
var found models.GCMDevice
|
||||
err = db.First(&found, device.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.Active, "Android device should still be active after failed deletion")
|
||||
}
|
||||
|
||||
func TestDeleteDevice_NonExistent_Returns404(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
service := NewNotificationService(notifRepo, nil)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
err := service.DeleteDevice(99999, push.PlatformIOS, user.ID)
|
||||
require.Error(t, err, "should return error for non-existent device")
|
||||
testutil.AssertAppErrorCode(t, err, http.StatusNotFound)
|
||||
}
|
||||
@@ -40,9 +40,12 @@ func generateTrackingID() string {
|
||||
// HasSentEmail checks if a specific email type has already been sent to a user
|
||||
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
|
||||
var count int64
|
||||
s.db.Model(&models.OnboardingEmail{}).
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("user_id = ? AND email_type = ?", userID, emailType).
|
||||
Count(&count)
|
||||
Count(&count).Error; err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to check if email was sent")
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
@@ -125,23 +128,31 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error)
|
||||
|
||||
// No residence email stats
|
||||
var noResTotal, noResOpened int64
|
||||
s.db.Model(&models.OnboardingEmail{}).
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("email_type = ?", models.OnboardingEmailNoResidence).
|
||||
Count(&noResTotal)
|
||||
s.db.Model(&models.OnboardingEmail{}).
|
||||
Count(&noResTotal).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Failed to count no-residence emails")
|
||||
}
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("email_type = ? AND opened_at IS NOT NULL", models.OnboardingEmailNoResidence).
|
||||
Count(&noResOpened)
|
||||
Count(&noResOpened).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Failed to count opened no-residence emails")
|
||||
}
|
||||
stats.NoResidenceTotal = noResTotal
|
||||
stats.NoResidenceOpened = noResOpened
|
||||
|
||||
// No tasks email stats
|
||||
var noTasksTotal, noTasksOpened int64
|
||||
s.db.Model(&models.OnboardingEmail{}).
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("email_type = ?", models.OnboardingEmailNoTasks).
|
||||
Count(&noTasksTotal)
|
||||
s.db.Model(&models.OnboardingEmail{}).
|
||||
Count(&noTasksTotal).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Failed to count no-tasks emails")
|
||||
}
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("email_type = ? AND opened_at IS NOT NULL", models.OnboardingEmailNoTasks).
|
||||
Count(&noTasksOpened)
|
||||
Count(&noTasksOpened).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Failed to count opened no-tasks emails")
|
||||
}
|
||||
stats.NoTasksTotal = noTasksTotal
|
||||
stats.NoTasksOpened = noTasksOpened
|
||||
|
||||
@@ -351,7 +362,9 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp
|
||||
// If already sent before, delete the old record first to allow re-recording
|
||||
// This allows admins to "resend" emails while still tracking them
|
||||
if alreadySent {
|
||||
s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{})
|
||||
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
|
||||
}
|
||||
}
|
||||
|
||||
// Record that email was sent
|
||||
|
||||
51
internal/services/path_utils.go
Normal file
51
internal/services/path_utils.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SafeResolvePath resolves a user-supplied relative path within a base directory.
|
||||
// Returns an error if the resolved path escapes the base directory (path traversal).
|
||||
// The baseDir must be an absolute path.
|
||||
func SafeResolvePath(baseDir, userInput string) (string, error) {
|
||||
if userInput == "" {
|
||||
return "", fmt.Errorf("empty path")
|
||||
}
|
||||
|
||||
// Reject absolute paths
|
||||
if filepath.IsAbs(userInput) {
|
||||
return "", fmt.Errorf("absolute paths not allowed")
|
||||
}
|
||||
|
||||
// Clean the user input to resolve . and .. components
|
||||
cleaned := filepath.Clean(userInput)
|
||||
|
||||
// After cleaning, check if it starts with .. (escapes base)
|
||||
if strings.HasPrefix(cleaned, "..") {
|
||||
return "", fmt.Errorf("path traversal detected")
|
||||
}
|
||||
|
||||
// Resolve the base directory to an absolute path
|
||||
absBase, err := filepath.Abs(baseDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base directory: %w", err)
|
||||
}
|
||||
|
||||
// Join and resolve the full path
|
||||
fullPath := filepath.Join(absBase, cleaned)
|
||||
|
||||
// Final containment check: the resolved path must be within the base directory
|
||||
absFullPath, err := filepath.Abs(fullPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid resolved path: %w", err)
|
||||
}
|
||||
|
||||
// Ensure the resolved path is strictly inside the base directory (not the base itself)
|
||||
if !strings.HasPrefix(absFullPath, absBase+string(filepath.Separator)) {
|
||||
return "", fmt.Errorf("path traversal detected")
|
||||
}
|
||||
|
||||
return absFullPath, nil
|
||||
}
|
||||
55
internal/services/path_utils_test.go
Normal file
55
internal/services/path_utils_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeResolvePath_Normal_Resolves(t *testing.T) {
|
||||
result, err := SafeResolvePath("/var/uploads", "images/photo.jpg")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
|
||||
}
|
||||
|
||||
func TestSafeResolvePath_SubdirPath_Resolves(t *testing.T) {
|
||||
result, err := SafeResolvePath("/var/uploads", "documents/2024/report.pdf")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/var/uploads/documents/2024/report.pdf", result)
|
||||
}
|
||||
|
||||
func TestSafeResolvePath_DotDotTraversal_Blocked(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"simple dotdot", "../etc/passwd"},
|
||||
{"nested dotdot", "../../etc/shadow"},
|
||||
{"embedded dotdot", "images/../../etc/passwd"},
|
||||
{"deep dotdot", "a/b/c/../../../../etc/passwd"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := SafeResolvePath("/var/uploads", tt.input)
|
||||
assert.Error(t, err, "path traversal should be blocked: %s", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeResolvePath_AbsolutePath_Blocked(t *testing.T) {
|
||||
_, err := SafeResolvePath("/var/uploads", "/etc/passwd")
|
||||
assert.Error(t, err, "absolute paths should be blocked")
|
||||
}
|
||||
|
||||
func TestSafeResolvePath_EmptyPath_Blocked(t *testing.T) {
|
||||
_, err := SafeResolvePath("/var/uploads", "")
|
||||
assert.Error(t, err, "empty paths should be blocked")
|
||||
}
|
||||
|
||||
func TestSafeResolvePath_CurrentDir_Blocked(t *testing.T) {
|
||||
// "." resolves to the base dir itself — this is not a file, so block it
|
||||
_, err := SafeResolvePath("/var/uploads", ".")
|
||||
assert.Error(t, err, "bare current directory should be blocked")
|
||||
}
|
||||
@@ -126,10 +126,11 @@ func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte
|
||||
pdf.SetFillColor(255, 255, 255) // White
|
||||
}
|
||||
|
||||
// Title (truncate if too long)
|
||||
// Title (truncate if too long, use runes to avoid cutting multi-byte UTF-8 characters)
|
||||
title := task.Title
|
||||
if len(title) > 35 {
|
||||
title = title[:32] + "..."
|
||||
titleRunes := []rune(title)
|
||||
if len(titleRunes) > 35 {
|
||||
title = string(titleRunes[:32]) + "..."
|
||||
}
|
||||
|
||||
// Status text
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
@@ -31,10 +32,11 @@ var (
|
||||
|
||||
// ResidenceService handles residence business logic
|
||||
type ResidenceService struct {
|
||||
residenceRepo *repositories.ResidenceRepository
|
||||
userRepo *repositories.UserRepository
|
||||
taskRepo *repositories.TaskRepository
|
||||
config *config.Config
|
||||
residenceRepo *repositories.ResidenceRepository
|
||||
userRepo *repositories.UserRepository
|
||||
taskRepo *repositories.TaskRepository
|
||||
subscriptionService *SubscriptionService
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewResidenceService creates a new residence service
|
||||
@@ -51,6 +53,11 @@ func (s *ResidenceService) SetTaskRepository(taskRepo *repositories.TaskReposito
|
||||
s.taskRepo = taskRepo
|
||||
}
|
||||
|
||||
// SetSubscriptionService sets the subscription service (used for tier limit enforcement)
|
||||
func (s *ResidenceService) SetSubscriptionService(subService *SubscriptionService) {
|
||||
s.subscriptionService = subService
|
||||
}
|
||||
|
||||
// GetResidence gets a residence by ID with access check
|
||||
func (s *ResidenceService) GetResidence(residenceID, userID uint) (*responses.ResidenceResponse, error) {
|
||||
// Check access
|
||||
@@ -152,12 +159,12 @@ func (s *ResidenceService) getSummaryForUser(_ uint) responses.TotalSummary {
|
||||
|
||||
// CreateResidence creates a new residence and returns it with updated summary
|
||||
func (s *ResidenceService) CreateResidence(req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceWithSummaryResponse, error) {
|
||||
// TODO: Check subscription tier limits
|
||||
// count, err := s.residenceRepo.CountByOwner(ownerID)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// Check against tier limits...
|
||||
// Check subscription tier limits (if subscription service is wired up)
|
||||
if s.subscriptionService != nil {
|
||||
if err := s.subscriptionService.CheckLimit(ownerID, "properties"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
isPrimary := true
|
||||
if req.IsPrimary != nil {
|
||||
@@ -447,6 +454,7 @@ func (s *ResidenceService) JoinWithCode(code string, userID uint) (*responses.Jo
|
||||
if err := s.residenceRepo.DeactivateShareCode(shareCode.ID); err != nil {
|
||||
// Log the error but don't fail the join - the user has already been added
|
||||
// The code will just be usable by others until it expires
|
||||
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate share code after join")
|
||||
}
|
||||
|
||||
// Get the residence with full details
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
@@ -333,3 +337,122 @@ func TestResidenceService_RemoveUser_CannotRemoveOwner(t *testing.T) {
|
||||
err := service.RemoveUser(residence.ID, owner.ID, owner.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.cannot_remove_owner")
|
||||
}
|
||||
|
||||
// setupResidenceServiceWithSubscription creates a ResidenceService wired with a
|
||||
// SubscriptionService, enabling tier limit enforcement in tests.
|
||||
func setupResidenceServiceWithSubscription(t *testing.T) (*ResidenceService, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
subscriptionService := NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
|
||||
service.SetSubscriptionService(subscriptionService)
|
||||
|
||||
return service, db
|
||||
}
|
||||
|
||||
func TestCreateResidence_FreeTier_EnforcesLimit(t *testing.T) {
|
||||
service, db := setupResidenceServiceWithSubscription(t)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Enable global limitations
|
||||
db.Where("1=1").Delete(&models.SubscriptionSettings{})
|
||||
err := db.Create(&models.SubscriptionSettings{EnableLimitations: true}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set free tier limit to 1 property
|
||||
one := 1
|
||||
db.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{})
|
||||
err = db.Create(&models.TierLimits{
|
||||
Tier: models.TierFree,
|
||||
PropertiesLimit: &one,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure user has a free-tier subscription record
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
_, err = subscriptionRepo.GetOrCreate(owner.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First residence should succeed (under the limit)
|
||||
req := &requests.CreateResidenceRequest{
|
||||
Name: "First House",
|
||||
StreetAddress: "1 Main St",
|
||||
City: "Austin",
|
||||
StateProvince: "TX",
|
||||
PostalCode: "78701",
|
||||
}
|
||||
resp, err := service.CreateResidence(req, owner.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "First House", resp.Data.Name)
|
||||
|
||||
// Second residence should be rejected (at the limit)
|
||||
req2 := &requests.CreateResidenceRequest{
|
||||
Name: "Second House",
|
||||
StreetAddress: "2 Main St",
|
||||
City: "Austin",
|
||||
StateProvince: "TX",
|
||||
PostalCode: "78702",
|
||||
}
|
||||
_, err = service.CreateResidence(req2, owner.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.properties_limit_exceeded")
|
||||
}
|
||||
|
||||
func TestCreateResidence_ProTier_AllowsMore(t *testing.T) {
|
||||
service, db := setupResidenceServiceWithSubscription(t)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Enable global limitations
|
||||
db.Where("1=1").Delete(&models.SubscriptionSettings{})
|
||||
err := db.Create(&models.SubscriptionSettings{EnableLimitations: true}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set free tier limit to 1 property (pro is unlimited by default: nil limits)
|
||||
one := 1
|
||||
db.Where("tier = ?", models.TierFree).Delete(&models.TierLimits{})
|
||||
err = db.Create(&models.TierLimits{
|
||||
Tier: models.TierFree,
|
||||
PropertiesLimit: &one,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a pro-tier subscription for the user
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
sub, err := subscriptionRepo.GetOrCreate(owner.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Upgrade to Pro with a future expiration
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
sub.Tier = models.TierPro
|
||||
sub.ExpiresAt = &future
|
||||
sub.SubscribedAt = ptrTime(time.Now().UTC())
|
||||
err = subscriptionRepo.Update(sub)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple residences — all should succeed for Pro users
|
||||
for i := 1; i <= 3; i++ {
|
||||
req := &requests.CreateResidenceRequest{
|
||||
Name: fmt.Sprintf("House %d", i),
|
||||
StreetAddress: fmt.Sprintf("%d Main St", i),
|
||||
City: "Austin",
|
||||
StateProvince: "TX",
|
||||
PostalCode: "78701",
|
||||
}
|
||||
resp, err := service.CreateResidence(req, owner.ID)
|
||||
require.NoError(t, err, "Pro user should be able to create residence %d", i)
|
||||
assert.Equal(t, fmt.Sprintf("House %d", i), resp.Data.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// ptrTime returns a pointer to the given time.
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
||||
if ext == "" {
|
||||
ext = s.getExtensionFromMimeType(mimeType)
|
||||
}
|
||||
newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String()[:8], ext)
|
||||
newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String(), ext)
|
||||
|
||||
// Determine subdirectory based on category
|
||||
subdir := "images"
|
||||
@@ -134,9 +134,15 @@ func (s *StorageService) Delete(fileURL string) error {
|
||||
fullPath := filepath.Join(s.cfg.UploadDir, relativePath)
|
||||
|
||||
// Security check: ensure path is within upload directory
|
||||
absUploadDir, _ := filepath.Abs(s.cfg.UploadDir)
|
||||
absFilePath, _ := filepath.Abs(fullPath)
|
||||
if !strings.HasPrefix(absFilePath, absUploadDir) {
|
||||
absUploadDir, err := filepath.Abs(s.cfg.UploadDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve upload directory: %w", err)
|
||||
}
|
||||
absFilePath, err := filepath.Abs(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve file path: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir {
|
||||
return fmt.Errorf("invalid file path")
|
||||
}
|
||||
|
||||
@@ -181,3 +187,9 @@ func (s *StorageService) getExtensionFromMimeType(mimeType string) string {
|
||||
func (s *StorageService) GetUploadDir() string {
|
||||
return s.cfg.UploadDir
|
||||
}
|
||||
|
||||
// NewStorageServiceForTest creates a StorageService without creating directories.
|
||||
// This is intended only for unit tests that need a StorageService with a known config.
|
||||
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
|
||||
return &StorageService{cfg: cfg}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package services
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
@@ -74,11 +74,11 @@ func NewSubscriptionService(
|
||||
appleClient, err := NewAppleIAPClient(cfg.AppleIAP)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrIAPNotConfigured) {
|
||||
log.Printf("Warning: Failed to initialize Apple IAP client: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to initialize Apple IAP client")
|
||||
}
|
||||
} else {
|
||||
svc.appleClient = appleClient
|
||||
log.Println("Apple IAP validation client initialized")
|
||||
log.Info().Msg("Apple IAP validation client initialized")
|
||||
}
|
||||
|
||||
// Initialize Google IAP client
|
||||
@@ -86,11 +86,11 @@ func NewSubscriptionService(
|
||||
googleClient, err := NewGoogleIAPClient(ctx, cfg.GoogleIAP)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrIAPNotConfigured) {
|
||||
log.Printf("Warning: Failed to initialize Google IAP client: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to initialize Google IAP client")
|
||||
}
|
||||
} else {
|
||||
svc.googleClient = googleClient
|
||||
log.Println("Google IAP validation client initialized")
|
||||
log.Info().Msg("Google IAP validation client initialized")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +173,8 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getUserUsage calculates current usage for a user
|
||||
// getUserUsage calculates current usage for a user.
|
||||
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
|
||||
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
|
||||
residences, err := s.residenceRepo.FindOwnedByUser(userID)
|
||||
if err != nil {
|
||||
@@ -181,26 +182,26 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
|
||||
}
|
||||
propertiesCount := int64(len(residences))
|
||||
|
||||
// Count tasks, contractors, and documents across all user's residences
|
||||
var tasksCount, contractorsCount, documentsCount int64
|
||||
for _, r := range residences {
|
||||
tc, err := s.taskRepo.CountByResidence(r.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
tasksCount += tc
|
||||
// Collect residence IDs for batch queries
|
||||
residenceIDs := make([]uint, len(residences))
|
||||
for i, r := range residences {
|
||||
residenceIDs[i] = r.ID
|
||||
}
|
||||
|
||||
cc, err := s.contractorRepo.CountByResidence(r.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
contractorsCount += cc
|
||||
// Count tasks, contractors, and documents across all residences with single queries each
|
||||
tasksCount, err := s.taskRepo.CountByResidenceIDs(residenceIDs)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
dc, err := s.documentRepo.CountByResidence(r.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
documentsCount += dc
|
||||
contractorsCount, err := s.contractorRepo.CountByResidenceIDs(residenceIDs)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
documentsCount, err := s.documentRepo.CountByResidenceIDs(residenceIDs)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &UsageResponse{
|
||||
@@ -342,46 +343,40 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Validate with Apple if client is configured
|
||||
var expiresAt time.Time
|
||||
if s.appleClient != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *AppleValidationResult
|
||||
var err error
|
||||
|
||||
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
|
||||
if transactionID != "" {
|
||||
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
|
||||
} else if receiptData != "" {
|
||||
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log the validation error
|
||||
log.Printf("Apple validation warning for user %d: %v", userID, err)
|
||||
|
||||
// Check if it's a fatal error
|
||||
if errors.Is(err, ErrInvalidReceipt) || errors.Is(err, ErrSubscriptionCancelled) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For other errors (network, etc.), fall back with shorter expiry
|
||||
expiresAt = time.Now().UTC().AddDate(0, 1, 0) // 1 month fallback
|
||||
} else if result != nil {
|
||||
// Use the expiration date from Apple
|
||||
expiresAt = result.ExpiresAt
|
||||
log.Printf("Apple purchase validated for user %d: product=%s, expires=%v, env=%s",
|
||||
userID, result.ProductID, result.ExpiresAt, result.Environment)
|
||||
}
|
||||
} else {
|
||||
// Apple validation not configured - trust client but log warning
|
||||
log.Printf("Warning: Apple IAP validation not configured, trusting client for user %d", userID)
|
||||
expiresAt = time.Now().UTC().AddDate(1, 0, 0) // 1 year default
|
||||
// Apple IAP client must be configured to validate purchases.
|
||||
// Without server-side validation, we cannot trust client-provided receipts.
|
||||
if s.appleClient == nil {
|
||||
log.Error().Uint("user_id", userID).Msg("Apple IAP validation not configured, rejecting purchase")
|
||||
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
|
||||
}
|
||||
|
||||
// Upgrade to Pro with the determined expiration
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *AppleValidationResult
|
||||
var err error
|
||||
|
||||
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
|
||||
if transactionID != "" {
|
||||
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
|
||||
} else if receiptData != "" {
|
||||
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Validation failed -- do NOT fall through to grant Pro.
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Apple validation failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, apperrors.BadRequest("error.no_receipt_or_transaction")
|
||||
}
|
||||
|
||||
expiresAt := result.ExpiresAt
|
||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
|
||||
|
||||
// Upgrade to Pro with the validated expiration
|
||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -397,59 +392,48 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Validate the purchase with Google if client is configured
|
||||
var expiresAt time.Time
|
||||
if s.googleClient != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *GoogleValidationResult
|
||||
var err error
|
||||
|
||||
// If productID is provided, use it directly; otherwise try known IDs
|
||||
if productID != "" {
|
||||
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
|
||||
} else {
|
||||
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log the validation error
|
||||
log.Printf("Google purchase validation warning for user %d: %v", userID, err)
|
||||
|
||||
// Check if it's a fatal error
|
||||
if errors.Is(err, ErrInvalidPurchaseToken) || errors.Is(err, ErrSubscriptionCancelled) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrSubscriptionExpired) {
|
||||
// Subscription expired - still allow but set past expiry
|
||||
expiresAt = time.Now().UTC().Add(-1 * time.Hour)
|
||||
} else {
|
||||
// For other errors, fall back with shorter expiry
|
||||
expiresAt = time.Now().UTC().AddDate(0, 1, 0) // 1 month fallback
|
||||
}
|
||||
} else if result != nil {
|
||||
// Use the expiration date from Google
|
||||
expiresAt = result.ExpiresAt
|
||||
log.Printf("Google purchase validated for user %d: product=%s, expires=%v, autoRenew=%v",
|
||||
userID, result.ProductID, result.ExpiresAt, result.AutoRenewing)
|
||||
|
||||
// Acknowledge the subscription if not already acknowledged
|
||||
if !result.AcknowledgedState {
|
||||
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
|
||||
log.Printf("Warning: Failed to acknowledge subscription for user %d: %v", userID, err)
|
||||
// Don't fail the purchase, just log the warning
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Google validation not configured - trust client but log warning
|
||||
log.Printf("Warning: Google IAP validation not configured, trusting client for user %d", userID)
|
||||
expiresAt = time.Now().UTC().AddDate(1, 0, 0) // 1 year default
|
||||
// Google IAP client must be configured to validate purchases.
|
||||
// Without server-side validation, we cannot trust client-provided tokens.
|
||||
if s.googleClient == nil {
|
||||
log.Error().Uint("user_id", userID).Msg("Google IAP validation not configured, rejecting purchase")
|
||||
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
|
||||
}
|
||||
|
||||
// Upgrade to Pro with the determined expiration
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *GoogleValidationResult
|
||||
var err error
|
||||
|
||||
// If productID is provided, use it directly; otherwise try known IDs
|
||||
if productID != "" {
|
||||
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
|
||||
} else {
|
||||
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Validation failed -- do NOT fall through to grant Pro.
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Google purchase validation failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, apperrors.BadRequest("error.no_purchase_token")
|
||||
}
|
||||
|
||||
expiresAt := result.ExpiresAt
|
||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Bool("auto_renew", result.AutoRenewing).Msg("Google purchase validated")
|
||||
|
||||
// Acknowledge the subscription if not already acknowledged
|
||||
if !result.AcknowledgedState {
|
||||
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to acknowledge Google subscription")
|
||||
// Don't fail the purchase, just log the warning
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade to Pro with the validated expiration
|
||||
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -654,5 +638,5 @@ type ProcessPurchaseRequest struct {
|
||||
TransactionID string `json:"transaction_id"` // iOS StoreKit 2 transaction ID
|
||||
PurchaseToken string `json:"purchase_token"` // Android
|
||||
ProductID string `json:"product_id"` // Android (optional, helps identify subscription)
|
||||
Platform string `json:"platform" binding:"required,oneof=ios android"`
|
||||
Platform string `json:"platform" validate:"required,oneof=ios android"`
|
||||
}
|
||||
|
||||
181
internal/services/subscription_service_test.go
Normal file
181
internal/services/subscription_service_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
// setupSubscriptionService creates a SubscriptionService with the given
|
||||
// IAP clients (nil means "not configured"). It bypasses NewSubscriptionService
|
||||
// which tries to load config from environment.
|
||||
func setupSubscriptionService(t *testing.T, appleClient *AppleIAPClient, googleClient *GoogleIAPClient) (*SubscriptionService, *repositories.SubscriptionRepository) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
// Create a test user and subscription record for the test
|
||||
user := testutil.CreateTestUser(t, db, "subuser", "subuser@test.com", "password")
|
||||
|
||||
// Create subscription record so GetOrCreate will find it
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierFree,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
appleClient: appleClient,
|
||||
googleClient: googleClient,
|
||||
}
|
||||
|
||||
return svc, subscriptionRepo
|
||||
}
|
||||
|
||||
func TestProcessApplePurchase_ClientNil_ReturnsError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "subuser", "subuser@test.com", "password")
|
||||
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
appleClient: nil, // Not configured
|
||||
googleClient: nil,
|
||||
}
|
||||
|
||||
_, err := svc.ProcessApplePurchase(user.ID, "fake-receipt", "")
|
||||
assert.Error(t, err, "ProcessApplePurchase should return error when Apple IAP client is nil")
|
||||
|
||||
// Verify user was NOT upgraded to Pro
|
||||
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier when IAP client is nil")
|
||||
}
|
||||
|
||||
func TestProcessApplePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
||||
// We cannot easily create a real AppleIAPClient that will fail validation
|
||||
// in a unit test (it requires real keys and network access).
|
||||
// Instead, we test the code path logic:
|
||||
// When appleClient is nil, the service must NOT upgrade the user.
|
||||
// This is the same as TestProcessApplePurchase_ClientNil_ReturnsError
|
||||
// but validates no fallback occurs for the specific case.
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "subuser2", "subuser2@test.com", "password")
|
||||
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
appleClient: nil,
|
||||
googleClient: nil,
|
||||
}
|
||||
|
||||
// Neither receipt data nor transaction ID - should still not grant Pro
|
||||
_, err := svc.ProcessApplePurchase(user.ID, "", "")
|
||||
assert.Error(t, err, "ProcessApplePurchase should return error when client is nil, even with empty data")
|
||||
|
||||
// Verify no upgrade happened
|
||||
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
|
||||
}
|
||||
|
||||
func TestProcessGooglePurchase_ClientNil_ReturnsError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "subuser3", "subuser3@test.com", "password")
|
||||
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
appleClient: nil,
|
||||
googleClient: nil, // Not configured
|
||||
}
|
||||
|
||||
_, err := svc.ProcessGooglePurchase(user.ID, "fake-token", "com.tt.casera.pro.monthly")
|
||||
assert.Error(t, err, "ProcessGooglePurchase should return error when Google IAP client is nil")
|
||||
|
||||
// Verify user was NOT upgraded to Pro
|
||||
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier when IAP client is nil")
|
||||
}
|
||||
|
||||
func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "subuser4", "subuser4@test.com", "password")
|
||||
sub := &models.UserSubscription{UserID: user.ID, Tier: models.TierFree}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
appleClient: nil,
|
||||
googleClient: nil, // Not configured
|
||||
}
|
||||
|
||||
// With empty token
|
||||
_, err := svc.ProcessGooglePurchase(user.ID, "", "")
|
||||
assert.Error(t, err, "ProcessGooglePurchase should return error when client is nil")
|
||||
|
||||
// Verify no upgrade happened
|
||||
updatedSub, err := subscriptionRepo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
|
||||
}
|
||||
@@ -560,11 +560,7 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
||||
Rating: req.Rating,
|
||||
}
|
||||
|
||||
if err := s.taskRepo.CreateCompletion(completion); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Update next_due_date and in_progress based on frequency
|
||||
// Determine interval days for NextDueDate calculation before entering the transaction.
|
||||
// - If frequency is "Once" (days = nil or 0), set next_due_date to nil (marks as completed)
|
||||
// - If frequency is "Custom", use task.CustomIntervalDays for recurrence
|
||||
// - If frequency is recurring, calculate next_due_date = completion_date + frequency_days
|
||||
@@ -598,11 +594,25 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
||||
// instead of staying in "In Progress" column
|
||||
task.InProgress = false
|
||||
}
|
||||
if err := s.taskRepo.Update(task); err != nil {
|
||||
if errors.Is(err, repositories.ErrVersionConflict) {
|
||||
|
||||
// P1-5: Wrap completion creation and task update in a transaction.
|
||||
// If either operation fails, both are rolled back to prevent orphaned completions.
|
||||
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
||||
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
// P1-6: Return the error instead of swallowing it.
|
||||
if errors.Is(txErr, repositories.ErrVersionConflict) {
|
||||
return nil, apperrors.Conflict("error.version_conflict")
|
||||
}
|
||||
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after completion")
|
||||
log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task")
|
||||
return nil, apperrors.Internal(txErr)
|
||||
}
|
||||
|
||||
// Create images if provided
|
||||
@@ -731,8 +741,15 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
||||
}
|
||||
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
|
||||
|
||||
// Send notification (fire and forget)
|
||||
go s.sendTaskCompletedNotification(task, completion)
|
||||
// Send notification (fire and forget with panic recovery)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Uint("task_id", task.ID).Msg("Panic in quick-complete notification goroutine")
|
||||
}
|
||||
}()
|
||||
s.sendTaskCompletedNotification(task, completion)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -764,23 +781,23 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
|
||||
emailImages = s.loadCompletionImagesForEmail(completion.Images)
|
||||
}
|
||||
|
||||
// Notify all users
|
||||
// Notify all users synchronously to avoid unbounded goroutine spawning.
|
||||
// This method is already called from a goroutine (QuickComplete) or inline
|
||||
// (CreateCompletion) where blocking is acceptable for notification delivery.
|
||||
for _, user := range users {
|
||||
isCompleter := user.ID == completion.CompletedByID
|
||||
|
||||
// Send push notification (to everyone EXCEPT the person who completed it)
|
||||
if !isCompleter && s.notificationService != nil {
|
||||
go func(userID uint) {
|
||||
ctx := context.Background()
|
||||
if err := s.notificationService.CreateAndSendTaskNotification(
|
||||
ctx,
|
||||
userID,
|
||||
models.NotificationTaskCompleted,
|
||||
task,
|
||||
); err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Uint("task_id", task.ID).Msg("Failed to send task completion push notification")
|
||||
}
|
||||
}(user.ID)
|
||||
ctx := context.Background()
|
||||
if err := s.notificationService.CreateAndSendTaskNotification(
|
||||
ctx,
|
||||
user.ID,
|
||||
models.NotificationTaskCompleted,
|
||||
task,
|
||||
); err != nil {
|
||||
log.Error().Err(err).Uint("user_id", user.ID).Uint("task_id", task.ID).Msg("Failed to send task completion push notification")
|
||||
}
|
||||
}
|
||||
|
||||
// Send email notification (to everyone INCLUDING the person who completed it)
|
||||
@@ -789,20 +806,18 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
|
||||
prefs, err := s.notificationService.GetPreferences(user.ID)
|
||||
if err != nil || (prefs != nil && prefs.EmailTaskCompleted) {
|
||||
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
|
||||
go func(u models.User, images []EmbeddedImage) {
|
||||
if err := s.emailService.SendTaskCompletedEmail(
|
||||
u.Email,
|
||||
u.GetFullName(),
|
||||
task.Title,
|
||||
completedByName,
|
||||
residenceName,
|
||||
images,
|
||||
); err != nil {
|
||||
log.Error().Err(err).Str("email", u.Email).Uint("task_id", task.ID).Msg("Failed to send task completion email")
|
||||
} else {
|
||||
log.Info().Str("email", u.Email).Uint("task_id", task.ID).Int("images", len(images)).Msg("Task completion email sent")
|
||||
}
|
||||
}(user, emailImages)
|
||||
if err := s.emailService.SendTaskCompletedEmail(
|
||||
user.Email,
|
||||
user.GetFullName(),
|
||||
task.Title,
|
||||
completedByName,
|
||||
residenceName,
|
||||
emailImages,
|
||||
); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Uint("task_id", task.ID).Msg("Failed to send task completion email")
|
||||
} else {
|
||||
log.Info().Str("email", user.Email).Uint("task_id", task.ID).Int("images", len(emailImages)).Msg("Task completion email sent")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -846,20 +861,28 @@ func (s *TaskService) loadCompletionImagesForEmail(images []models.TaskCompletio
|
||||
return emailImages
|
||||
}
|
||||
|
||||
// resolveImageFilePath converts a stored URL to an actual file path
|
||||
// resolveImageFilePath converts a stored URL to an actual file path.
|
||||
// Returns empty string if the URL is empty or the resolved path would escape
|
||||
// the upload directory (path traversal attempt).
|
||||
func (s *TaskService) resolveImageFilePath(storedURL, uploadDir string) string {
|
||||
if storedURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle /uploads/... URLs
|
||||
// Strip legacy /uploads/ prefix to get relative path
|
||||
relativePath := storedURL
|
||||
if strings.HasPrefix(storedURL, "/uploads/") {
|
||||
relativePath := strings.TrimPrefix(storedURL, "/uploads/")
|
||||
return filepath.Join(uploadDir, relativePath)
|
||||
relativePath = strings.TrimPrefix(storedURL, "/uploads/")
|
||||
}
|
||||
|
||||
// Handle relative paths
|
||||
return filepath.Join(uploadDir, storedURL)
|
||||
// Use SafeResolvePath to validate containment within upload directory
|
||||
resolved, err := SafeResolvePath(uploadDir, relativePath)
|
||||
if err != nil {
|
||||
// Path traversal or invalid path — return empty to signal file not found
|
||||
return ""
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
// getContentTypeFromPath returns the MIME type based on file extension
|
||||
@@ -977,7 +1000,11 @@ func (s *TaskService) UpdateCompletion(completionID, userID uint, req *requests.
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteCompletion deletes a task completion
|
||||
// DeleteCompletion deletes a task completion and recalculates the task's NextDueDate.
|
||||
//
|
||||
// P1-7: After deleting a completion, NextDueDate must be recalculated:
|
||||
// - If no completions remain: restore NextDueDate = DueDate (original schedule)
|
||||
// - If completions remain (recurring): recalculate from latest remaining completion + frequency days
|
||||
func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.DeleteWithSummaryResponse, error) {
|
||||
completion, err := s.taskRepo.FindCompletionByID(completionID)
|
||||
if err != nil {
|
||||
@@ -996,10 +1023,66 @@ func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.De
|
||||
return nil, apperrors.Forbidden("error.task_access_denied")
|
||||
}
|
||||
|
||||
taskID := completion.TaskID
|
||||
|
||||
if err := s.taskRepo.DeleteCompletion(completionID); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Recalculate NextDueDate based on remaining completions
|
||||
task, err := s.taskRepo.FindByID(taskID)
|
||||
if err != nil {
|
||||
// Non-fatal for the delete operation itself, but log the error
|
||||
log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to reload task after completion deletion for NextDueDate recalculation")
|
||||
return &responses.DeleteWithSummaryResponse{
|
||||
Data: "completion deleted",
|
||||
Summary: s.getSummaryForUser(userID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get remaining completions for this task
|
||||
remainingCompletions, err := s.taskRepo.FindCompletionsByTask(taskID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to query remaining completions after deletion")
|
||||
return &responses.DeleteWithSummaryResponse{
|
||||
Data: "completion deleted",
|
||||
Summary: s.getSummaryForUser(userID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Determine the task's frequency interval
|
||||
var intervalDays *int
|
||||
if task.FrequencyID != nil {
|
||||
frequency, freqErr := s.taskRepo.GetFrequencyByID(*task.FrequencyID)
|
||||
if freqErr == nil && frequency != nil {
|
||||
if frequency.Name == "Custom" {
|
||||
intervalDays = task.CustomIntervalDays
|
||||
} else {
|
||||
intervalDays = frequency.Days
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(remainingCompletions) == 0 {
|
||||
// No completions remain: restore NextDueDate to the original DueDate
|
||||
task.NextDueDate = task.DueDate
|
||||
} else if intervalDays != nil && *intervalDays > 0 {
|
||||
// Recurring task with remaining completions: recalculate from the latest completion
|
||||
// remainingCompletions is ordered by completed_at DESC, so index 0 is the latest
|
||||
latestCompletion := remainingCompletions[0]
|
||||
nextDue := latestCompletion.CompletedAt.AddDate(0, 0, *intervalDays)
|
||||
task.NextDueDate = &nextDue
|
||||
} else {
|
||||
// One-time task with remaining completions (unusual case): keep NextDueDate as nil
|
||||
// since the task is still considered completed
|
||||
task.NextDueDate = nil
|
||||
}
|
||||
|
||||
if err := s.taskRepo.Update(task); err != nil {
|
||||
log.Error().Err(err).Uint("task_id", taskID).Msg("Failed to update task NextDueDate after completion deletion")
|
||||
// The completion was already deleted; return success but log the update failure
|
||||
}
|
||||
|
||||
return &responses.DeleteWithSummaryResponse{
|
||||
Data: "completion deleted",
|
||||
Summary: s.getSummaryForUser(userID),
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
@@ -442,6 +443,333 @@ func TestTaskService_DeleteCompletion(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTaskService_CreateCompletion_TransactionIntegrity(t *testing.T) {
|
||||
// Verifies P1-5 / P1-6: completion creation and task update are atomic.
|
||||
// After completion, both the completion record AND the task's NextDueDate update
|
||||
// should succeed together, and errors should be propagated (not swallowed).
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewTaskService(taskRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a one-time task with a due date
|
||||
dueDate := time.Now().AddDate(0, 0, 7).UTC()
|
||||
task := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "One-time Task",
|
||||
DueDate: &dueDate,
|
||||
NextDueDate: &dueDate,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
err := db.Create(task).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Notes: "Done",
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
resp, err := service.CreateCompletion(req, user.ID, now)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, resp.Data.ID)
|
||||
|
||||
// Verify the task was updated: NextDueDate should be nil for a one-time task
|
||||
var reloaded models.Task
|
||||
db.First(&reloaded, task.ID)
|
||||
assert.Nil(t, reloaded.NextDueDate, "One-time task NextDueDate should be nil after completion")
|
||||
assert.False(t, reloaded.InProgress, "InProgress should be false after completion")
|
||||
|
||||
// Verify completion record exists
|
||||
var completion models.TaskCompletion
|
||||
err = db.Where("task_id = ?", task.ID).First(&completion).Error
|
||||
require.NoError(t, err, "Completion record should exist")
|
||||
assert.Equal(t, "Done", completion.Notes)
|
||||
}
|
||||
|
||||
func TestTaskService_CreateCompletion_UpdateError_ReturnedNotSwallowed(t *testing.T) {
|
||||
// Verifies P1-5 and P1-6: the completion creation and task update are wrapped
|
||||
// in a transaction, and update errors are returned (not swallowed).
|
||||
//
|
||||
// Strategy: We trigger a version conflict by using a goroutine that bumps
|
||||
// the task version after the service reads the task but during the transaction.
|
||||
// Since SQLite serializes writes, we instead verify the behavior by deleting
|
||||
// the task between the service read and the transactional update. When UpdateTx
|
||||
// tries to match the row by id+version, 0 rows are affected and ErrVersionConflict
|
||||
// is returned. The transaction then rolls back the completion insert.
|
||||
//
|
||||
// However, because the entire CreateCompletion flow is synchronous and we cannot
|
||||
// inject failures between steps, we instead verify the transactional guarantee
|
||||
// indirectly: we confirm that a concurrent version bump (set before the call
|
||||
// but after the SELECT) causes the version conflict to propagate. Since FindByID
|
||||
// re-reads the current version, we must verify via a custom test that invokes
|
||||
// the transaction layer directly.
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
dueDate := time.Now().AddDate(0, 0, 7).UTC()
|
||||
task := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Conflict Task",
|
||||
DueDate: &dueDate,
|
||||
NextDueDate: &dueDate,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
err := db.Create(task).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Directly test that the transactional path returns an error on version conflict:
|
||||
// Use a stale task object (version=1) when the DB has been bumped to version=999.
|
||||
db.Model(&models.Task{}).Where("id = ?", task.ID).Update("version", 999)
|
||||
|
||||
completion := &models.TaskCompletion{
|
||||
TaskID: task.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: time.Now().UTC(),
|
||||
Notes: "Should be rolled back",
|
||||
}
|
||||
|
||||
// Simulate the transaction that CreateCompletion now uses (task still has version=1)
|
||||
txErr := taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
||||
if err := taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
||||
return err
|
||||
}
|
||||
// task.Version is 1 but DB has 999 -> version conflict
|
||||
if err := taskRepo.UpdateTx(tx, task); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
require.Error(t, txErr, "Transaction should fail due to version conflict")
|
||||
assert.ErrorIs(t, txErr, repositories.ErrVersionConflict, "Error should be ErrVersionConflict")
|
||||
|
||||
// Verify the completion was rolled back
|
||||
var count int64
|
||||
db.Model(&models.TaskCompletion{}).Where("task_id = ?", task.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "Completion should not exist when transaction rolls back")
|
||||
|
||||
// Also verify that CreateCompletion (full service method) would propagate the error.
|
||||
// Re-create the task with a normal version so FindByID works, then bump it.
|
||||
db.Model(&models.Task{}).Where("id = ?", task.ID).Update("version", 1)
|
||||
service := NewTaskService(taskRepo, residenceRepo)
|
||||
req := &requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Notes: "Test error propagation",
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
// This call will succeed because FindByID loads version=1, UpdateTx uses version=1, DB has version=1.
|
||||
// To verify error propagation, we use the direct transaction test above.
|
||||
resp, err := service.CreateCompletion(req, user.ID, now)
|
||||
require.NoError(t, err, "CreateCompletion should succeed with matching versions")
|
||||
assert.NotZero(t, resp.Data.ID)
|
||||
}
|
||||
|
||||
func TestTaskService_DeleteCompletion_OneTime_RestoresOriginalDueDate(t *testing.T) {
|
||||
// Verifies P1-7: deleting the only completion on a one-time task
|
||||
// should restore NextDueDate to the original DueDate.
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewTaskService(taskRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a one-time task with a due date
|
||||
originalDueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
task := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "One-time Task",
|
||||
DueDate: &originalDueDate,
|
||||
NextDueDate: &originalDueDate,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
// No FrequencyID = one-time task
|
||||
}
|
||||
err := db.Create(task).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Complete the task (sets NextDueDate to nil for one-time tasks)
|
||||
req := &requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Notes: "Completed",
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
completionResp, err := service.CreateCompletion(req, user.ID, now)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Confirm NextDueDate is nil after completion
|
||||
var taskAfterComplete models.Task
|
||||
db.First(&taskAfterComplete, task.ID)
|
||||
assert.Nil(t, taskAfterComplete.NextDueDate, "NextDueDate should be nil after one-time completion")
|
||||
|
||||
// Delete the completion
|
||||
_, err = service.DeleteCompletion(completionResp.Data.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify NextDueDate is restored to the original DueDate
|
||||
var taskAfterDelete models.Task
|
||||
db.First(&taskAfterDelete, task.ID)
|
||||
require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be restored after deleting completion")
|
||||
assert.Equal(t, originalDueDate.Year(), taskAfterDelete.NextDueDate.Year())
|
||||
assert.Equal(t, originalDueDate.Month(), taskAfterDelete.NextDueDate.Month())
|
||||
assert.Equal(t, originalDueDate.Day(), taskAfterDelete.NextDueDate.Day())
|
||||
}
|
||||
|
||||
func TestTaskService_DeleteCompletion_Recurring_RecalculatesFromLastCompletion(t *testing.T) {
|
||||
// Verifies P1-7: deleting the latest completion on a recurring task
|
||||
// should recalculate NextDueDate from the remaining latest completion.
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewTaskService(taskRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
var monthlyFrequency models.TaskFrequency
|
||||
db.Where("name = ?", "Monthly").First(&monthlyFrequency)
|
||||
|
||||
// Create a recurring task
|
||||
originalDueDate := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
task := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Recurring Task",
|
||||
FrequencyID: &monthlyFrequency.ID,
|
||||
DueDate: &originalDueDate,
|
||||
NextDueDate: &originalDueDate,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
err := db.Create(task).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// First completion on Jan 15
|
||||
firstCompletedAt := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||
firstReq := &requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Notes: "First completion",
|
||||
CompletedAt: &firstCompletedAt,
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
_, err = service.CreateCompletion(firstReq, user.ID, now)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second completion on Feb 15
|
||||
secondCompletedAt := time.Date(2026, 2, 15, 10, 0, 0, 0, time.UTC)
|
||||
secondReq := &requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Notes: "Second completion",
|
||||
CompletedAt: &secondCompletedAt,
|
||||
}
|
||||
resp, err := service.CreateCompletion(secondReq, user.ID, now)
|
||||
require.NoError(t, err)
|
||||
|
||||
// NextDueDate should be Feb 15 + 30 days = Mar 17
|
||||
var taskAfterSecond models.Task
|
||||
db.First(&taskAfterSecond, task.ID)
|
||||
require.NotNil(t, taskAfterSecond.NextDueDate)
|
||||
expectedAfterSecond := secondCompletedAt.AddDate(0, 0, 30)
|
||||
assert.Equal(t, expectedAfterSecond.Year(), taskAfterSecond.NextDueDate.Year())
|
||||
assert.Equal(t, expectedAfterSecond.Month(), taskAfterSecond.NextDueDate.Month())
|
||||
assert.Equal(t, expectedAfterSecond.Day(), taskAfterSecond.NextDueDate.Day())
|
||||
|
||||
// Delete the second (latest) completion
|
||||
_, err = service.DeleteCompletion(resp.Data.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// NextDueDate should be recalculated from the first completion: Jan 15 + 30 = Feb 14
|
||||
var taskAfterDelete models.Task
|
||||
db.First(&taskAfterDelete, task.ID)
|
||||
require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be set after deleting latest completion")
|
||||
expectedRecalculated := firstCompletedAt.AddDate(0, 0, 30)
|
||||
assert.Equal(t, expectedRecalculated.Year(), taskAfterDelete.NextDueDate.Year())
|
||||
assert.Equal(t, expectedRecalculated.Month(), taskAfterDelete.NextDueDate.Month())
|
||||
assert.Equal(t, expectedRecalculated.Day(), taskAfterDelete.NextDueDate.Day())
|
||||
}
|
||||
|
||||
func TestTaskService_DeleteCompletion_LastCompletion_RestoresDueDate(t *testing.T) {
|
||||
// Verifies P1-7: deleting the only completion on a recurring task
|
||||
// should restore NextDueDate to the original DueDate.
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewTaskService(taskRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
var weeklyFrequency models.TaskFrequency
|
||||
db.Where("name = ?", "Weekly").First(&weeklyFrequency)
|
||||
|
||||
// Create a recurring task
|
||||
originalDueDate := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
task := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Weekly Task",
|
||||
FrequencyID: &weeklyFrequency.ID,
|
||||
DueDate: &originalDueDate,
|
||||
NextDueDate: &originalDueDate,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
err := db.Create(task).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Complete the task
|
||||
completedAt := time.Date(2026, 3, 2, 10, 0, 0, 0, time.UTC)
|
||||
req := &requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Notes: "First completion",
|
||||
CompletedAt: &completedAt,
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
completionResp, err := service.CreateCompletion(req, user.ID, now)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify NextDueDate was set to completedAt + 7 days
|
||||
var taskAfterComplete models.Task
|
||||
db.First(&taskAfterComplete, task.ID)
|
||||
require.NotNil(t, taskAfterComplete.NextDueDate)
|
||||
|
||||
// Delete the only completion
|
||||
_, err = service.DeleteCompletion(completionResp.Data.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// NextDueDate should be restored to original DueDate since no completions remain
|
||||
var taskAfterDelete models.Task
|
||||
db.First(&taskAfterDelete, task.ID)
|
||||
require.NotNil(t, taskAfterDelete.NextDueDate, "NextDueDate should be restored to original DueDate")
|
||||
assert.Equal(t, originalDueDate.Year(), taskAfterDelete.NextDueDate.Year())
|
||||
assert.Equal(t, originalDueDate.Month(), taskAfterDelete.NextDueDate.Month())
|
||||
assert.Equal(t, originalDueDate.Day(), taskAfterDelete.NextDueDate.Day())
|
||||
}
|
||||
|
||||
func TestTaskService_GetCategories(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
Reference in New Issue
Block a user