Harden API security: input validation, safe auth extraction, new tests, and deploy config

Comprehensive security hardening from audit findings:
- Add validation tags to all DTO request structs (max lengths, ranges, enums)
- Replace unsafe type assertions with MustGetAuthUser helper across all handlers
- Remove query-param token auth from admin middleware (prevents URL token leakage)
- Add request validation calls in handlers that were missing c.Validate()
- Remove goroutines in handlers (timezone update now synchronous)
- Add sanitize middleware and path traversal protection (path_utils)
- Stop resetting admin passwords on migration restart
- Warn on well-known default SECRET_KEY
- Add ~30 new test files covering security regressions, auth safety, repos, and services
- Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-02 09:48:01 -06:00
parent 56d6fa4514
commit 7690f07a2b
123 changed files with 8321 additions and 750 deletions

View File

@@ -63,7 +63,7 @@ func (r *ContractorRepository) FindByUser(userID uint, residenceIDs []uint) ([]m
query = query.Where("residence_id IS NULL AND created_by_id = ?", userID)
}
err := query.Order("is_favorite DESC, name ASC").Find(&contractors).Error
err := query.Order("is_favorite DESC, name ASC").Limit(500).Find(&contractors).Error
return contractors, err
}
@@ -85,18 +85,31 @@ func (r *ContractorRepository) Delete(id uint) error {
Update("is_active", false).Error
}
// ToggleFavorite toggles the favorite status of a contractor
// ToggleFavorite toggles the favorite status of a contractor atomically.
// Uses a single UPDATE with NOT to avoid read-then-write race conditions.
// Only toggles active contractors to prevent toggling soft-deleted records.
func (r *ContractorRepository) ToggleFavorite(id uint) (bool, error) {
var contractor models.Contractor
if err := r.db.First(&contractor, id).Error; err != nil {
return false, err
}
newStatus := !contractor.IsFavorite
err := r.db.Model(&models.Contractor{}).
Where("id = ?", id).
Update("is_favorite", newStatus).Error
var newStatus bool
err := r.db.Transaction(func(tx *gorm.DB) error {
// Atomic toggle: SET is_favorite = NOT is_favorite for active contractors only
result := tx.Model(&models.Contractor{}).
Where("id = ? AND is_active = ?", id, true).
Update("is_favorite", gorm.Expr("NOT is_favorite"))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
// Read back the new value within the same transaction
var contractor models.Contractor
if err := tx.Select("is_favorite").First(&contractor, id).Error; err != nil {
return err
}
newStatus = contractor.IsFavorite
return nil
})
return newStatus, err
}
@@ -145,6 +158,19 @@ func (r *ContractorRepository) CountByResidence(residenceID uint) (int64, error)
return count, err
}
// CountByResidenceIDs counts all active contractors across multiple residences in a single query.
// Returns the total count of active contractors for the given residence IDs.
func (r *ContractorRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
if len(residenceIDs) == 0 {
return 0, nil
}
var count int64
err := r.db.Model(&models.Contractor{}).
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Count(&count).Error
return count, err
}
// === Specialty Operations ===
// GetAllSpecialties returns all contractor specialties

View File

