Fix 113 hardening issues across entire Go backend

Security:
- Replace all binding: tags with validate: + c.Validate() in admin handlers
- Add rate limiting to auth endpoints (login, register, password reset)
- Add security headers (HSTS, XSS protection, nosniff, frame options)
- Wire Google Pub/Sub token verification into webhook handler
- Replace ParseUnverified with proper OIDC/JWKS key verification
- Verify inner Apple JWS signatures in webhook handler
- Add io.LimitReader (1MB) to all webhook body reads
- Add ownership verification to file deletion
- Move hardcoded admin credentials to env vars
- Add uniqueIndex to User.Email
- Hide ConfirmationCode from JSON serialization
- Mask confirmation codes in admin responses
- Use http.DetectContentType for upload validation
- Fix path traversal in storage service
- Replace os.Getenv with Viper in stripe service
- Sanitize Redis URLs before logging
- Separate DEBUG_FIXED_CODES from DEBUG flag
- Reject weak SECRET_KEY in production
- Add host check on /_next/* proxy routes
- Use explicit localhost CORS origins in debug mode
- Replace err.Error() with generic messages in all admin error responses

Critical fixes:
- Rewrite FCM to HTTP v1 API with OAuth 2.0 service account auth
- Fix user_customuser -> auth_user table names in raw SQL
- Fix dashboard verified query to use UserProfile model
- Add escapeLikeWildcards() to prevent SQL wildcard injection

Bug fixes:
- Add bounds checks for days/expiring_soon query params (1-3650)
- Add receipt_data/transaction_id empty-check to RestoreSubscription
- Change Active bool -> *bool in device handler
- Check all unchecked GORM/FindByIDWithProfile errors
- Add validation for notification hour fields (0-23)
- Add max=10000 validation on task description updates

Transactions & data integrity:
- Wrap registration flow in transaction
- Wrap QuickComplete in transaction
- Move image creation inside completion transaction
- Wrap SetSpecialties in transaction
- Wrap GetOrCreateToken in transaction
- Wrap completion+image deletion in transaction

Performance:
- Batch completion summaries (2 queries vs 2N)
- Reuse single http.Client in IAP validation
- Cache dashboard counts (30s TTL)
- Batch COUNT queries in admin user list
- Add Limit(500) to document queries
- Add reminder_stage+due_date filters to reminder queries
- Parse AllowedTypes once at init
- In-memory user cache in auth middleware (30s TTL)
- Timezone change detection cache
- Optimize P95 with per-endpoint sorted buffers
- Replace crypto/md5 with hash/fnv for ETags

Code quality:
- Add sync.Once to all monitoring Stop()/Close() methods
- Replace 8 fmt.Printf with zerolog in auth service
- Log previously discarded errors
- Standardize delete response shapes
- Route hardcoded English through i18n
- Remove FileURL from DocumentResponse (keep MediaURL only)
- Thread user timezone through kanban board responses
- Initialize empty slices to prevent null JSON
- Extract shared field map for task Update/UpdateTx
- Delete unused SoftDeleteModel, min(), formatCron, legacy handlers

Worker & jobs:
- Wire Asynq email infrastructure into worker
- Register HandleReminderLogCleanup with daily 3AM cron
- Use per-user timezone in HandleSmartReminder
- Replace direct DB queries with repository calls
- Delete legacy reminder handlers (~200 lines)
- Delete unused task type constants

Dependencies:
- Replace archived jung-kurt/gofpdf with go-pdf/fpdf
- Replace unmaintained gomail.v2 with wneessen/go-mail
- Add TODO for Echo jwt v3 transitive dep removal

Test infrastructure:
- Fix MakeRequest/SeedLookupData error handling
- Replace os.Exit(0) with t.Skip() in scope/consistency tests
- Add 11 new FCM v1 tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-18 23:14:13 -05:00
parent 3b86d0aae1
commit 42a5533a56
95 changed files with 2892 additions and 1783 deletions

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"github.com/treytartt/honeydue-api/internal/apperrors"
@@ -90,7 +91,7 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
// Update last login
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
// Log error but don't fail the login
fmt.Printf("Failed to update last login: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
}
return &responses.LoginResponse{
@@ -99,7 +100,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
}, nil
}
// Register creates a new user account
// Register creates a new user account.
// F-10: User creation, profile creation, notification preferences, and confirmation code
// are wrapped in a transaction for atomicity.
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
// Check if username exists
exists, err := s.userRepo.ExistsByUsername(req.Username)
@@ -133,43 +136,49 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
return nil, "", apperrors.Internal(err)
}
// Save user
if err := s.userRepo.Create(user); err != nil {
return nil, "", apperrors.Internal(err)
}
// Create user profile
if _, err := s.userRepo.GetOrCreateProfile(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create user profile: %v\n", err)
}
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create notification preferences: %v\n", err)
}
}
// Create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, "", apperrors.Internal(err)
}
// Generate confirmation code - use fixed code in debug mode for easier local testing
// Generate confirmation code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
var code string
if s.cfg.Server.Debug {
if s.cfg.Server.DebugFixedCodes {
code = "123456"
} else {
code = generateSixDigitCode()
}
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create confirmation code: %v\n", err)
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
// Save user
if err := txRepo.Create(user); err != nil {
return err
}
// Create user profile
if _, err := txRepo.GetOrCreateProfile(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile during registration")
}
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
}
}
// Create confirmation code
if _, err := txRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create confirmation code during registration")
}
return nil
})
if txErr != nil {
return nil, "", apperrors.Internal(txErr)
}
// Create auth token (outside transaction since token generation is idempotent)
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, "", apperrors.Internal(err)
}
return &responses.RegisterResponse{
@@ -248,8 +257,8 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
return apperrors.BadRequest("error.email_already_verified")
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
// Check for test code when DEBUG_FIXED_CODES is enabled
if s.cfg.Server.DebugFixedCodes && code == "123456" {
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err)
}
@@ -294,9 +303,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
return "", apperrors.BadRequest("error.email_already_verified")
}
// Generate new code - use fixed code in debug mode for easier local testing
// Generate new code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
var code string
if s.cfg.Server.Debug {
if s.cfg.Server.DebugFixedCodes {
code = "123456"
} else {
code = generateSixDigitCode()
@@ -331,9 +340,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded")
}
// Generate code and reset token - use fixed code in debug mode for easier local testing
// Generate code and reset token - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
var code string
if s.cfg.Server.Debug {
if s.cfg.Server.DebugFixedCodes {
code = "123456"
} else {
code = generateSixDigitCode()
@@ -365,8 +374,8 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
return "", apperrors.Internal(err)
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
// Check for test code when DEBUG_FIXED_CODES is enabled
if s.cfg.Server.DebugFixedCodes && code == "123456" {
return resetCode.ResetToken, nil
}
@@ -422,13 +431,13 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
// Mark reset code as used
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
// Log error but don't fail
fmt.Printf("Failed to mark reset code as used: %v\n", err)
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
}
// Invalidate all existing tokens for this user (security measure)
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
// Log error but don't fail
fmt.Printf("Failed to delete user tokens: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
}
return nil
@@ -482,6 +491,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
if email != "" {
existingUser, err := s.userRepo.FindByEmail(email)
if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user
log.Warn().
Str("email", email).
Str("provider", "apple").
Uint("user_id", existingUser.ID).
Msg("Auto-linking social account to existing user by email match")
// Link Apple ID to existing account
appleAuthRecord := &models.AppleSocialAuth{
UserID: existingUser.ID,
@@ -505,8 +521,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
// Reload user with profile
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
// B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.AppleSignInResponse{
Token: token.Key,
@@ -544,8 +563,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create notification preferences: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
}
}
@@ -566,8 +584,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
return nil, apperrors.Internal(err)
}
// Reload user with profile
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
// B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.AppleSignInResponse{
Token: token.Key,
@@ -623,6 +644,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
if email != "" {
existingUser, err := s.userRepo.FindByEmail(email)
if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user
log.Warn().
Str("email", email).
Str("provider", "google").
Uint("user_id", existingUser.ID).
Msg("Auto-linking social account to existing user by email match")
// Link Google ID to existing account
googleAuthRecord := &models.GoogleSocialAuth{
UserID: existingUser.ID,
@@ -649,8 +677,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
// Reload user with profile
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
// B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.GoogleSignInResponse{
Token: token.Key,
@@ -688,8 +719,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create notification preferences: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
}
}
@@ -711,8 +741,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
return nil, apperrors.Internal(err)
}
// Reload user with profile
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
// B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.GoogleSignInResponse{
Token: token.Key,

View File

@@ -2,9 +2,10 @@ package services
import (
"context"
"crypto/md5"
"encoding/json"
"fmt"
"hash/fnv"
"sync"
"time"
"github.com/redis/go-redis/v9"
@@ -18,38 +19,55 @@ type CacheService struct {
client *redis.Client
}
var cacheInstance *CacheService
var (
cacheInstance *CacheService
cacheOnce sync.Once
)
// NewCacheService creates a new cache service
// NewCacheService creates a new cache service (thread-safe via sync.Once)
func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) {
opt, err := redis.ParseURL(cfg.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
var initErr error
cacheOnce.Do(func() {
opt, err := redis.ParseURL(cfg.URL)
if err != nil {
initErr = fmt.Errorf("failed to parse Redis URL: %w", err)
return
}
if cfg.Password != "" {
opt.Password = cfg.Password
}
if cfg.DB != 0 {
opt.DB = cfg.DB
}
client := redis.NewClient(opt)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
initErr = fmt.Errorf("failed to connect to Redis: %w", err)
// Reset Once so a retry is possible after transient failures
cacheOnce = sync.Once{}
return
}
// S-14: Mask credentials in Redis URL before logging
log.Info().
Str("url", config.MaskURLCredentials(cfg.URL)).
Int("db", opt.DB).
Msg("Connected to Redis")
cacheInstance = &CacheService{client: client}
})
if initErr != nil {
return nil, initErr
}
if cfg.Password != "" {
opt.Password = cfg.Password
}
if cfg.DB != 0 {
opt.DB = cfg.DB
}
client := redis.NewClient(opt)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
log.Info().
Str("url", cfg.URL).
Int("db", opt.DB).
Msg("Connected to Redis")
cacheInstance = &CacheService{client: client}
return cacheInstance, nil
}
@@ -311,9 +329,10 @@ func (c *CacheService) CacheSeededData(ctx context.Context, data interface{}) (s
return "", fmt.Errorf("failed to marshal seeded data: %w", err)
}
// Generate MD5 ETag from the JSON data
hash := md5.Sum(jsonData)
etag := fmt.Sprintf("\"%x\"", hash)
// Generate FNV-64a ETag from the JSON data (faster than MD5, non-cryptographic)
h := fnv.New64a()
h.Write(jsonData)
etag := fmt.Sprintf("\"%x\"", h.Sum64())
// Store both the data and the ETag
if err := c.client.Set(ctx, SeededDataKey, jsonData, SeededDataTTL).Err(); err != nil {

View File

@@ -4,11 +4,11 @@ import (
"bytes"
"fmt"
"html/template"
"io"
"time"
mail "github.com/wneessen/go-mail"
"github.com/rs/zerolog/log"
"gopkg.in/gomail.v2"
"github.com/treytartt/honeydue-api/internal/config"
)
@@ -16,17 +16,31 @@ import (
// EmailService handles sending emails
type EmailService struct {
cfg *config.EmailConfig
dialer *gomail.Dialer
client *mail.Client
enabled bool
}
// NewEmailService creates a new email service
func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password)
client, err := mail.NewClient(cfg.Host,
mail.WithPort(cfg.Port),
mail.WithSMTPAuth(mail.SMTPAuthPlain),
mail.WithUsername(cfg.User),
mail.WithPassword(cfg.Password),
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
)
if err != nil {
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
return &EmailService{
cfg: cfg,
client: nil,
enabled: false,
}
}
return &EmailService{
cfg: cfg,
dialer: dialer,
client: client,
enabled: enabled,
}
}
@@ -37,14 +51,18 @@ func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
m := mail.NewMsg()
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
return fmt.Errorf("failed to set from address: %w", err)
}
if err := m.AddTo(to); err != nil {
return fmt.Errorf("failed to set to address: %w", err)
}
m.Subject(subject)
m.SetBodyString(mail.TypeTextPlain, textBody)
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
if err := s.dialer.DialAndSend(m); err != nil {
if err := s.client.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email")
return fmt.Errorf("failed to send email: %w", err)
}
@@ -74,26 +92,25 @@ func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody s
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
m := mail.NewMsg()
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
return fmt.Errorf("failed to set from address: %w", err)
}
if err := m.AddTo(to); err != nil {
return fmt.Errorf("failed to set to address: %w", err)
}
m.Subject(subject)
m.SetBodyString(mail.TypeTextPlain, textBody)
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
if attachment != nil {
m.Attach(attachment.Filename,
gomail.SetCopyFunc(func(w io.Writer) error {
_, err := w.Write(attachment.Data)
return err
}),
gomail.SetHeader(map[string][]string{
"Content-Type": {attachment.ContentType},
}),
m.AttachReader(attachment.Filename,
bytes.NewReader(attachment.Data),
mail.WithFileContentType(mail.ContentType(attachment.ContentType)),
)
}
if err := s.dialer.DialAndSend(m); err != nil {
if err := s.client.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email with attachment")
return fmt.Errorf("failed to send email: %w", err)
}
@@ -108,29 +125,28 @@ func (s *EmailService) SendEmailWithEmbeddedImages(to, subject, htmlBody, textBo
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
m := mail.NewMsg()
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
return fmt.Errorf("failed to set from address: %w", err)
}
if err := m.AddTo(to); err != nil {
return fmt.Errorf("failed to set to address: %w", err)
}
m.Subject(subject)
m.SetBodyString(mail.TypeTextPlain, textBody)
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
// Embed each image with Content-ID for inline display
for _, img := range images {
m.Embed(img.Filename,
gomail.SetCopyFunc(func(w io.Writer) error {
_, err := w.Write(img.Data)
return err
}),
gomail.SetHeader(map[string][]string{
"Content-Type": {img.ContentType},
"Content-ID": {"<" + img.ContentID + ">"},
"Content-Disposition": {"inline; filename=\"" + img.Filename + "\""},
}),
img := img // capture range variable for closure
m.EmbedReader(img.Filename,
bytes.NewReader(img.Data),
mail.WithFileContentType(mail.ContentType(img.ContentType)),
mail.WithFileContentID(img.ContentID),
)
}
if err := s.dialer.DialAndSend(m); err != nil {
if err := s.client.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Int("images", len(images)).Msg("Failed to send email with embedded images")
return fmt.Errorf("failed to send email: %w", err)
}

View File

@@ -0,0 +1,66 @@
package services
import (
"github.com/treytartt/honeydue-api/internal/models"
"gorm.io/gorm"
)
// FileOwnershipService checks whether a user owns a file referenced by URL.
// It queries task completion images, document files, and document images
// to determine ownership through residence access.
type FileOwnershipService struct {
db *gorm.DB
}
// NewFileOwnershipService creates a new FileOwnershipService
func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
return &FileOwnershipService{db: db}
}
// IsFileOwnedByUser checks if the given file URL belongs to a record
// that the user has access to (via residence membership).
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
// Check task completion images: image_url -> completion -> task -> residence -> user access
var completionImageCount int64
err := s.db.Model(&models.TaskCompletionImage{}).
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Count(&completionImageCount).Error
if err != nil {
return false, err
}
if completionImageCount > 0 {
return true, nil
}
// Check document files: file_url -> document -> residence -> user access
var documentCount int64
err = s.db.Model(&models.Document{}).
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Count(&documentCount).Error
if err != nil {
return false, err
}
if documentCount > 0 {
return true, nil
}
// Check document images: image_url -> document_image -> document -> residence -> user access
var documentImageCount int64
err = s.db.Model(&models.DocumentImage{}).
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Count(&documentImageCount).Error
if err != nil {
return false, err
}
if documentImageCount > 0 {
return true, nil
}
return false, nil
}

View File

@@ -36,11 +36,12 @@ var (
// AppleIAPClient handles Apple App Store Server API validation
type AppleIAPClient struct {
keyID string
issuerID string
bundleID string
keyID string
issuerID string
bundleID string
privateKey *ecdsa.PrivateKey
sandbox bool
sandbox bool
httpClient *http.Client // P-07: Reused across requests
}
// GoogleIAPClient handles Google Play Developer API validation
@@ -122,6 +123,7 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
bundleID: cfg.BundleID,
privateKey: ecdsaKey,
sandbox: cfg.Sandbox,
httpClient: &http.Client{Timeout: 30 * time.Second}, // P-07: Single client reused across requests
}, nil
}
@@ -168,8 +170,8 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
// P-07: Reuse the single http.Client instead of creating one per request
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Apple API: %w", err)
}
@@ -276,8 +278,8 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
// P-07: Reuse the single http.Client
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Apple API: %w", err)
}
@@ -357,8 +359,8 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
// P-07: Reuse the single http.Client
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Apple verifyReceipt: %w", err)
}

View File

@@ -4,9 +4,11 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors"
@@ -184,8 +186,33 @@ func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferen
return NewNotificationPreferencesResponse(prefs), nil
}
// validateHourField checks that an optional hour value is in the valid range 0-23.
func validateHourField(val *int, fieldName string) error {
if val != nil && (*val < 0 || *val > 23) {
return apperrors.BadRequest("error.invalid_hour").
WithMessage(fmt.Sprintf("%s must be between 0 and 23", fieldName))
}
return nil
}
// UpdatePreferences updates notification preferences
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
// B-12: Validate hour fields are in range 0-23
hourFields := []struct {
value *int
name string
}{
{req.TaskDueSoonHour, "task_due_soon_hour"},
{req.TaskOverdueHour, "task_overdue_hour"},
{req.WarrantyExpiringHour, "warranty_expiring_hour"},
{req.DailyDigestHour, "daily_digest_hour"},
}
for _, hf := range hourFields {
if err := validateHourField(hf.value, hf.name); err != nil {
return nil, err
}
}
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return nil, apperrors.Internal(err)
@@ -256,7 +283,10 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
// Only update if timezone changed (avoid unnecessary DB writes)
if prefs.Timezone == nil || *prefs.Timezone != timezone {
prefs.Timezone = &timezone
_ = s.notificationRepo.UpdatePreferences(prefs)
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
Msg("Failed to update user timezone in notification preferences")
}
}
}
@@ -430,6 +460,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
// === Response/Request Types ===
// TODO(hardening): Move to internal/dto/responses/notification.go
// NotificationResponse represents a notification in API response
type NotificationResponse struct {
ID uint `json:"id"`
@@ -473,6 +504,7 @@ func NewNotificationResponse(n *models.Notification) NotificationResponse {
return resp
}
// TODO(hardening): Move to internal/dto/responses/notification.go
// NotificationPreferencesResponse represents notification preferences
type NotificationPreferencesResponse struct {
TaskDueSoon bool `json:"task_due_soon"`
@@ -511,6 +543,7 @@ func NewNotificationPreferencesResponse(p *models.NotificationPreference) *Notif
}
}
// TODO(hardening): Move to internal/dto/requests/notification.go
// UpdatePreferencesRequest represents preferences update request
type UpdatePreferencesRequest struct {
TaskDueSoon *bool `json:"task_due_soon"`
@@ -532,6 +565,7 @@ type UpdatePreferencesRequest struct {
DailyDigestHour *int `json:"daily_digest_hour"`
}
// TODO(hardening): Move to internal/dto/responses/notification.go
// DeviceResponse represents a device in API response
type DeviceResponse struct {
ID uint `json:"id"`
@@ -569,6 +603,7 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
}
}
// TODO(hardening): Move to internal/dto/requests/notification.go
// RegisterDeviceRequest represents device registration request
type RegisterDeviceRequest struct {
Name string `json:"name"`

View File

@@ -39,6 +39,7 @@ func generateTrackingID() string {
// HasSentEmail checks if a specific email type has already been sent to a user
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
// TODO(hardening): Replace with OnboardingEmailRepository.HasSentEmail()
var count int64
if err := s.db.Model(&models.OnboardingEmail{}).
Where("user_id = ? AND email_type = ?", userID, emailType).
@@ -51,6 +52,7 @@ func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.Onbo
// RecordEmailSent records that an email was sent to a user
func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.OnboardingEmailType, trackingID string) error {
// TODO(hardening): Replace with OnboardingEmailRepository.Create()
email := &models.OnboardingEmail{
UserID: userID,
EmailType: emailType,
@@ -66,6 +68,7 @@ func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.O
// RecordEmailOpened records that an email was opened based on tracking ID
func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
// TODO(hardening): Replace with OnboardingEmailRepository.MarkOpened()
now := time.Now().UTC()
result := s.db.Model(&models.OnboardingEmail{}).
Where("tracking_id = ? AND opened_at IS NULL", trackingID).
@@ -84,6 +87,7 @@ func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
// GetEmailHistory gets all onboarding emails for a specific user
func (s *OnboardingEmailService) GetEmailHistory(userID uint) ([]models.OnboardingEmail, error) {
// TODO(hardening): Replace with OnboardingEmailRepository.FindByUserID()
var emails []models.OnboardingEmail
if err := s.db.Where("user_id = ?", userID).Order("sent_at DESC").Find(&emails).Error; err != nil {
return nil, err
@@ -105,11 +109,13 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
offset := (page - 1) * pageSize
// TODO(hardening): Replace with OnboardingEmailRepository.CountAll()
// Count total
if err := s.db.Model(&models.OnboardingEmail{}).Count(&total).Error; err != nil {
return nil, 0, err
}
// TODO(hardening): Replace with OnboardingEmailRepository.FindAllPaginated()
// Get paginated results with user info
if err := s.db.Preload("User").
Order("sent_at DESC").
@@ -126,6 +132,7 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) {
stats := &OnboardingEmailStats{}
// TODO(hardening): Replace with OnboardingEmailRepository.GetStats()
// No residence email stats
var noResTotal, noResOpened int64
if err := s.db.Model(&models.OnboardingEmail{}).
@@ -159,6 +166,7 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error)
return stats, nil
}
// TODO(hardening): Move to internal/dto/responses/onboarding_email.go
// OnboardingEmailStats represents statistics about onboarding emails
type OnboardingEmailStats struct {
NoResidenceTotal int64 `json:"no_residence_total"`
@@ -173,6 +181,7 @@ func (s *OnboardingEmailService) UsersNeedingNoResidenceEmail() ([]models.User,
twoDaysAgo := time.Now().UTC().AddDate(0, 0, -2)
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutResidence()
// Find users who:
// 1. Are verified
// 2. Registered 2+ days ago
@@ -201,6 +210,7 @@ func (s *OnboardingEmailService) UsersNeedingNoTasksEmail() ([]models.User, erro
fiveDaysAgo := time.Now().UTC().AddDate(0, 0, -5)
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutTasks()
// Find users who:
// 1. Are verified
// 2. Have at least one residence
@@ -325,6 +335,7 @@ func (s *OnboardingEmailService) sendNoTasksEmail(user models.User) error {
// SendOnboardingEmailToUser manually sends an onboarding email to a specific user
// This is used by admin to force-send emails regardless of eligibility criteria
func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailType models.OnboardingEmailType) error {
// TODO(hardening): Replace with UserRepository.FindByID() (inject UserRepository)
// Load the user
var user models.User
if err := s.db.First(&user, userID).Error; err != nil {
@@ -362,6 +373,7 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp
// If already sent before, delete the old record first to allow re-recording
// This allows admins to "resend" emails while still tracking them
if alreadySent {
// TODO(hardening): Replace with OnboardingEmailRepository.DeleteByUserAndType()
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"github.com/jung-kurt/gofpdf"
"github.com/go-pdf/fpdf"
)
// PDFService handles PDF generation
@@ -18,7 +18,7 @@ func NewPDFService() *PDFService {
// GenerateTasksReportPDF generates a PDF report from task report data
func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte, error) {
pdf := gofpdf.New("P", "mm", "A4", "")
pdf := fpdf.New("P", "mm", "A4", "")
pdf.SetMargins(15, 15, 15)
pdf.AddPage()

View File

@@ -133,14 +133,16 @@ func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*respons
}
}
// Attach completion summaries (honeycomb grid data)
for i := range residenceResponses {
summary, err := s.taskRepo.GetCompletionSummary(residenceResponses[i].ID, now, 10)
if err != nil {
log.Warn().Err(err).Uint("residence_id", residenceResponses[i].ID).Msg("Failed to fetch completion summary")
continue
// P-01: Batch fetch completion summaries in 2 queries total instead of 2*N
summaries, err := s.taskRepo.GetBatchCompletionSummaries(residenceIDs, now, 10)
if err != nil {
log.Warn().Err(err).Msg("Failed to fetch batch completion summaries")
} else {
for i := range residenceResponses {
if summary, ok := summaries[residenceResponses[i].ID]; ok {
residenceResponses[i].CompletionSummary = summary
}
}
residenceResponses[i].CompletionSummary = summary
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
@@ -17,7 +18,8 @@ import (
// StorageService handles file uploads to local filesystem
type StorageService struct {
cfg *config.StorageConfig
cfg *config.StorageConfig
allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups
}
// UploadResult contains information about an uploaded file
@@ -44,9 +46,18 @@ func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) {
}
}
log.Info().Str("upload_dir", cfg.UploadDir).Msg("Storage service initialized")
// P-12: Parse AllowedTypes once at initialization for O(1) lookups
allowedTypes := make(map[string]struct{})
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
trimmed := strings.TrimSpace(t)
if trimmed != "" {
allowedTypes[trimmed] = struct{}{}
}
}
return &StorageService{cfg: cfg}, nil
log.Info().Str("upload_dir", cfg.UploadDir).Int("allowed_types", len(allowedTypes)).Msg("Storage service initialized")
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}, nil
}
// Upload saves a file to the local filesystem
@@ -56,17 +67,47 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize)
}
// Get MIME type
mimeType := file.Header.Get("Content-Type")
if mimeType == "" {
mimeType = "application/octet-stream"
// Get claimed MIME type from header
claimedMimeType := file.Header.Get("Content-Type")
if claimedMimeType == "" {
claimedMimeType = "application/octet-stream"
}
// Validate MIME type
// S-09: Detect actual content type from file bytes to prevent disguised uploads
src, err := file.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
}
defer src.Close()
// Read the first 512 bytes for content type detection
sniffBuf := make([]byte, 512)
n, err := src.Read(sniffBuf)
if err != nil && n == 0 {
return nil, fmt.Errorf("failed to read file for content type detection: %w", err)
}
detectedMimeType := http.DetectContentType(sniffBuf[:n])
// Validate that the detected type matches the claimed type (at the category level)
// Allow application/octet-stream from detection since DetectContentType may not
// recognize all valid types, but the claimed type must still be in our allowed list
if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) {
return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType)
}
// Use the claimed MIME type (which is more specific) if it's allowed
mimeType := claimedMimeType
// Validate MIME type against allowed list
if !s.isAllowedType(mimeType) {
return nil, fmt.Errorf("file type %s is not allowed", mimeType)
}
// Seek back to beginning after sniffing
if _, err := src.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("failed to seek file: %w", err)
}
// Generate unique filename
ext := filepath.Ext(file.Filename)
if ext == "" {
@@ -83,15 +124,11 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
subdir = "completions"
}
// Full path
destPath := filepath.Join(s.cfg.UploadDir, subdir, newFilename)
// Open source file
src, err := file.Open()
// S-18: Sanitize path to prevent traversal attacks
destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename))
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
return nil, fmt.Errorf("invalid upload path: %w", err)
}
defer src.Close()
// Create destination file
dst, err := os.Create(destPath)
@@ -131,19 +168,11 @@ func (s *StorageService) Delete(fileURL string) error {
// Convert URL to file path
relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL)
relativePath = strings.TrimPrefix(relativePath, "/")
fullPath := filepath.Join(s.cfg.UploadDir, relativePath)
// Security check: ensure path is within upload directory
absUploadDir, err := filepath.Abs(s.cfg.UploadDir)
// S-18: Use SafeResolvePath to prevent path traversal
fullPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath)
if err != nil {
return fmt.Errorf("failed to resolve upload directory: %w", err)
}
absFilePath, err := filepath.Abs(fullPath)
if err != nil {
return fmt.Errorf("failed to resolve file path: %w", err)
}
if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir {
return fmt.Errorf("invalid file path")
return fmt.Errorf("invalid file path: %w", err)
}
if err := os.Remove(fullPath); err != nil {
@@ -157,15 +186,23 @@ func (s *StorageService) Delete(fileURL string) error {
return nil
}
// isAllowedType checks if the MIME type is in the allowed list
// isAllowedType checks if the MIME type is in the allowed list.
// P-12: Uses the pre-parsed allowedTypes map for O(1) lookups instead of
// splitting the config string on every call.
func (s *StorageService) isAllowedType(mimeType string) bool {
allowed := strings.Split(s.cfg.AllowedTypes, ",")
for _, t := range allowed {
if strings.TrimSpace(t) == mimeType {
return true
}
_, ok := s.allowedTypes[mimeType]
return ok
}
// mimeTypesCompatible checks if the claimed and detected MIME types are compatible.
// Two MIME types are compatible if they share the same primary type (e.g., both "image/*").
func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool {
claimedParts := strings.SplitN(claimed, "/", 2)
detectedParts := strings.SplitN(detected, "/", 2)
if len(claimedParts) < 1 || len(detectedParts) < 1 {
return false
}
return false
return claimedParts[0] == detectedParts[0]
}
// getExtensionFromMimeType returns a file extension for common MIME types
@@ -191,5 +228,12 @@ func (s *StorageService) GetUploadDir() string {
// NewStorageServiceForTest creates a StorageService without creating directories.
// This is intended only for unit tests that need a StorageService with a known config.
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
return &StorageService{cfg: cfg}
allowedTypes := make(map[string]struct{})
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
trimmed := strings.TrimSpace(t)
if trimmed != "" {
allowedTypes[trimmed] = struct{}{}
}
}
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}
}

View File

@@ -3,10 +3,10 @@ package services
import (
"encoding/json"
"fmt"
"os"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"github.com/stripe/stripe-go/v81"
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
@@ -34,7 +34,8 @@ func NewStripeService(
subscriptionRepo *repositories.SubscriptionRepository,
userRepo *repositories.UserRepository,
) *StripeService {
key := os.Getenv("STRIPE_SECRET_KEY")
// S-21: Use Viper config instead of os.Getenv for consistent configuration management
key := viper.GetString("STRIPE_SECRET_KEY")
if key == "" {
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
} else {
@@ -42,7 +43,7 @@ func NewStripeService(
log.Info().Msg("Stripe API key configured")
}
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
webhookSecret := viper.GetString("STRIPE_WEBHOOK_SECRET")
if webhookSecret == "" {
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
}

View File

@@ -202,18 +202,19 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
}
// getUserUsage calculates current usage for a user.
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
residences, err := s.residenceRepo.FindOwnedByUser(userID)
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
propertiesCount, err := s.residenceRepo.CountByOwner(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
propertiesCount := int64(len(residences))
// Collect residence IDs for batch queries
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
// Still need residence IDs for batch counting tasks/contractors/documents
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Count tasks, contractors, and documents across all residences with single queries each

View File

@@ -130,7 +130,7 @@ func (s *TaskService) ListTasks(userID uint, daysThreshold int, now time.Time) (
return nil, apperrors.Internal(err)
}
resp := responses.NewKanbanBoardResponseForAll(board)
resp := responses.NewKanbanBoardResponseForAll(board, now)
// NOTE: Summary statistics are calculated client-side from kanban data
return &resp, nil
}
@@ -157,7 +157,7 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol
return nil, apperrors.Internal(err)
}
resp := responses.NewKanbanBoardResponse(board, residenceID)
resp := responses.NewKanbanBoardResponse(board, residenceID, now)
// NOTE: Summary statistics are calculated client-side from kanban data
return &resp, nil
}
@@ -601,8 +601,8 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
task.InProgress = false
}
// P1-5: Wrap completion creation and task update in a transaction.
// If either operation fails, both are rolled back to prevent orphaned completions.
// P1-5 + B-07: Wrap completion creation, task update, and image creation
// in a single transaction for atomicity. If any operation fails, all are rolled back.
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
return err
@@ -610,6 +610,18 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
return err
}
// B-07: Create images inside the same transaction as completion
for _, imageURL := range req.ImageURLs {
if imageURL != "" {
img := &models.TaskCompletionImage{
CompletionID: completion.ID,
ImageURL: imageURL,
}
if err := tx.Create(img).Error; err != nil {
return fmt.Errorf("failed to create completion image: %w", err)
}
}
}
return nil
})
if txErr != nil {
@@ -621,19 +633,6 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
return nil, apperrors.Internal(txErr)
}
// Create images if provided
for _, imageURL := range req.ImageURLs {
if imageURL != "" {
img := &models.TaskCompletionImage{
CompletionID: completion.ID,
ImageURL: imageURL,
}
if err := s.taskRepo.CreateCompletionImage(img); err != nil {
log.Error().Err(err).Uint("completion_id", completion.ID).Msg("Failed to create completion image")
}
}
}
// Reload completion with user info and images
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
if err != nil {
@@ -663,8 +662,10 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
}, nil
}
// QuickComplete creates a minimal task completion (for widget use)
// Returns only success/error, no response body
// QuickComplete creates a minimal task completion (for widget use).
// LE-01: The entire operation (completion creation + task update) is wrapped in a
// transaction for atomicity.
// Returns only success/error, no response body.
func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
// Get the task
task, err := s.taskRepo.FindByID(taskID)
@@ -697,10 +698,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
CompletedFromColumn: completedFromColumn,
}
if err := s.taskRepo.CreateCompletion(completion); err != nil {
return apperrors.Internal(err)
}
// Update next_due_date and in_progress based on frequency
// Determine interval days: Custom frequency uses task.CustomIntervalDays, otherwise use frequency.Days
// Note: Frequency is no longer preloaded for performance, so we load it separately if needed
@@ -729,7 +726,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
} else {
// Recurring task - calculate next due date from completion date + interval
nextDue := completedAt.AddDate(0, 0, *quickIntervalDays)
// frequencyName was already set when loading frequency above
log.Info().
Uint("task_id", task.ID).
Str("frequency_name", frequencyName).
@@ -742,12 +738,23 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
// Reset in_progress to false
task.InProgress = false
}
if err := s.taskRepo.Update(task); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
// LE-01: Wrap completion creation and task update in a transaction for atomicity
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
return err
}
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
return err
}
return nil
})
if txErr != nil {
if errors.Is(txErr, repositories.ErrVersionConflict) {
return apperrors.Conflict("error.version_conflict")
}
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion")
return apperrors.Internal(err) // Return error so caller knows the update failed
log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task in QuickComplete")
return apperrors.Internal(txErr)
}
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
@@ -813,8 +820,16 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
// Send email notification (to everyone INCLUDING the person who completed it)
// Check user's email notification preferences first
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
prefs, err := s.notificationService.GetPreferences(user.ID)
if err != nil || (prefs != nil && prefs.EmailTaskCompleted) {
prefs, prefsErr := s.notificationService.GetPreferences(user.ID)
// LE-06: Log fail-open behavior when preferences cannot be loaded
if prefsErr != nil {
log.Warn().
Err(prefsErr).
Uint("user_id", user.ID).
Uint("task_id", task.ID).
Msg("Failed to load notification preferences, falling back to sending email (fail-open)")
}
if prefsErr != nil || (prefs != nil && prefs.EmailTaskCompleted) {
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
if err := s.emailService.SendTaskCompletedEmail(
user.Email,

View File

@@ -32,7 +32,8 @@ func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.User
return nil, apperrors.Internal(err)
}
var result []responses.UserSummary
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
result := make([]responses.UserSummary, 0, len(users))
for _, u := range users {
result = append(result, responses.UserSummary{
ID: u.ID,
@@ -72,7 +73,8 @@ func (s *UserService) ListProfilesInSharedResidences(userID uint) ([]responses.U
return nil, apperrors.Internal(err)
}
var result []responses.UserProfileSummary
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
result := make([]responses.UserProfileSummary, 0, len(profiles))
for _, p := range profiles {
result = append(result, responses.UserProfileSummary{
ID: p.ID,