Harden API security: input validation, safe auth extraction, new tests, and deploy config
Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -63,7 +63,7 @@ func (r *ContractorRepository) FindByUser(userID uint, residenceIDs []uint) ([]m
|
||||
query = query.Where("residence_id IS NULL AND created_by_id = ?", userID)
|
||||
}
|
||||
|
||||
err := query.Order("is_favorite DESC, name ASC").Find(&contractors).Error
|
||||
err := query.Order("is_favorite DESC, name ASC").Limit(500).Find(&contractors).Error
|
||||
return contractors, err
|
||||
}
|
||||
|
||||
@@ -85,18 +85,31 @@ func (r *ContractorRepository) Delete(id uint) error {
|
||||
Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// ToggleFavorite toggles the favorite status of a contractor
|
||||
// ToggleFavorite toggles the favorite status of a contractor atomically.
|
||||
// Uses a single UPDATE with NOT to avoid read-then-write race conditions.
|
||||
// Only toggles active contractors to prevent toggling soft-deleted records.
|
||||
func (r *ContractorRepository) ToggleFavorite(id uint) (bool, error) {
|
||||
var contractor models.Contractor
|
||||
if err := r.db.First(&contractor, id).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
newStatus := !contractor.IsFavorite
|
||||
err := r.db.Model(&models.Contractor{}).
|
||||
Where("id = ?", id).
|
||||
Update("is_favorite", newStatus).Error
|
||||
var newStatus bool
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Atomic toggle: SET is_favorite = NOT is_favorite for active contractors only
|
||||
result := tx.Model(&models.Contractor{}).
|
||||
Where("id = ? AND is_active = ?", id, true).
|
||||
Update("is_favorite", gorm.Expr("NOT is_favorite"))
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
// Read back the new value within the same transaction
|
||||
var contractor models.Contractor
|
||||
if err := tx.Select("is_favorite").First(&contractor, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
newStatus = contractor.IsFavorite
|
||||
return nil
|
||||
})
|
||||
return newStatus, err
|
||||
}
|
||||
|
||||
@@ -145,6 +158,19 @@ func (r *ContractorRepository) CountByResidence(residenceID uint) (int64, error)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByResidenceIDs counts all active contractors across multiple residences in a single query.
|
||||
// Returns the total count of active contractors for the given residence IDs.
|
||||
func (r *ContractorRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
|
||||
if len(residenceIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
err := r.db.Model(&models.Contractor{}).
|
||||
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// === Specialty Operations ===
|
||||
|
||||
// GetAllSpecialties returns all contractor specialties
|
||||
|
||||
96
internal/repositories/contractor_repo_test.go
Normal file
96
internal/repositories/contractor_repo_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestToggleFavorite_Active_Toggles(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
// Initially is_favorite is false
|
||||
assert.False(t, contractor.IsFavorite, "contractor should start as not favorite")
|
||||
|
||||
// First toggle: false -> true
|
||||
newStatus, err := repo.ToggleFavorite(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, newStatus, "first toggle should set favorite to true")
|
||||
|
||||
// Verify in database
|
||||
var found models.Contractor
|
||||
err = db.First(&found, contractor.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.IsFavorite, "database should reflect favorite = true")
|
||||
|
||||
// Second toggle: true -> false
|
||||
newStatus, err = repo.ToggleFavorite(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, newStatus, "second toggle should set favorite to false")
|
||||
|
||||
// Verify in database
|
||||
err = db.First(&found, contractor.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.IsFavorite, "database should reflect favorite = false")
|
||||
}
|
||||
|
||||
func TestToggleFavorite_SoftDeleted_ReturnsError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Deleted Contractor")
|
||||
|
||||
// Soft-delete the contractor
|
||||
err := db.Model(&models.Contractor{}).
|
||||
Where("id = ?", contractor.ID).
|
||||
Update("is_active", false).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Toggling a soft-deleted contractor should fail (record not found)
|
||||
_, err = repo.ToggleFavorite(contractor.ID)
|
||||
assert.Error(t, err, "toggling a soft-deleted contractor should return an error")
|
||||
}
|
||||
|
||||
func TestToggleFavorite_NonExistent_ReturnsError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
_, err := repo.ToggleFavorite(99999)
|
||||
assert.Error(t, err, "toggling a non-existent contractor should return an error")
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByUser_HasDefaultLimit(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create 510 contractors to exceed the default limit of 500
|
||||
for i := 0; i < 510; i++ {
|
||||
c := &models.Contractor{
|
||||
ResidenceID: &residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Name: fmt.Sprintf("Contractor %d", i+1),
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(c).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
contractors, err := repo.FindByUser(user.ID, []uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, len(contractors), "FindByUser should return at most 500 contractors by default")
|
||||
}
|
||||
@@ -52,7 +52,8 @@ func (r *DocumentRepository) FindByResidence(residenceID uint) ([]models.Documen
|
||||
return documents, err
|
||||
}
|
||||
|
||||
// FindByUser finds all documents accessible to a user
|
||||
// FindByUser finds all documents accessible to a user.
|
||||
// A default limit of 500 is applied to prevent unbounded result sets.
|
||||
func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, error) {
|
||||
var documents []models.Document
|
||||
err := r.db.Preload("CreatedBy").
|
||||
@@ -60,6 +61,7 @@ func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document,
|
||||
Preload("Images").
|
||||
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
|
||||
Order("created_at DESC").
|
||||
Limit(500).
|
||||
Find(&documents).Error
|
||||
return documents, err
|
||||
}
|
||||
@@ -89,7 +91,8 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc
|
||||
query = query.Where("expiry_date IS NOT NULL AND expiry_date > ? AND expiry_date <= ?", now, threshold)
|
||||
}
|
||||
if filter.Search != "" {
|
||||
searchPattern := "%" + filter.Search + "%"
|
||||
escaped := escapeLikeWildcards(filter.Search)
|
||||
searchPattern := "%" + escaped + "%"
|
||||
query = query.Where("(title ILIKE ? OR description ILIKE ?)", searchPattern, searchPattern)
|
||||
}
|
||||
}
|
||||
@@ -169,6 +172,19 @@ func (r *DocumentRepository) CountByResidence(residenceID uint) (int64, error) {
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByResidenceIDs counts all active documents across multiple residences in a single query.
|
||||
// Returns the total count of active documents for the given residence IDs.
|
||||
func (r *DocumentRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
|
||||
if len(residenceIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
err := r.db.Model(&models.Document{}).
|
||||
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// FindByIDIncludingInactive finds a document by ID including inactive ones
|
||||
func (r *DocumentRepository) FindByIDIncludingInactive(id uint, document *models.Document) error {
|
||||
return r.db.Preload("CreatedBy").Preload("Images").First(document, id).Error
|
||||
|
||||
38
internal/repositories/document_repo_test.go
Normal file
38
internal/repositories/document_repo_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestDocumentRepository_FindByUser_HasDefaultLimit(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create 510 documents to exceed the default limit of 500
|
||||
for i := 0; i < 510; i++ {
|
||||
doc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: fmt.Sprintf("Doc %d", i+1),
|
||||
DocumentType: models.DocumentTypeGeneral,
|
||||
FileURL: "https://example.com/doc.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(doc).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
docs, err := repo.FindByUser([]uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, len(docs), "FindByUser should return at most 500 documents by default")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -130,18 +131,25 @@ func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPre
|
||||
|
||||
// UpdatePreferences updates notification preferences
|
||||
func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error {
|
||||
return r.db.Save(prefs).Error
|
||||
return r.db.Omit("User").Save(prefs).Error
|
||||
}
|
||||
|
||||
// GetOrCreatePreferences gets or creates notification preferences for a user
|
||||
// GetOrCreatePreferences gets or creates notification preferences for a user.
|
||||
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
|
||||
func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) {
|
||||
prefs, err := r.FindPreferencesByUser(userID)
|
||||
if err == nil {
|
||||
return prefs, nil
|
||||
}
|
||||
var prefs models.NotificationPreference
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
prefs = &models.NotificationPreference{
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Where("user_id = ?", userID).First(&prefs).Error
|
||||
if err == nil {
|
||||
return nil // Found existing preferences
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err // Unexpected error
|
||||
}
|
||||
|
||||
// Record not found -- create with defaults
|
||||
prefs = models.NotificationPreference{
|
||||
UserID: userID,
|
||||
TaskDueSoon: true,
|
||||
TaskOverdue: true,
|
||||
@@ -151,17 +159,36 @@ func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.No
|
||||
WarrantyExpiring: true,
|
||||
EmailTaskCompleted: true,
|
||||
}
|
||||
if err := r.CreatePreferences(prefs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return prefs, nil
|
||||
return tx.Create(&prefs).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return &prefs, nil
|
||||
}
|
||||
|
||||
// === Device Registration ===
|
||||
|
||||
// FindAPNSDeviceByID finds an APNS device by ID
|
||||
func (r *NotificationRepository) FindAPNSDeviceByID(id uint) (*models.APNSDevice, error) {
|
||||
var device models.APNSDevice
|
||||
err := r.db.First(&device, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// FindGCMDeviceByID finds a GCM device by ID
|
||||
func (r *NotificationRepository) FindGCMDeviceByID(id uint) (*models.GCMDevice, error) {
|
||||
var device models.GCMDevice
|
||||
err := r.db.First(&device, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// FindAPNSDeviceByToken finds an APNS device by registration token
|
||||
func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) {
|
||||
var device models.APNSDevice
|
||||
@@ -243,12 +270,12 @@ func (r *NotificationRepository) DeactivateGCMDevice(id uint) error {
|
||||
// GetActiveTokensForUser gets all active push tokens for a user
|
||||
func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) {
|
||||
apnsDevices, err := r.FindAPNSDevicesByUser(userID)
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
gcmDevices, err := r.FindGCMDevicesByUser(userID)
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
||||
96
internal/repositories/notification_repo_test.go
Normal file
96
internal/repositories/notification_repo_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestGetOrCreatePreferences_New_Creates(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// No preferences exist yet for this user
|
||||
prefs, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, prefs)
|
||||
|
||||
// Verify defaults were set
|
||||
assert.Equal(t, user.ID, prefs.UserID)
|
||||
assert.True(t, prefs.TaskDueSoon)
|
||||
assert.True(t, prefs.TaskOverdue)
|
||||
assert.True(t, prefs.TaskCompleted)
|
||||
assert.True(t, prefs.TaskAssigned)
|
||||
assert.True(t, prefs.ResidenceShared)
|
||||
assert.True(t, prefs.WarrantyExpiring)
|
||||
assert.True(t, prefs.EmailTaskCompleted)
|
||||
|
||||
// Verify it was actually persisted
|
||||
var count int64
|
||||
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one preferences record")
|
||||
}
|
||||
|
||||
func TestGetOrCreatePreferences_AlreadyExists_Returns(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Create preferences manually first
|
||||
existingPrefs := &models.NotificationPreference{
|
||||
UserID: user.ID,
|
||||
TaskDueSoon: true,
|
||||
TaskOverdue: true,
|
||||
TaskCompleted: true,
|
||||
TaskAssigned: true,
|
||||
ResidenceShared: true,
|
||||
WarrantyExpiring: true,
|
||||
EmailTaskCompleted: true,
|
||||
}
|
||||
err := db.Create(existingPrefs).Error
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, existingPrefs.ID)
|
||||
|
||||
// GetOrCreatePreferences should return the existing record, not create a new one
|
||||
prefs, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, prefs)
|
||||
|
||||
// The returned record should have the same ID as the existing one
|
||||
assert.Equal(t, existingPrefs.ID, prefs.ID, "should return the existing record by ID")
|
||||
assert.Equal(t, user.ID, prefs.UserID, "should have correct user_id")
|
||||
|
||||
// Verify still only one record exists (no duplicate created)
|
||||
var count int64
|
||||
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should still have exactly one preferences record")
|
||||
}
|
||||
|
||||
func TestGetOrCreatePreferences_Idempotent(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Call twice in succession
|
||||
prefs1, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
prefs2, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both should return the same record
|
||||
assert.Equal(t, prefs1.ID, prefs2.ID)
|
||||
|
||||
// Should only have one record
|
||||
var count int64
|
||||
db.Model(&models.NotificationPreference{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one preferences record after two calls")
|
||||
}
|
||||
@@ -37,6 +37,84 @@ func (r *ReminderRepository) HasSentReminder(taskID, userID uint, dueDate time.T
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ReminderKey uniquely identifies a reminder that may have been sent.
|
||||
type ReminderKey struct {
|
||||
TaskID uint
|
||||
UserID uint
|
||||
DueDate time.Time
|
||||
Stage models.ReminderStage
|
||||
}
|
||||
|
||||
// HasSentReminderBatch checks which reminders from the given list have already been sent.
|
||||
// Returns a set of indices into the input slice that have already been sent.
|
||||
// This replaces N individual HasSentReminder calls with a single query.
|
||||
func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]bool, error) {
|
||||
result := make(map[int]bool)
|
||||
if len(keys) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Build a lookup from (task_id, user_id, due_date, stage) -> index
|
||||
type normalizedKey struct {
|
||||
TaskID uint
|
||||
UserID uint
|
||||
DueDate string
|
||||
Stage models.ReminderStage
|
||||
}
|
||||
keyToIdx := make(map[normalizedKey][]int, len(keys))
|
||||
|
||||
// Collect unique task IDs and user IDs for the WHERE clause
|
||||
taskIDSet := make(map[uint]bool)
|
||||
userIDSet := make(map[uint]bool)
|
||||
for i, k := range keys {
|
||||
taskIDSet[k.TaskID] = true
|
||||
userIDSet[k.UserID] = true
|
||||
dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC)
|
||||
nk := normalizedKey{
|
||||
TaskID: k.TaskID,
|
||||
UserID: k.UserID,
|
||||
DueDate: dueDateOnly.Format("2006-01-02"),
|
||||
Stage: k.Stage,
|
||||
}
|
||||
keyToIdx[nk] = append(keyToIdx[nk], i)
|
||||
}
|
||||
|
||||
taskIDs := make([]uint, 0, len(taskIDSet))
|
||||
for id := range taskIDSet {
|
||||
taskIDs = append(taskIDs, id)
|
||||
}
|
||||
userIDs := make([]uint, 0, len(userIDSet))
|
||||
for id := range userIDSet {
|
||||
userIDs = append(userIDs, id)
|
||||
}
|
||||
|
||||
// Query all matching reminder logs in one query
|
||||
var logs []models.TaskReminderLog
|
||||
err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs).
|
||||
Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Match returned logs against our key set
|
||||
for _, l := range logs {
|
||||
dueDateStr := l.DueDate.Format("2006-01-02")
|
||||
nk := normalizedKey{
|
||||
TaskID: l.TaskID,
|
||||
UserID: l.UserID,
|
||||
DueDate: dueDateStr,
|
||||
Stage: l.ReminderStage,
|
||||
}
|
||||
if indices, ok := keyToIdx[nk]; ok {
|
||||
for _, idx := range indices {
|
||||
result[idx] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// LogReminder records that a reminder was sent.
|
||||
// Returns the created log entry or an error if the reminder was already sent
|
||||
// (unique constraint violation).
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
@@ -269,7 +270,9 @@ func (r *ResidenceRepository) GetActiveShareCode(residenceID uint) (*models.Resi
|
||||
// Check if expired
|
||||
if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) {
|
||||
// Auto-deactivate expired code
|
||||
r.DeactivateShareCode(shareCode.ID)
|
||||
if err := r.DeactivateShareCode(shareCode.ID); err != nil {
|
||||
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate expired share code")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -296,9 +299,11 @@ func (r *ResidenceRepository) generateUniqueCode() (string, error) {
|
||||
|
||||
// Check if code already exists
|
||||
var count int64
|
||||
r.db.Model(&models.ResidenceShareCode{}).
|
||||
if err := r.db.Model(&models.ResidenceShareCode{}).
|
||||
Where("code = ? AND is_active = ?", codeStr, true).
|
||||
Count(&count)
|
||||
Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return codeStr, nil
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
)
|
||||
@@ -30,31 +32,37 @@ func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscrip
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetOrCreate gets or creates a subscription for a user (defaults to free tier)
|
||||
// GetOrCreate gets or creates a subscription for a user (defaults to free tier).
|
||||
// Uses a transaction to avoid TOCTOU race conditions on concurrent requests.
|
||||
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
|
||||
sub, err := r.FindByUserID(userID)
|
||||
if err == nil {
|
||||
return sub, nil
|
||||
}
|
||||
var sub models.UserSubscription
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
sub = &models.UserSubscription{
|
||||
UserID: userID,
|
||||
Tier: models.TierFree,
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Where("user_id = ?", userID).First(&sub).Error
|
||||
if err == nil {
|
||||
return nil // Found existing subscription
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err // Unexpected error
|
||||
}
|
||||
|
||||
// Record not found -- create with free tier defaults
|
||||
sub = models.UserSubscription{
|
||||
UserID: userID,
|
||||
Tier: models.TierFree,
|
||||
AutoRenew: true,
|
||||
}
|
||||
if err := r.db.Create(sub).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sub, nil
|
||||
return tx.Create(&sub).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Update updates a subscription
|
||||
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
|
||||
return r.db.Save(sub).Error
|
||||
return r.db.Omit("User").Save(sub).Error
|
||||
}
|
||||
|
||||
// UpgradeToPro upgrades a user to Pro tier using a transaction with row locking
|
||||
@@ -63,7 +71,7 @@ func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time,
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Lock the row for update
|
||||
var sub models.UserSubscription
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("user_id = ?", userID).First(&sub).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -86,7 +94,7 @@ func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Lock the row for update
|
||||
var sub models.UserSubscription
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("user_id = ?", userID).First(&sub).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -165,7 +173,7 @@ func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*m
|
||||
var limits models.TierLimits
|
||||
err := r.db.Where("tier = ?", tier).First(&limits).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Return defaults
|
||||
if tier == models.TierFree {
|
||||
defaults := models.GetDefaultFreeLimits()
|
||||
@@ -193,7 +201,7 @@ func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, er
|
||||
var settings models.SubscriptionSettings
|
||||
err := r.db.First(&settings).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Return default settings (limitations disabled)
|
||||
return &models.SubscriptionSettings{
|
||||
EnableLimitations: false,
|
||||
|
||||
79
internal/repositories/subscription_repo_test.go
Normal file
79
internal/repositories/subscription_repo_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestGetOrCreate_New_CreatesFreeTier(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
sub, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sub)
|
||||
|
||||
assert.Equal(t, user.ID, sub.UserID)
|
||||
assert.Equal(t, models.TierFree, sub.Tier)
|
||||
assert.True(t, sub.AutoRenew)
|
||||
|
||||
// Verify persisted
|
||||
var count int64
|
||||
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one subscription record")
|
||||
}
|
||||
|
||||
func TestGetOrCreate_AlreadyExists_Returns(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Create a pro subscription manually
|
||||
existing := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
AutoRenew: true,
|
||||
}
|
||||
err := db.Create(existing).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// GetOrCreate should return existing, not overwrite with free defaults
|
||||
sub, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sub)
|
||||
|
||||
assert.Equal(t, existing.ID, sub.ID, "should return the existing record by ID")
|
||||
assert.Equal(t, models.TierPro, sub.Tier, "should preserve existing pro tier, not overwrite with free")
|
||||
|
||||
// Verify still only one record
|
||||
var count int64
|
||||
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should still have exactly one subscription record")
|
||||
}
|
||||
|
||||
func TestGetOrCreate_Idempotent(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
sub1, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
sub2, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, sub1.ID, sub2.ID)
|
||||
|
||||
var count int64
|
||||
db.Model(&models.UserSubscription{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "should have exactly one subscription record after two calls")
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
@@ -25,6 +26,50 @@ func NewTaskRepository(db *gorm.DB) *TaskRepository {
|
||||
return &TaskRepository{db: db}
|
||||
}
|
||||
|
||||
// DB returns the underlying database connection.
|
||||
// Used by services that need to run transactions spanning multiple operations.
|
||||
func (r *TaskRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
// CreateCompletionTx creates a new task completion within an existing transaction.
|
||||
func (r *TaskRepository) CreateCompletionTx(tx *gorm.DB, completion *models.TaskCompletion) error {
|
||||
return tx.Create(completion).Error
|
||||
}
|
||||
|
||||
// UpdateTx updates a task with optimistic locking within an existing transaction.
|
||||
func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error {
|
||||
result := tx.Model(task).
|
||||
Where("id = ? AND version = ?", task.ID, task.Version).
|
||||
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
|
||||
Updates(map[string]interface{}{
|
||||
"title": task.Title,
|
||||
"description": task.Description,
|
||||
"category_id": task.CategoryID,
|
||||
"priority_id": task.PriorityID,
|
||||
"frequency_id": task.FrequencyID,
|
||||
"custom_interval_days": task.CustomIntervalDays,
|
||||
"in_progress": task.InProgress,
|
||||
"assigned_to_id": task.AssignedToID,
|
||||
"due_date": task.DueDate,
|
||||
"next_due_date": task.NextDueDate,
|
||||
"estimated_cost": task.EstimatedCost,
|
||||
"actual_cost": task.ActualCost,
|
||||
"contractor_id": task.ContractorID,
|
||||
"is_cancelled": task.IsCancelled,
|
||||
"is_archived": task.IsArchived,
|
||||
"version": gorm.Expr("version + 1"),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return ErrVersionConflict
|
||||
}
|
||||
task.Version++ // Update local copy
|
||||
return nil
|
||||
}
|
||||
|
||||
// === Task Filter Options ===
|
||||
|
||||
// TaskFilterOptions provides flexible filtering for task queries.
|
||||
@@ -495,55 +540,39 @@ func buildKanbanColumns(
|
||||
}
|
||||
|
||||
// GetKanbanData retrieves tasks organized for kanban display.
|
||||
// Uses single-purpose query functions for each column type, ensuring consistency
|
||||
// with notification handlers that use the same functions.
|
||||
// Fetches all non-cancelled, non-archived tasks for the residence in a single query,
|
||||
// then categorizes them in-memory using the task categorization chain for consistency
|
||||
// with the predicate-based logic used throughout the application.
|
||||
// The `now` parameter should be the start of day in the user's timezone for accurate overdue detection.
|
||||
//
|
||||
// Optimization: Preloads only minimal completion data (id, task_id, completed_at) for count/detection.
|
||||
// Optimization: Single query with preloads, then in-memory categorization.
|
||||
// Images and CompletedBy are NOT preloaded - fetch separately when viewing completion details.
|
||||
func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int, now time.Time) (*models.KanbanBoard, error) {
|
||||
opts := TaskFilterOptions{
|
||||
ResidenceID: residenceID,
|
||||
PreloadCreatedBy: true,
|
||||
PreloadAssignedTo: true,
|
||||
PreloadCompletions: true,
|
||||
// Fetch all tasks for this residence in a single query (excluding cancelled/archived)
|
||||
var allTasks []models.Task
|
||||
query := r.db.Model(&models.Task{}).
|
||||
Where("task_task.residence_id = ?", residenceID).
|
||||
Preload("CreatedBy").
|
||||
Preload("AssignedTo").
|
||||
Preload("Completions", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Select("id", "task_id", "completed_at")
|
||||
}).
|
||||
Scopes(task.ScopeKanbanOrder)
|
||||
|
||||
if err := query.Find(&allTasks).Error; err != nil {
|
||||
return nil, fmt.Errorf("get tasks for kanban: %w", err)
|
||||
}
|
||||
|
||||
// Query each column using single-purpose functions
|
||||
// These functions use the same scopes as notification handlers for consistency
|
||||
overdue, err := r.GetOverdueTasks(now, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get overdue tasks: %w", err)
|
||||
}
|
||||
// Categorize all tasks in-memory using the categorization chain
|
||||
columnMap := categorization.CategorizeTasksIntoColumnsWithTime(allTasks, daysThreshold, now)
|
||||
|
||||
inProgress, err := r.GetInProgressTasks(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get in-progress tasks: %w", err)
|
||||
}
|
||||
|
||||
dueSoon, err := r.GetDueSoonTasks(now, daysThreshold, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get due-soon tasks: %w", err)
|
||||
}
|
||||
|
||||
upcoming, err := r.GetUpcomingTasks(now, daysThreshold, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get upcoming tasks: %w", err)
|
||||
}
|
||||
|
||||
completed, err := r.GetCompletedTasks(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get completed tasks: %w", err)
|
||||
}
|
||||
|
||||
// Intentionally hidden from board:
|
||||
// cancelled/archived tasks are not returned as a kanban column.
|
||||
// cancelled, err := r.GetCancelledTasks(opts)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("get cancelled tasks: %w", err)
|
||||
// }
|
||||
|
||||
columns := buildKanbanColumns(overdue, inProgress, dueSoon, upcoming, completed)
|
||||
columns := buildKanbanColumns(
|
||||
columnMap[categorization.ColumnOverdue],
|
||||
columnMap[categorization.ColumnInProgress],
|
||||
columnMap[categorization.ColumnDueSoon],
|
||||
columnMap[categorization.ColumnUpcoming],
|
||||
columnMap[categorization.ColumnCompleted],
|
||||
)
|
||||
|
||||
return &models.KanbanBoard{
|
||||
Columns: columns,
|
||||
@@ -553,56 +582,39 @@ func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int, now
|
||||
}
|
||||
|
||||
// GetKanbanDataForMultipleResidences retrieves tasks from multiple residences organized for kanban display.
|
||||
// Uses single-purpose query functions for each column type, ensuring consistency
|
||||
// with notification handlers that use the same functions.
|
||||
// Fetches all tasks in a single query, then categorizes them in-memory using the
|
||||
// task categorization chain for consistency with predicate-based logic.
|
||||
// The `now` parameter should be the start of day in the user's timezone for accurate overdue detection.
|
||||
//
|
||||
// Optimization: Preloads only minimal completion data (id, task_id, completed_at) for count/detection.
|
||||
// Optimization: Single query with preloads, then in-memory categorization.
|
||||
// Images and CompletedBy are NOT preloaded - fetch separately when viewing completion details.
|
||||
func (r *TaskRepository) GetKanbanDataForMultipleResidences(residenceIDs []uint, daysThreshold int, now time.Time) (*models.KanbanBoard, error) {
|
||||
opts := TaskFilterOptions{
|
||||
ResidenceIDs: residenceIDs,
|
||||
PreloadCreatedBy: true,
|
||||
PreloadAssignedTo: true,
|
||||
PreloadResidence: true,
|
||||
PreloadCompletions: true,
|
||||
// Fetch all tasks for these residences in a single query (excluding cancelled/archived)
|
||||
var allTasks []models.Task
|
||||
query := r.db.Model(&models.Task{}).
|
||||
Where("task_task.residence_id IN ?", residenceIDs).
|
||||
Preload("CreatedBy").
|
||||
Preload("AssignedTo").
|
||||
Preload("Residence").
|
||||
Preload("Completions", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Select("id", "task_id", "completed_at")
|
||||
}).
|
||||
Scopes(task.ScopeKanbanOrder)
|
||||
|
||||
if err := query.Find(&allTasks).Error; err != nil {
|
||||
return nil, fmt.Errorf("get tasks for kanban: %w", err)
|
||||
}
|
||||
|
||||
// Query each column using single-purpose functions
|
||||
// These functions use the same scopes as notification handlers for consistency
|
||||
overdue, err := r.GetOverdueTasks(now, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get overdue tasks: %w", err)
|
||||
}
|
||||
// Categorize all tasks in-memory using the categorization chain
|
||||
columnMap := categorization.CategorizeTasksIntoColumnsWithTime(allTasks, daysThreshold, now)
|
||||
|
||||
inProgress, err := r.GetInProgressTasks(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get in-progress tasks: %w", err)
|
||||
}
|
||||
|
||||
dueSoon, err := r.GetDueSoonTasks(now, daysThreshold, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get due-soon tasks: %w", err)
|
||||
}
|
||||
|
||||
upcoming, err := r.GetUpcomingTasks(now, daysThreshold, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get upcoming tasks: %w", err)
|
||||
}
|
||||
|
||||
completed, err := r.GetCompletedTasks(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get completed tasks: %w", err)
|
||||
}
|
||||
|
||||
// Intentionally hidden from board:
|
||||
// cancelled/archived tasks are not returned as a kanban column.
|
||||
// cancelled, err := r.GetCancelledTasks(opts)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("get cancelled tasks: %w", err)
|
||||
// }
|
||||
|
||||
columns := buildKanbanColumns(overdue, inProgress, dueSoon, upcoming, completed)
|
||||
columns := buildKanbanColumns(
|
||||
columnMap[categorization.ColumnOverdue],
|
||||
columnMap[categorization.ColumnInProgress],
|
||||
columnMap[categorization.ColumnDueSoon],
|
||||
columnMap[categorization.ColumnUpcoming],
|
||||
columnMap[categorization.ColumnCompleted],
|
||||
)
|
||||
|
||||
return &models.KanbanBoard{
|
||||
Columns: columns,
|
||||
@@ -653,6 +665,19 @@ func (r *TaskRepository) CountByResidence(residenceID uint) (int64, error) {
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByResidenceIDs counts all active tasks across multiple residences in a single query.
|
||||
// Returns the total count of non-cancelled, non-archived tasks for the given residence IDs.
|
||||
func (r *TaskRepository) CountByResidenceIDs(residenceIDs []uint) (int64, error) {
|
||||
if len(residenceIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var count int64
|
||||
err := r.db.Model(&models.Task{}).
|
||||
Where("residence_id IN ? AND is_cancelled = ? AND is_archived = ?", residenceIDs, false, false).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// === Task Completion Operations ===
|
||||
|
||||
// CreateCompletion creates a new task completion
|
||||
@@ -705,7 +730,9 @@ func (r *TaskRepository) UpdateCompletion(completion *models.TaskCompletion) err
|
||||
// DeleteCompletion deletes a task completion
|
||||
func (r *TaskRepository) DeleteCompletion(id uint) error {
|
||||
// Delete images first
|
||||
r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{})
|
||||
if err := r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil {
|
||||
log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images")
|
||||
}
|
||||
return r.db.Delete(&models.TaskCompletion{}, id).Error
|
||||
}
|
||||
|
||||
|
||||
@@ -2097,3 +2097,170 @@ func TestConsistency_OverduePredicateVsScopeVsRepo(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(repoTasks), "Overdue task count mismatch")
|
||||
}
|
||||
|
||||
// TestGetKanbanData_CategorizesCorrectly verifies the single-query kanban approach
|
||||
// produces correct column assignments for various task states.
|
||||
func TestGetKanbanData_CategorizesCorrectly(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
now := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
|
||||
yesterday := now.AddDate(0, 0, -1)
|
||||
tomorrow := now.AddDate(0, 0, 1)
|
||||
nextMonth := now.AddDate(0, 1, 0)
|
||||
|
||||
// Create overdue task (due yesterday)
|
||||
overdueTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Overdue Task",
|
||||
DueDate: &yesterday,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(overdueTask).Error)
|
||||
|
||||
// Create due-soon task (due tomorrow, within 30-day threshold)
|
||||
dueSoonTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Due Soon Task",
|
||||
DueDate: &tomorrow,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(dueSoonTask).Error)
|
||||
|
||||
// Create upcoming task (due next month, outside 30-day threshold)
|
||||
upcomingTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Upcoming Task",
|
||||
DueDate: &nextMonth,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(upcomingTask).Error)
|
||||
|
||||
// Create in-progress task
|
||||
inProgressTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "In Progress Task",
|
||||
DueDate: &tomorrow,
|
||||
InProgress: true,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(inProgressTask).Error)
|
||||
|
||||
// Create completed task (no next due date, has completion)
|
||||
completedTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Completed Task",
|
||||
DueDate: &yesterday,
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(completedTask).Error)
|
||||
completion := &models.TaskCompletion{
|
||||
TaskID: completedTask.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: now,
|
||||
}
|
||||
require.NoError(t, db.Create(completion).Error)
|
||||
|
||||
// Create cancelled task (should NOT appear in kanban columns)
|
||||
cancelledTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Cancelled Task",
|
||||
DueDate: &yesterday,
|
||||
IsCancelled: true,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(cancelledTask).Error)
|
||||
|
||||
// Create archived task (should NOT appear in active kanban columns)
|
||||
archivedTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Archived Task",
|
||||
DueDate: &yesterday,
|
||||
IsCancelled: false,
|
||||
IsArchived: true,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(archivedTask).Error)
|
||||
|
||||
// Create no-due-date task (should go to upcoming)
|
||||
noDueDateTask := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "No Due Date Task",
|
||||
IsCancelled: false,
|
||||
IsArchived: false,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(noDueDateTask).Error)
|
||||
|
||||
// Execute kanban data retrieval
|
||||
board, err := repo.GetKanbanData(residence.ID, 30, now)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, board)
|
||||
require.Len(t, board.Columns, 5, "Should have 5 visible columns")
|
||||
|
||||
// Build a map of column name -> task titles for easy assertion
|
||||
columnTasks := make(map[string][]string)
|
||||
for _, col := range board.Columns {
|
||||
var titles []string
|
||||
for _, task := range col.Tasks {
|
||||
titles = append(titles, task.Title)
|
||||
}
|
||||
columnTasks[col.Name] = titles
|
||||
}
|
||||
|
||||
// Verify overdue column
|
||||
assert.Contains(t, columnTasks["overdue_tasks"], "Overdue Task",
|
||||
"Overdue task should be in overdue column")
|
||||
|
||||
// Verify in-progress column
|
||||
assert.Contains(t, columnTasks["in_progress_tasks"], "In Progress Task",
|
||||
"In-progress task should be in in-progress column")
|
||||
|
||||
// Verify due-soon column
|
||||
assert.Contains(t, columnTasks["due_soon_tasks"], "Due Soon Task",
|
||||
"Due-soon task should be in due-soon column")
|
||||
|
||||
// Verify upcoming column contains both upcoming and no-due-date tasks
|
||||
assert.Contains(t, columnTasks["upcoming_tasks"], "No Due Date Task",
|
||||
"No-due-date task should be in upcoming column")
|
||||
|
||||
// Verify completed column
|
||||
assert.Contains(t, columnTasks["completed_tasks"], "Completed Task",
|
||||
"Completed task should be in completed column")
|
||||
|
||||
// Verify cancelled and archived tasks are categorized to the cancelled column
|
||||
// (which is present in categorization but hidden from visible kanban columns)
|
||||
// The cancelled/archived tasks should NOT appear in any of the 5 visible columns
|
||||
allVisibleTitles := make(map[string]bool)
|
||||
for _, col := range board.Columns {
|
||||
for _, task := range col.Tasks {
|
||||
allVisibleTitles[task.Title] = true
|
||||
}
|
||||
}
|
||||
assert.False(t, allVisibleTitles["Cancelled Task"],
|
||||
"Cancelled task should not appear in visible kanban columns")
|
||||
assert.False(t, allVisibleTitles["Archived Task"],
|
||||
"Archived task should not appear in visible kanban columns")
|
||||
}
|
||||
|
||||
@@ -45,7 +45,8 @@ func (r *TaskTemplateRepository) GetByCategory(categoryID uint) ([]models.TaskTe
|
||||
// Search searches templates by title and tags
|
||||
func (r *TaskTemplateRepository) Search(query string) ([]models.TaskTemplate, error) {
|
||||
var templates []models.TaskTemplate
|
||||
searchTerm := "%" + strings.ToLower(query) + "%"
|
||||
escaped := escapeLikeWildcards(strings.ToLower(query))
|
||||
searchTerm := "%" + escaped + "%"
|
||||
|
||||
err := r.db.
|
||||
Preload("Category").
|
||||
@@ -77,7 +78,7 @@ func (r *TaskTemplateRepository) Create(template *models.TaskTemplate) error {
|
||||
|
||||
// Update updates an existing task template
|
||||
func (r *TaskTemplateRepository) Update(template *models.TaskTemplate) error {
|
||||
return r.db.Save(template).Error
|
||||
return r.db.Omit("Category", "Frequency").Save(template).Error
|
||||
}
|
||||
|
||||
// Delete hard deletes a task template
|
||||
|
||||
11
internal/repositories/util.go
Normal file
11
internal/repositories/util.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package repositories
|
||||
|
||||
import "strings"
|
||||
|
||||
// escapeLikeWildcards escapes SQL LIKE wildcard characters in user input
|
||||
// to prevent users from injecting wildcards like % or _ into search queries.
|
||||
func escapeLikeWildcards(s string) string {
|
||||
s = strings.ReplaceAll(s, "%", "\\%")
|
||||
s = strings.ReplaceAll(s, "_", "\\_")
|
||||
return s
|
||||
}
|
||||
Reference in New Issue
Block a user