@@ -0,0 +1,96 @@
package repositories
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestToggleFavorite_Active_Toggles(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
// Initially is_favorite is false
assert.False(t, contractor.IsFavorite, "contractor should start as not favorite")
// First toggle: false -> true
newStatus, err := repo.ToggleFavorite(contractor.ID)
require.NoError(t, err)
assert.True(t, newStatus, "first toggle should set favorite to true")
// Verify in database
var found models.Contractor
err = db.First(&found, contractor.ID).Error
require.NoError(t, err)
assert.True(t, found.IsFavorite, "database should reflect favorite = true")
// Second toggle: true -> false
newStatus, err = repo.ToggleFavorite(contractor.ID)
require.NoError(t, err)
assert.False(t, newStatus, "second toggle should set favorite to false")
// Verify in database
err = db.First(&found, contractor.ID).Error
require.NoError(t, err)
assert.False(t, found.IsFavorite, "database should reflect favorite = false")
}
func TestToggleFavorite_SoftDeleted_ReturnsError(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Deleted Contractor")
// Soft-delete the contractor
err := db.Model(&models.Contractor{}).
Where("id = ?", contractor.ID).
Update("is_active", false).Error
require.NoError(t, err)
// Toggling a soft-deleted contractor should fail (record not found)
_, err = repo.ToggleFavorite(contractor.ID)
assert.Error(t, err, "toggling a soft-deleted contractor should return an error")
}
func TestToggleFavorite_NonExistent_ReturnsError(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
_, err := repo.ToggleFavorite(99999)
assert.Error(t, err, "toggling a non-existent contractor should return an error")
}
func TestContractorRepository_FindByUser_HasDefaultLimit(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewContractorRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create 510 contractors to exceed the default limit of 500
for i := 0; i < 510; i++ {
c := &models.Contractor{
ResidenceID: &residence.ID,
CreatedByID: user.ID,
Name: fmt.Sprintf("Contractor %d", i+1),
IsActive: true,
}
err := db.Create(c).Error
require.NoError(t, err)
}
contractors, err := repo.FindByUser(user.ID, []uint{residence.ID})
require.NoError(t, err)
assert.Equal(t, 500, len(contractors), "FindByUser should return at most 500 contractors by default")
}

View File

@@ -52,7 +52,8 @@ func (r *DocumentRepository) FindByResidence(residenceID uint) ([]models.Documen
return documents, err
}
// FindByUser finds all documents accessible to a user
// FindByUser finds all documents accessible to a user.
// A default limit of 500 is applied to prevent unbounded result sets.
func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, error) {
var documents []models.Document
err := r.db.Preload("CreatedBy").
@@ -60,6 +61,7 @@ func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document,
Preload("Images").
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Order("created_at DESC").
Limit(500).
Find(&documents).Error
return documents, err
}
@@ -89,7 +91,8 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc
query = query.Where("expiry_date IS NOT NULL AND expiry_date > ? AND expiry_date <= ?", now, threshold)
}
if filter.Search != "" {
searchPattern := "%" + filter.Search + "%"
escaped := escapeLikeWildcards(filter.Search)
searchPattern := "%" + escaped + "%"
query = query.Where("(title ILIKE ? OR description ILIKE ?)", searchPattern, searchPattern)
}
}
@@ -169,6 +172,19 @@ func (r *DocumentRepository) CountByResidence(residenceID uint) (int64, error) {
return count, err
}
// CountByResidenceIDs counts all active documents across multiple residences in a single query.
// Returns the total count of active documents for the given residence IDs.
func (r *DocumentRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
if len(residenceIDs) == 0 {
return 0, nil
}
var count int64
err := r.db.Model(&models.Document{}).
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Count(&count).Error
return count, err
}
// FindByIDIncludingInactive finds a document by ID including inactive ones
func (r *DocumentRepository) FindByIDIncludingInactive(id uint, document *models.Document) error {
return r.db.Preload("CreatedBy").Preload("Images").First(document, id).Error

View File

@@ -0,0 +1,38 @@
package repositories
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestDocumentRepository_FindByUser_HasDefaultLimit(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewDocumentRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create 510 documents to exceed the default limit of 500
for i := 0; i < 510; i++ {
doc := &models.Document{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: fmt.Sprintf("Doc %d", i+1),
DocumentType: models.DocumentTypeGeneral,
FileURL: "https://example.com/doc.pdf",
IsActive: true,
}
err := db.Create(doc).Error
require.NoError(t, err)
}
docs, err := repo.FindByUser([]uint{residence.ID})
require.NoError(t, err)
assert.Equal(t, 500, len(docs), "FindByUser should return at most 500 documents by default")
}

View File

@@ -1,6 +1,7 @@
package repositories
import (
"errors"
"time"
"gorm.io/gorm"
@@ -130,18 +131,25 @@ func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPre
// UpdatePreferences updates notification preferences
func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error {
return r.db.Save(prefs).Error
return r.db.Omit("User").Save(prefs).Error
}
// GetOrCreatePreferences gets or creates notification preferences for a user
// GetOrCreatePreferences gets or creates notification preferences for a user.
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) {
prefs, err := r.FindPreferencesByUser(userID)
if err == nil {
return prefs, nil
}
var prefs models.NotificationPreference
if err == gorm.ErrRecordNotFound {
prefs = &models.NotificationPreference{
err := r.db.Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ?", userID).First(&prefs).Error
if err == nil {
return nil // Found existing preferences
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err // Unexpected error
}
// Record not found -- create with defaults
prefs = models.NotificationPreference{
UserID: userID,
TaskDueSoon: true,
TaskOverdue: true,
@@ -151,17 +159,36 @@ func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.No
WarrantyExpiring: true,
EmailTaskCompleted: true,
}
if err := r.CreatePreferences(prefs); err != nil {
return nil, err
}
return prefs, nil
return tx.Create(&prefs).Error
})
if err != nil {
return nil, err
}
return nil, err
return &prefs, nil
}
// === Device Registration ===
// FindAPNSDeviceByID finds an APNS device by ID
func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) {
var device models.APNSDevice
err := r.db.First(&device, id).Error
if err != nil {
return nil, err
}
return &device, nil
}
// FindGCMDeviceByID finds a GCM device by ID
func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) {
var device models.GCMDevice
err := r.db.First(&device, id).Error
if err != nil {
return nil, err
}
return &device, nil
}
// FindAPNSDeviceByToken finds an APNS device by registration token
func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) {
var device models.APNSDevice
@@ -243,12 +270,12 @@ func (r *NotificationRepository) DeactivateGCMDevice(id uint) error {
// GetActiveTokensForUser gets all active push tokens for a user
func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) {
apnsDevices, err := r.FindAPNSDevicesByUser(userID)
if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}
gcmDevices, err := r.FindGCMDevicesByUser(userID)
if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}

View File

@@ -0,0 +1,96 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestGetOrCreatePreferences_New_Creates(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// No preferences exist yet for this user
prefs, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
require.NotNil(t, prefs)
// Verify defaults were set
assert.Equal(t, user.ID, prefs.UserID)
assert.True(t, prefs.TaskDueSoon)
assert.True(t, prefs.TaskOverdue)
assert.True(t, prefs.TaskCompleted)
assert.True(t, prefs.TaskAssigned)
assert.True(t, prefs.ResidenceShared)
assert.True(t, prefs.WarrantyExpiring)
assert.True(t, prefs.EmailTaskCompleted)
// Verify it was actually persisted
var count int64
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one preferences record")
}
func TestGetOrCreatePreferences_AlreadyExists_Returns(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Create preferences manually first
existingPrefs := &models.NotificationPreference{
UserID: user.ID,
TaskDueSoon: true,
TaskOverdue: true,
TaskCompleted: true,
TaskAssigned: true,
ResidenceShared: true,
WarrantyExpiring: true,
EmailTaskCompleted: true,
}
err := db.Create(existingPrefs).Error
require.NoError(t, err)
require.NotZero(t, existingPrefs.ID)
// GetOrCreatePreferences should return the existing record, not create a new one
prefs, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
require.NotNil(t, prefs)
// The returned record should have the same ID as the existing one
assert.Equal(t, existingPrefs.ID, prefs.ID, "should return the existing record by ID")
assert.Equal(t, user.ID, prefs.UserID, "should have correct user_id")
// Verify still only one record exists (no duplicate created)
var count int64
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should still have exactly one preferences record")
}
func TestGetOrCreatePreferences_Idempotent(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewNotificationRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Call twice in succession
prefs1, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
prefs2, err := repo.GetOrCreatePreferences(user.ID)
require.NoError(t, err)
// Both should return the same record
assert.Equal(t, prefs1.ID, prefs2.ID)
// Should only have one record
var count int64
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one preferences record after two calls")
}

View File

@@ -37,6 +37,84 @@ func (r *ReminderRepository) HasSentReminder(taskID, userID uint, dueDate time.T
return count > 0, nil
}
// ReminderKey uniquely identifies a reminder that may have been sent.
type ReminderKey struct {
TaskID uint
UserID uint
DueDate time.Time
Stage models.ReminderStage
}
// HasSentReminderBatch checks which reminders from the given list have already been sent.
// Returns a set of indices into the input slice that have already been sent.
// This replaces N individual HasSentReminder calls with a single query.
func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]bool, error) {
result := make(map[int]bool)
if len(keys) == 0 {
return result, nil
}
// Build a lookup from (task_id, user_id, due_date, stage) -> index
type normalizedKey struct {
TaskID uint
UserID uint
DueDate string
Stage models.ReminderStage
}
keyToIdx := make(map[normalizedKey][]int, len(keys))
// Collect unique task IDs and user IDs for the WHERE clause
taskIDSet := make(map[uint]bool)
userIDSet := make(map[uint]bool)
for i, k := range keys {
taskIDSet[k.TaskID] = true
userIDSet[k.UserID] = true
dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC)
nk := normalizedKey{
TaskID: k.TaskID,
UserID: k.UserID,
DueDate: dueDateOnly.Format("2006-01-02"),
Stage: k.Stage,
}
keyToIdx[nk] = append(keyToIdx[nk], i)
}
taskIDs := make([]uint, 0, len(taskIDSet))
for id := range taskIDSet {
taskIDs = append(taskIDs, id)
}
userIDs := make([]uint, 0, len(userIDSet))
for id := range userIDSet {
userIDs = append(userIDs, id)
}
// Query all matching reminder logs in one query
var logs []models.TaskReminderLog
err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs).
Find(&logs).Error
if err != nil {
return nil, err
}
// Match returned logs against our key set
for _, l := range logs {
dueDateStr := l.DueDate.Format("2006-01-02")
nk := normalizedKey{
TaskID: l.TaskID,
UserID: l.UserID,
DueDate: dueDateStr,
Stage: l.ReminderStage,
}
if indices, ok := keyToIdx[nk]; ok {
for _, idx := range indices {
result[idx] = true
}
}
}
return result, nil
}
// LogReminder records that a reminder was sent.
// Returns the created log entry or an error if the reminder was already sent
// (unique constraint violation).

View File

@@ -6,6 +6,7 @@ import (
"math/big"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
@@ -269,7 +270,9 @@ func (r *ResidenceRepository) GetActiveShareCode(residenceID uint) (*models.Resi
// Check if expired
if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) {
// Auto-deactivate expired code
r.DeactivateShareCode(shareCode.ID)
if err := r.DeactivateShareCode(shareCode.ID); err != nil {
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate expired share code")
}
return nil, nil
}
@@ -296,9 +299,11 @@ func (r *ResidenceRepository) generateUniqueCode() (string, error) {
// Check if code already exists
var count int64
r.db.Model(&models.ResidenceShareCode{}).
if err := r.db.Model(&models.ResidenceShareCode{}).
Where("code = ? AND is_active = ?", codeStr, true).
Count(&count)
Count(&count).Error; err != nil {
return "", err
}
if count == 0 {
return codeStr, nil

View File

@@ -1,9 +1,11 @@
package repositories
import (
"errors"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/treytartt/casera-api/internal/models"
)
@@ -30,31 +32,37 @@ func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscrip
return &sub, nil
}
// GetOrCreate gets or creates a subscription for a user (defaults to free tier)
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
sub, err := r.FindByUserID(userID)
if err == nil {
return sub, nil
}
var sub models.UserSubscription
if err == gorm.ErrRecordNotFound {
sub = &models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
err := r.db.Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ?", userID).First(&sub).Error
if err == nil {
return nil // Found existing subscription
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err // Unexpected error
}
// Record not found -- create with free tier defaults
sub = models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
AutoRenew: true,
}
if err := r.db.Create(sub).Error; err != nil {
return nil, err
}
return sub, nil
return tx.Create(&sub).Error
})
if err != nil {
return nil, err
}
return nil, err
return &sub, nil
}
// Update updates a subscription
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
return r.db.Save(sub).Error
return r.db.Omit("User").Save(sub).Error
}
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
@@ -63,7 +71,7 @@ func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time,
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Set("gorm:query_option", "FOR UPDATE").
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
@@ -86,7 +94,7 @@ func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Lock the row for update
var sub models.UserSubscription
if err := tx.Set("gorm:query_option", "FOR UPDATE").
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("user_id = ?", userID).First(&sub).Error; err != nil {
return err
}
@@ -165,7 +173,7 @@ func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*m
var limits models.TierLimits
err := r.db.Where("tier = ?", tier).First(&limits).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return defaults
if tier == models.TierFree {
defaults := models.GetDefaultFreeLimits()
@@ -193,7 +201,7 @@ func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, er
var settings models.SubscriptionSettings
err := r.db.First(&settings).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Return default settings (limitations disabled)
return &models.SubscriptionSettings{
EnableLimitations: false,

View File

@@ -0,0 +1,79 @@
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/testutil"
)
func TestGetOrCreate_New_CreatesFreeTier(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
require.NotNil(t, sub)
assert.Equal(t, user.ID, sub.UserID)
assert.Equal(t, models.TierFree, sub.Tier)
assert.True(t, sub.AutoRenew)
// Verify persisted
var count int64
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one subscription record")
}
func TestGetOrCreate_AlreadyExists_Returns(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Create a pro subscription manually
existing := &models.UserSubscription{
UserID: user.ID,
Tier: models.TierPro,
AutoRenew: true,
}
err := db.Create(existing).Error
require.NoError(t, err)
// GetOrCreate should return existing, not overwrite with free defaults
sub, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
require.NotNil(t, sub)
assert.Equal(t, existing.ID, sub.ID, "should return the existing record by ID")
assert.Equal(t, models.TierPro, sub.Tier, "should preserve existing pro tier, not overwrite with free")
// Verify still only one record
var count int64
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should still have exactly one subscription record")
}
func TestGetOrCreate_Idempotent(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewSubscriptionRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
sub1, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
sub2, err := repo.GetOrCreate(user.ID)
require.NoError(t, err)
assert.Equal(t, sub1.ID, sub2.ID)
var count int64
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
@@ -25,6 +26,50 @@ func NewTaskRepository(db *gorm.DB) *TaskRepository {
return &TaskRepository{db: db}
}
// DB returns the underlying database connection.
// Used by services that need to run transactions spanning multiple operations.
func (r *TaskRepository) DB() *gorm.DB {
return r.db
}
// CreateCompletionTx creates a new task completion within an existing transaction.
func (r *TaskRepository) CreateCompletionTx(tx *gorm.DB, completion *models.TaskCompletion) error {
return tx.Create(completion).Error
}
// UpdateTx updates a task with optimistic locking within an existing transaction.
func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error {
result := tx.Model(task).
Where("id = ? AND version = ?", task.ID, task.Version).
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
Updates(map[string]interface{}{
"title": task.Title,
"description": task.Description,
"category_id": task.CategoryID,
"priority_id": task.PriorityID,
"frequency_id": task.FrequencyID,
"custom_interval_days": task.CustomIntervalDays,
"in_progress": task.InProgress,
"assigned_to_id": task.AssignedToID,
"due_date": task.DueDate,
"next_due_date": task.NextDueDate,
"estimated_cost": task.EstimatedCost,
"actual_cost": task.ActualCost,
"contractor_id": task.ContractorID,
"is_cancelled": task.IsCancelled,
"is_archived": task.IsArchived,
"version": gorm.Expr("version + 1"),
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrVersionConflict
}
task.Version++ // Update local copy
return nil
}
// === Task Filter Options ===
// TaskFilterOptions provides flexible filtering for task queries.
@@ -495,55 +540,39 @@ func buildKanbanColumns(
}
// GetKanbanData retrieves tasks organized for kanban display.
// Uses single-purpose query functions for each column type, ensuring consistency
// with notification handlers that use the same functions.
// Fetches all non-cancelled, non-archived tasks for the residence in a single query,
// then categorizes them in-memory using the task categorization chain for consistency
// with the predicate-based logic used throughout the application.
// The `now` parameter should be the start of day in the user's timezone for accurate overdue detection.
//
// Optimization: Preloads only minimal completion data (id, task_id, completed_at) for count/detection.
// Optimization: Single query with preloads, then in-memory categorization.
// Images and CompletedBy are NOT preloaded - fetch separately when viewing completion details.
func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int, now time.Time) (*models.KanbanBoard, error) {
opts := TaskFilterOptions{
ResidenceID: residenceID,
PreloadCreatedBy: true,
PreloadAssignedTo: true,
PreloadCompletions: true,
// Fetch all tasks for this residence in a single query (excluding cancelled/archived)
var allTasks []models.Task
query := r.db.Model(&models.Task{}).
Where("task_task.residence_id = ?", residenceID).
Preload("CreatedBy").
Preload("AssignedTo").
Preload("Completions", func(db *gorm.DB) *gorm.DB {
return db.Select("id", "task_id", "completed_at")
}).
Scopes(task.ScopeKanbanOrder)
if err := query.Find(&allTasks).Error; err != nil {
return nil, fmt.Errorf("get tasks for kanban: %w", err)
}
// Query each column using single-purpose functions
// These functions use the same scopes as notification handlers for consistency
overdue, err := r.GetOverdueTasks(now, opts)
if err != nil {
return nil, fmt.Errorf("get overdue tasks: %w", err)
}
// Categorize all tasks in-memory using the categorization chain
columnMap := categorization.CategorizeTasksIntoColumnsWithTime(allTasks, daysThreshold, now)
inProgress, err := r.GetInProgressTasks(opts)
if err != nil {
return nil, fmt.Errorf("get in-progress tasks: %w", err)
}
dueSoon, err := r.GetDueSoonTasks(now, daysThreshold, opts)
if err != nil {
return nil, fmt.Errorf("get due-soon tasks: %w", err)
}
upcoming, err := r.GetUpcomingTasks(now, daysThreshold, opts)
if err != nil {
return nil, fmt.Errorf("get upcoming tasks: %w", err)
}
completed, err := r.GetCompletedTasks(opts)
if err != nil {
return nil, fmt.Errorf("get completed tasks: %w", err)
}
// Intentionally hidden from board:
// cancelled/archived tasks are not returned as a kanban column.
// cancelled, err := r.GetCancelledTasks(opts)
// if err != nil {
// return nil, fmt.Errorf("get cancelled tasks: %w", err)
// }
columns := buildKanbanColumns(overdue, inProgress, dueSoon, upcoming, completed)
columns := buildKanbanColumns(
columnMap[categorization.ColumnOverdue],
columnMap[categorization.ColumnInProgress],
columnMap[categorization.ColumnDueSoon],
columnMap[categorization.ColumnUpcoming],
columnMap[categorization.ColumnCompleted],
)
return &models.KanbanBoard{
Columns: columns,
@@ -553,56 +582,39 @@ func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int, now
}
// GetKanbanDataForMultipleResidences retrieves tasks from multiple residences organized for kanban display.
// Uses single-purpose query functions for each column type, ensuring consistency
// with notification handlers that use the same functions.
// Fetches all tasks in a single query, then categorizes them in-memory using the
// task categorization chain for consistency with predicate-based logic.
// The `now` parameter should be the start of day in the user's timezone for accurate overdue detection.
//
// Optimization: Preloads only minimal completion data (id, task_id, completed_at) for count/detection.
// Optimization: Single query with preloads, then in-memory categorization.
// Images and CompletedBy are NOT preloaded - fetch separately when viewing completion details.
func (r *TaskRepository) GetKanbanDataForMultipleResidences(residenceIDs []uint, daysThreshold int, now time.Time) (*models.KanbanBoard, error) {
opts := TaskFilterOptions{
ResidenceIDs: residenceIDs,
PreloadCreatedBy: true,
PreloadAssignedTo: true,
PreloadResidence: true,
PreloadCompletions: true,
// Fetch all tasks for these residences in a single query (excluding cancelled/archived)
var allTasks []models.Task
query := r.db.Model(&models.Task{}).
Where("task_task.residence_id IN ?", residenceIDs).
Preload("CreatedBy").
Preload("AssignedTo").
Preload("Residence").
Preload("Completions", func(db *gorm.DB) *gorm.DB {
return db.Select("id", "task_id", "completed_at")
}).
Scopes(task.ScopeKanbanOrder)
if err := query.Find(&allTasks).Error; err != nil {
return nil, fmt.Errorf("get tasks for kanban: %w", err)
}
// Query each column using single-purpose functions
// These functions use the same scopes as notification handlers for consistency
overdue, err := r.GetOverdueTasks(now, opts)
if err != nil {
return nil, fmt.Errorf("get overdue tasks: %w", err)
}
// Categorize all tasks in-memory using the categorization chain
columnMap := categorization.CategorizeTasksIntoColumnsWithTime(allTasks, daysThreshold, now)
inProgress, err := r.GetInProgressTasks(opts)
if err != nil {
return nil, fmt.Errorf("get in-progress tasks: %w", err)
}
dueSoon, err := r.GetDueSoonTasks(now, daysThreshold, opts)
if err != nil {
return nil, fmt.Errorf("get due-soon tasks: %w", err)
}
upcoming, err := r.GetUpcomingTasks(now, daysThreshold, opts)
if err != nil {
return nil, fmt.Errorf("get upcoming tasks: %w", err)
}
completed, err := r.GetCompletedTasks(opts)
if err != nil {
return nil, fmt.Errorf("get completed tasks: %w", err)
}
// Intentionally hidden from board:
// cancelled/archived tasks are not returned as a kanban column.
// cancelled, err := r.GetCancelledTasks(opts)
// if err != nil {
// return nil, fmt.Errorf("get cancelled tasks: %w", err)
// }
columns := buildKanbanColumns(overdue, inProgress, dueSoon, upcoming, completed)
columns := buildKanbanColumns(
columnMap[categorization.ColumnOverdue],
columnMap[categorization.ColumnInProgress],
columnMap[categorization.ColumnDueSoon],
columnMap[categorization.ColumnUpcoming],
columnMap[categorization.ColumnCompleted],
)
return &models.KanbanBoard{
Columns: columns,
@@ -653,6 +665,19 @@ func (r *TaskRepository) CountByResidence(residenceID uint) (int64, error) {
return count, err
}
// CountByResidenceIDs counts all active tasks across multiple residences in a single query.
// Returns the total count of non-cancelled, non-archived tasks for the given residence IDs.
func (r *TaskRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
if len(residenceIDs) == 0 {
return 0, nil
}
var count int64
err := r.db.Model(&models.Task{}).
Where("residence_id IN ? AND is_cancelled = ? AND is_archived = ?", residenceIDs, false, false).
Count(&count).Error
return count, err
}
// === Task Completion Operations ===
// CreateCompletion creates a new task completion
@@ -705,7 +730,9 @@ func (r *TaskRepository) UpdateCompletion(completion *models.TaskCompletion) err
// DeleteCompletion deletes a task completion
func (r *TaskRepository) DeleteCompletion(id uint) error {
// Delete images first
r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{})
if err := r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil {
log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images")
}
return r.db.Delete(&models.TaskCompletion{}, id).Error
}

View File

@@ -2097,3 +2097,170 @@ func TestConsistency_OverduePredicateVsScopeVsRepo(t *testing.T) {
}
assert.Equal(t, expectedCount, len(repoTasks), "Overdue task count mismatch")
}
// TestGetKanbanData_CategorizesCorrectly verifies the single-query kanban approach
// produces correct column assignments for various task states.
func TestGetKanbanData_CategorizesCorrectly(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewTaskRepository(db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
now := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
yesterday := now.AddDate(0, 0, -1)
tomorrow := now.AddDate(0, 0, 1)
nextMonth := now.AddDate(0, 1, 0)
// Create overdue task (due yesterday)
overdueTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Overdue Task",
DueDate: &yesterday,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(overdueTask).Error)
// Create due-soon task (due tomorrow, within 30-day threshold)
dueSoonTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Due Soon Task",
DueDate: &tomorrow,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(dueSoonTask).Error)
// Create upcoming task (due next month, outside 30-day threshold)
upcomingTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Upcoming Task",
DueDate: &nextMonth,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(upcomingTask).Error)
// Create in-progress task
inProgressTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "In Progress Task",
DueDate: &tomorrow,
InProgress: true,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(inProgressTask).Error)
// Create completed task (no next due date, has completion)
completedTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Completed Task",
DueDate: &yesterday,
IsCancelled: false,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(completedTask).Error)
completion := &models.TaskCompletion{
TaskID: completedTask.ID,
CompletedByID: user.ID,
CompletedAt: now,
}
require.NoError(t, db.Create(completion).Error)
// Create cancelled task (should NOT appear in kanban columns)
cancelledTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Cancelled Task",
DueDate: &yesterday,
IsCancelled: true,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(cancelledTask).Error)
// Create archived task (should NOT appear in active kanban columns)
archivedTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "Archived Task",
DueDate: &yesterday,
IsCancelled: false,
IsArchived: true,
Version: 1,
}
require.NoError(t, db.Create(archivedTask).Error)
// Create no-due-date task (should go to upcoming)
noDueDateTask := &models.Task{
ResidenceID: residence.ID,
CreatedByID: user.ID,
Title: "No Due Date Task",
IsCancelled: false,
IsArchived: false,
Version: 1,
}
require.NoError(t, db.Create(noDueDateTask).Error)
// Execute kanban data retrieval
board, err := repo.GetKanbanData(residence.ID, 30, now)
require.NoError(t, err)
require.NotNil(t, board)
require.Len(t, board.Columns, 5, "Should have 5 visible columns")
// Build a map of column name -> task titles for easy assertion
columnTasks := make(map[string][]string)
for _, col := range board.Columns {
var titles []string
for _, task := range col.Tasks {
titles = append(titles, task.Title)
}
columnTasks[col.Name] = titles
}
// Verify overdue column
assert.Contains(t, columnTasks["overdue_tasks"], "Overdue Task",
"Overdue task should be in overdue column")
// Verify in-progress column
assert.Contains(t, columnTasks["in_progress_tasks"], "In Progress Task",
"In-progress task should be in in-progress column")
// Verify due-soon column
assert.Contains(t, columnTasks["due_soon_tasks"], "Due Soon Task",
"Due-soon task should be in due-soon column")
// Verify upcoming column contains both upcoming and no-due-date tasks
assert.Contains(t, columnTasks["upcoming_tasks"], "No Due Date Task",
"No-due-date task should be in upcoming column")
// Verify completed column
assert.Contains(t, columnTasks["completed_tasks"], "Completed Task",
"Completed task should be in completed column")
// Verify cancelled and archived tasks are categorized to the cancelled column
// (which is present in categorization but hidden from visible kanban columns)
// The cancelled/archived tasks should NOT appear in any of the 5 visible columns
allVisibleTitles := make(map[string]bool)
for _, col := range board.Columns {
for _, task := range col.Tasks {
allVisibleTitles[task.Title] = true
}
}
assert.False(t, allVisibleTitles["Cancelled Task"],
"Cancelled task should not appear in visible kanban columns")
assert.False(t, allVisibleTitles["Archived Task"],
"Archived task should not appear in visible kanban columns")
}

View File

@@ -45,7 +45,8 @@ func (r *TaskTemplateRepository) GetByCategory(categoryID uint) ([]models.TaskTe
// Search searches templates by title and tags
func (r *TaskTemplateRepository) Search(query string) ([]models.TaskTemplate, error) {
var templates []models.TaskTemplate
searchTerm := "%" + strings.ToLower(query) + "%"
escaped := escapeLikeWildcards(strings.ToLower(query))
searchTerm := "%" + escaped + "%"
err := r.db.
Preload("Category").
@@ -77,7 +78,7 @@ func (r *TaskTemplateRepository) Create(template *models.TaskTemplate) error {
// Update updates an existing task template
func (r *TaskTemplateRepository) Update(template *models.TaskTemplate) error {
return r.db.Save(template).Error
return r.db.Omit("Category", "Frequency").Save(template).Error
}
// Delete hard deletes a task template

View File

@@ -0,0 +1,11 @@
package repositories
import "strings"
// escapeLikeWildcards escapes SQL LIKE wildcard characters in user input
// to prevent users from injecting wildcards like % or _ into search queries.
func escapeLikeWildcards(s string) string {
s = strings.ReplaceAll(s, "%", "\\%")
s = strings.ReplaceAll(s, "_", "\\_")
return s
}