Fix 113 hardening issues across entire Go backend
Security: - Replace all binding: tags with validate: + c.Validate() in admin handlers - Add rate limiting to auth endpoints (login, register, password reset) - Add security headers (HSTS, XSS protection, nosniff, frame options) - Wire Google Pub/Sub token verification into webhook handler - Replace ParseUnverified with proper OIDC/JWKS key verification - Verify inner Apple JWS signatures in webhook handler - Add io.LimitReader (1MB) to all webhook body reads - Add ownership verification to file deletion - Move hardcoded admin credentials to env vars - Add uniqueIndex to User.Email - Hide ConfirmationCode from JSON serialization - Mask confirmation codes in admin responses - Use http.DetectContentType for upload validation - Fix path traversal in storage service - Replace os.Getenv with Viper in stripe service - Sanitize Redis URLs before logging - Separate DEBUG_FIXED_CODES from DEBUG flag - Reject weak SECRET_KEY in production - Add host check on /_next/* proxy routes - Use explicit localhost CORS origins in debug mode - Replace err.Error() with generic messages in all admin error responses Critical fixes: - Rewrite FCM to HTTP v1 API with OAuth 2.0 service account auth - Fix user_customuser -> auth_user table names in raw SQL - Fix dashboard verified query to use UserProfile model - Add escapeLikeWildcards() to prevent SQL wildcard injection Bug fixes: - Add bounds checks for days/expiring_soon query params (1-3650) - Add receipt_data/transaction_id empty-check to RestoreSubscription - Change Active bool -> *bool in device handler - Check all unchecked GORM/FindByIDWithProfile errors - Add validation for notification hour fields (0-23) - Add max=10000 validation on task description updates Transactions & data integrity: - Wrap registration flow in transaction - Wrap QuickComplete in transaction - Move image creation inside completion transaction - Wrap SetSpecialties in transaction - Wrap GetOrCreateToken in transaction - Wrap completion+image deletion in transaction Performance: - Batch completion summaries (2 queries vs 2N) - Reuse single http.Client in IAP validation - Cache dashboard counts (30s TTL) - Batch COUNT queries in admin user list - Add Limit(500) to document queries - Add reminder_stage+due_date filters to reminder queries - Parse AllowedTypes once at init - In-memory user cache in auth middleware (30s TTL) - Timezone change detection cache - Optimize P95 with per-endpoint sorted buffers - Replace crypto/md5 with hash/fnv for ETags Code quality: - Add sync.Once to all monitoring Stop()/Close() methods - Replace 8 fmt.Printf with zerolog in auth service - Log previously discarded errors - Standardize delete response shapes - Route hardcoded English through i18n - Remove FileURL from DocumentResponse (keep MediaURL only) - Thread user timezone through kanban board responses - Initialize empty slices to prevent null JSON - Extract shared field map for task Update/UpdateTx - Delete unused SoftDeleteModel, min(), formatCron, legacy handlers Worker & jobs: - Wire Asynq email infrastructure into worker - Register HandleReminderLogCleanup with daily 3AM cron - Use per-user timezone in HandleSmartReminder - Replace direct DB queries with repository calls - Delete legacy reminder handlers (~200 lines) - Delete unused task type constants Dependencies: - Replace archived jung-kurt/gofpdf with go-pdf/fpdf - Replace unmaintained gomail.v2 with wneessen/go-mail - Add TODO for Echo jwt v3 transitive dep removal Test infrastructure: - Fix MakeRequest/SeedLookupData error handling - Replace os.Exit(0) with t.Skip() in scope/consistency tests - Add 11 new FCM v1 tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
@@ -90,7 +91,7 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
|
||||
// Update last login
|
||||
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
|
||||
// Log error but don't fail the login
|
||||
fmt.Printf("Failed to update last login: %v\n", err)
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
|
||||
}
|
||||
|
||||
return &responses.LoginResponse{
|
||||
@@ -99,7 +100,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register creates a new user account
|
||||
// Register creates a new user account.
|
||||
// F-10: User creation, profile creation, notification preferences, and confirmation code
|
||||
// are wrapped in a transaction for atomicity.
|
||||
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
|
||||
// Check if username exists
|
||||
exists, err := s.userRepo.ExistsByUsername(req.Username)
|
||||
@@ -133,43 +136,49 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
||||
return nil, "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Save user
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Create user profile
|
||||
if _, err := s.userRepo.GetOrCreateProfile(user.ID); err != nil {
|
||||
// Log error but don't fail registration
|
||||
fmt.Printf("Failed to create user profile: %v\n", err)
|
||||
}
|
||||
|
||||
// Create notification preferences with all options enabled
|
||||
if s.notificationRepo != nil {
|
||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||
// Log error but don't fail registration
|
||||
fmt.Printf("Failed to create notification preferences: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create auth token
|
||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
||||
if err != nil {
|
||||
return nil, "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Generate confirmation code - use fixed code in debug mode for easier local testing
|
||||
// Generate confirmation code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
|
||||
var code string
|
||||
if s.cfg.Server.Debug {
|
||||
if s.cfg.Server.DebugFixedCodes {
|
||||
code = "123456"
|
||||
} else {
|
||||
code = generateSixDigitCode()
|
||||
}
|
||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
||||
|
||||
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
|
||||
// Log error but don't fail registration
|
||||
fmt.Printf("Failed to create confirmation code: %v\n", err)
|
||||
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
|
||||
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
|
||||
// Save user
|
||||
if err := txRepo.Create(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create user profile
|
||||
if _, err := txRepo.GetOrCreateProfile(user.ID); err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile during registration")
|
||||
}
|
||||
|
||||
// Create notification preferences with all options enabled
|
||||
if s.notificationRepo != nil {
|
||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
|
||||
}
|
||||
}
|
||||
|
||||
// Create confirmation code
|
||||
if _, err := txRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create confirmation code during registration")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return nil, "", apperrors.Internal(txErr)
|
||||
}
|
||||
|
||||
// Create auth token (outside transaction since token generation is idempotent)
|
||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
||||
if err != nil {
|
||||
return nil, "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.RegisterResponse{
|
||||
@@ -248,8 +257,8 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
|
||||
return apperrors.BadRequest("error.email_already_verified")
|
||||
}
|
||||
|
||||
// Check for test code in debug mode
|
||||
if s.cfg.Server.Debug && code == "123456" {
|
||||
// Check for test code when DEBUG_FIXED_CODES is enabled
|
||||
if s.cfg.Server.DebugFixedCodes && code == "123456" {
|
||||
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
@@ -294,9 +303,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
|
||||
return "", apperrors.BadRequest("error.email_already_verified")
|
||||
}
|
||||
|
||||
// Generate new code - use fixed code in debug mode for easier local testing
|
||||
// Generate new code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
|
||||
var code string
|
||||
if s.cfg.Server.Debug {
|
||||
if s.cfg.Server.DebugFixedCodes {
|
||||
code = "123456"
|
||||
} else {
|
||||
code = generateSixDigitCode()
|
||||
@@ -331,9 +340,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
|
||||
return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded")
|
||||
}
|
||||
|
||||
// Generate code and reset token - use fixed code in debug mode for easier local testing
|
||||
// Generate code and reset token - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
|
||||
var code string
|
||||
if s.cfg.Server.Debug {
|
||||
if s.cfg.Server.DebugFixedCodes {
|
||||
code = "123456"
|
||||
} else {
|
||||
code = generateSixDigitCode()
|
||||
@@ -365,8 +374,8 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
|
||||
return "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Check for test code in debug mode
|
||||
if s.cfg.Server.Debug && code == "123456" {
|
||||
// Check for test code when DEBUG_FIXED_CODES is enabled
|
||||
if s.cfg.Server.DebugFixedCodes && code == "123456" {
|
||||
return resetCode.ResetToken, nil
|
||||
}
|
||||
|
||||
@@ -422,13 +431,13 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
|
||||
// Mark reset code as used
|
||||
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
|
||||
// Log error but don't fail
|
||||
fmt.Printf("Failed to mark reset code as used: %v\n", err)
|
||||
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
|
||||
}
|
||||
|
||||
// Invalidate all existing tokens for this user (security measure)
|
||||
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
|
||||
// Log error but don't fail
|
||||
fmt.Printf("Failed to delete user tokens: %v\n", err)
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -482,6 +491,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
if email != "" {
|
||||
existingUser, err := s.userRepo.FindByEmail(email)
|
||||
if err == nil && existingUser != nil {
|
||||
// S-06: Log auto-linking of social account to existing user
|
||||
log.Warn().
|
||||
Str("email", email).
|
||||
Str("provider", "apple").
|
||||
Uint("user_id", existingUser.ID).
|
||||
Msg("Auto-linking social account to existing user by email match")
|
||||
|
||||
// Link Apple ID to existing account
|
||||
appleAuthRecord := &models.AppleSocialAuth{
|
||||
UserID: existingUser.ID,
|
||||
@@ -505,8 +521,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
// Update last login
|
||||
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
||||
|
||||
// Reload user with profile
|
||||
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||
// B-08: Check error from FindByIDWithProfile
|
||||
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
@@ -544,8 +563,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
// Create notification preferences with all options enabled
|
||||
if s.notificationRepo != nil {
|
||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||
// Log error but don't fail registration
|
||||
fmt.Printf("Failed to create notification preferences: %v\n", err)
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -566,8 +584,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Reload user with profile
|
||||
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
|
||||
// B-08: Check error from FindByIDWithProfile
|
||||
user, err = s.userRepo.FindByIDWithProfile(user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
@@ -623,6 +644,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
if email != "" {
|
||||
existingUser, err := s.userRepo.FindByEmail(email)
|
||||
if err == nil && existingUser != nil {
|
||||
// S-06: Log auto-linking of social account to existing user
|
||||
log.Warn().
|
||||
Str("email", email).
|
||||
Str("provider", "google").
|
||||
Uint("user_id", existingUser.ID).
|
||||
Msg("Auto-linking social account to existing user by email match")
|
||||
|
||||
// Link Google ID to existing account
|
||||
googleAuthRecord := &models.GoogleSocialAuth{
|
||||
UserID: existingUser.ID,
|
||||
@@ -649,8 +677,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
// Update last login
|
||||
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
||||
|
||||
// Reload user with profile
|
||||
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||
// B-08: Check error from FindByIDWithProfile
|
||||
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
@@ -688,8 +719,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
// Create notification preferences with all options enabled
|
||||
if s.notificationRepo != nil {
|
||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||
// Log error but don't fail registration
|
||||
fmt.Printf("Failed to create notification preferences: %v\n", err)
|
||||
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -711,8 +741,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Reload user with profile
|
||||
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
|
||||
// B-08: Check error from FindByIDWithProfile
|
||||
user, err = s.userRepo.FindByIDWithProfile(user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
|
||||
@@ -2,9 +2,10 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -18,38 +19,55 @@ type CacheService struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
var cacheInstance *CacheService
|
||||
var (
|
||||
cacheInstance *CacheService
|
||||
cacheOnce sync.Once
|
||||
)
|
||||
|
||||
// NewCacheService creates a new cache service
|
||||
// NewCacheService creates a new cache service (thread-safe via sync.Once)
|
||||
func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) {
|
||||
opt, err := redis.ParseURL(cfg.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
|
||||
var initErr error
|
||||
|
||||
cacheOnce.Do(func() {
|
||||
opt, err := redis.ParseURL(cfg.URL)
|
||||
if err != nil {
|
||||
initErr = fmt.Errorf("failed to parse Redis URL: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Password != "" {
|
||||
opt.Password = cfg.Password
|
||||
}
|
||||
if cfg.DB != 0 {
|
||||
opt.DB = cfg.DB
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
initErr = fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
// Reset Once so a retry is possible after transient failures
|
||||
cacheOnce = sync.Once{}
|
||||
return
|
||||
}
|
||||
|
||||
// S-14: Mask credentials in Redis URL before logging
|
||||
log.Info().
|
||||
Str("url", config.MaskURLCredentials(cfg.URL)).
|
||||
Int("db", opt.DB).
|
||||
Msg("Connected to Redis")
|
||||
|
||||
cacheInstance = &CacheService{client: client}
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
}
|
||||
|
||||
if cfg.Password != "" {
|
||||
opt.Password = cfg.Password
|
||||
}
|
||||
if cfg.DB != 0 {
|
||||
opt.DB = cfg.DB
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("url", cfg.URL).
|
||||
Int("db", opt.DB).
|
||||
Msg("Connected to Redis")
|
||||
|
||||
cacheInstance = &CacheService{client: client}
|
||||
return cacheInstance, nil
|
||||
}
|
||||
|
||||
@@ -311,9 +329,10 @@ func (c *CacheService) CacheSeededData(ctx context.Context, data interface{}) (s
|
||||
return "", fmt.Errorf("failed to marshal seeded data: %w", err)
|
||||
}
|
||||
|
||||
// Generate MD5 ETag from the JSON data
|
||||
hash := md5.Sum(jsonData)
|
||||
etag := fmt.Sprintf("\"%x\"", hash)
|
||||
// Generate FNV-64a ETag from the JSON data (faster than MD5, non-cryptographic)
|
||||
h := fnv.New64a()
|
||||
h.Write(jsonData)
|
||||
etag := fmt.Sprintf("\"%x\"", h.Sum64())
|
||||
|
||||
// Store both the data and the ETag
|
||||
if err := c.client.Set(ctx, SeededDataKey, jsonData, SeededDataTTL).Err(); err != nil {
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
mail "github.com/wneessen/go-mail"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/gomail.v2"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
@@ -16,17 +16,31 @@ import (
|
||||
// EmailService handles sending emails
|
||||
type EmailService struct {
|
||||
cfg *config.EmailConfig
|
||||
dialer *gomail.Dialer
|
||||
client *mail.Client
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewEmailService creates a new email service
|
||||
func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
|
||||
dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password)
|
||||
client, err := mail.NewClient(cfg.Host,
|
||||
mail.WithPort(cfg.Port),
|
||||
mail.WithSMTPAuth(mail.SMTPAuthPlain),
|
||||
mail.WithUsername(cfg.User),
|
||||
mail.WithPassword(cfg.Password),
|
||||
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
|
||||
return &EmailService{
|
||||
cfg: cfg,
|
||||
client: nil,
|
||||
enabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
return &EmailService{
|
||||
cfg: cfg,
|
||||
dialer: dialer,
|
||||
client: client,
|
||||
enabled: enabled,
|
||||
}
|
||||
}
|
||||
@@ -37,14 +51,18 @@ func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
|
||||
log.Debug().Msg("Email sending disabled by feature flag")
|
||||
return nil
|
||||
}
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", s.cfg.From)
|
||||
m.SetHeader("To", to)
|
||||
m.SetHeader("Subject", subject)
|
||||
m.SetBody("text/plain", textBody)
|
||||
m.AddAlternative("text/html", htmlBody)
|
||||
m := mail.NewMsg()
|
||||
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
|
||||
return fmt.Errorf("failed to set from address: %w", err)
|
||||
}
|
||||
if err := m.AddTo(to); err != nil {
|
||||
return fmt.Errorf("failed to set to address: %w", err)
|
||||
}
|
||||
m.Subject(subject)
|
||||
m.SetBodyString(mail.TypeTextPlain, textBody)
|
||||
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
|
||||
|
||||
if err := s.dialer.DialAndSend(m); err != nil {
|
||||
if err := s.client.DialAndSend(m); err != nil {
|
||||
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email")
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
@@ -74,26 +92,25 @@ func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody s
|
||||
log.Debug().Msg("Email sending disabled by feature flag")
|
||||
return nil
|
||||
}
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", s.cfg.From)
|
||||
m.SetHeader("To", to)
|
||||
m.SetHeader("Subject", subject)
|
||||
m.SetBody("text/plain", textBody)
|
||||
m.AddAlternative("text/html", htmlBody)
|
||||
m := mail.NewMsg()
|
||||
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
|
||||
return fmt.Errorf("failed to set from address: %w", err)
|
||||
}
|
||||
if err := m.AddTo(to); err != nil {
|
||||
return fmt.Errorf("failed to set to address: %w", err)
|
||||
}
|
||||
m.Subject(subject)
|
||||
m.SetBodyString(mail.TypeTextPlain, textBody)
|
||||
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
|
||||
|
||||
if attachment != nil {
|
||||
m.Attach(attachment.Filename,
|
||||
gomail.SetCopyFunc(func(w io.Writer) error {
|
||||
_, err := w.Write(attachment.Data)
|
||||
return err
|
||||
}),
|
||||
gomail.SetHeader(map[string][]string{
|
||||
"Content-Type": {attachment.ContentType},
|
||||
}),
|
||||
m.AttachReader(attachment.Filename,
|
||||
bytes.NewReader(attachment.Data),
|
||||
mail.WithFileContentType(mail.ContentType(attachment.ContentType)),
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.dialer.DialAndSend(m); err != nil {
|
||||
if err := s.client.DialAndSend(m); err != nil {
|
||||
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email with attachment")
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
@@ -108,29 +125,28 @@ func (s *EmailService) SendEmailWithEmbeddedImages(to, subject, htmlBody, textBo
|
||||
log.Debug().Msg("Email sending disabled by feature flag")
|
||||
return nil
|
||||
}
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", s.cfg.From)
|
||||
m.SetHeader("To", to)
|
||||
m.SetHeader("Subject", subject)
|
||||
m.SetBody("text/plain", textBody)
|
||||
m.AddAlternative("text/html", htmlBody)
|
||||
m := mail.NewMsg()
|
||||
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
|
||||
return fmt.Errorf("failed to set from address: %w", err)
|
||||
}
|
||||
if err := m.AddTo(to); err != nil {
|
||||
return fmt.Errorf("failed to set to address: %w", err)
|
||||
}
|
||||
m.Subject(subject)
|
||||
m.SetBodyString(mail.TypeTextPlain, textBody)
|
||||
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
|
||||
|
||||
// Embed each image with Content-ID for inline display
|
||||
for _, img := range images {
|
||||
m.Embed(img.Filename,
|
||||
gomail.SetCopyFunc(func(w io.Writer) error {
|
||||
_, err := w.Write(img.Data)
|
||||
return err
|
||||
}),
|
||||
gomail.SetHeader(map[string][]string{
|
||||
"Content-Type": {img.ContentType},
|
||||
"Content-ID": {"<" + img.ContentID + ">"},
|
||||
"Content-Disposition": {"inline; filename=\"" + img.Filename + "\""},
|
||||
}),
|
||||
img := img // capture range variable for closure
|
||||
m.EmbedReader(img.Filename,
|
||||
bytes.NewReader(img.Data),
|
||||
mail.WithFileContentType(mail.ContentType(img.ContentType)),
|
||||
mail.WithFileContentID(img.ContentID),
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.dialer.DialAndSend(m); err != nil {
|
||||
if err := s.client.DialAndSend(m); err != nil {
|
||||
log.Error().Err(err).Str("to", to).Str("subject", subject).Int("images", len(images)).Msg("Failed to send email with embedded images")
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
66
internal/services/file_ownership_service.go
Normal file
66
internal/services/file_ownership_service.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// FileOwnershipService checks whether a user owns a file referenced by URL.
|
||||
// It queries task completion images, document files, and document images
|
||||
// to determine ownership through residence access.
|
||||
type FileOwnershipService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewFileOwnershipService creates a new FileOwnershipService
|
||||
func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
|
||||
return &FileOwnershipService{db: db}
|
||||
}
|
||||
|
||||
// IsFileOwnedByUser checks if the given file URL belongs to a record
|
||||
// that the user has access to (via residence membership).
|
||||
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
|
||||
// Check task completion images: image_url -> completion -> task -> residence -> user access
|
||||
var completionImageCount int64
|
||||
err := s.db.Model(&models.TaskCompletionImage{}).
|
||||
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
|
||||
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
|
||||
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Count(&completionImageCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if completionImageCount > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check document files: file_url -> document -> residence -> user access
|
||||
var documentCount int64
|
||||
err = s.db.Model(&models.Document{}).
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Count(&documentCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if documentCount > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check document images: image_url -> document_image -> document -> residence -> user access
|
||||
var documentImageCount int64
|
||||
err = s.db.Model(&models.DocumentImage{}).
|
||||
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Count(&documentImageCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if documentImageCount > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
@@ -36,11 +36,12 @@ var (
|
||||
|
||||
// AppleIAPClient handles Apple App Store Server API validation
|
||||
type AppleIAPClient struct {
|
||||
keyID string
|
||||
issuerID string
|
||||
bundleID string
|
||||
keyID string
|
||||
issuerID string
|
||||
bundleID string
|
||||
privateKey *ecdsa.PrivateKey
|
||||
sandbox bool
|
||||
sandbox bool
|
||||
httpClient *http.Client // P-07: Reused across requests
|
||||
}
|
||||
|
||||
// GoogleIAPClient handles Google Play Developer API validation
|
||||
@@ -122,6 +123,7 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
|
||||
bundleID: cfg.BundleID,
|
||||
privateKey: ecdsaKey,
|
||||
sandbox: cfg.Sandbox,
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second}, // P-07: Single client reused across requests
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -168,8 +170,8 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
// P-07: Reuse the single http.Client instead of creating one per request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call Apple API: %w", err)
|
||||
}
|
||||
@@ -276,8 +278,8 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
// P-07: Reuse the single http.Client
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call Apple API: %w", err)
|
||||
}
|
||||
@@ -357,8 +359,8 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
// P-07: Reuse the single http.Client
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call Apple verifyReceipt: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
@@ -184,8 +186,33 @@ func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferen
|
||||
return NewNotificationPreferencesResponse(prefs), nil
|
||||
}
|
||||
|
||||
// validateHourField checks that an optional hour value is in the valid range 0-23.
|
||||
func validateHourField(val *int, fieldName string) error {
|
||||
if val != nil && (*val < 0 || *val > 23) {
|
||||
return apperrors.BadRequest("error.invalid_hour").
|
||||
WithMessage(fmt.Sprintf("%s must be between 0 and 23", fieldName))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePreferences updates notification preferences
|
||||
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
|
||||
// B-12: Validate hour fields are in range 0-23
|
||||
hourFields := []struct {
|
||||
value *int
|
||||
name string
|
||||
}{
|
||||
{req.TaskDueSoonHour, "task_due_soon_hour"},
|
||||
{req.TaskOverdueHour, "task_overdue_hour"},
|
||||
{req.WarrantyExpiringHour, "warranty_expiring_hour"},
|
||||
{req.DailyDigestHour, "daily_digest_hour"},
|
||||
}
|
||||
for _, hf := range hourFields {
|
||||
if err := validateHourField(hf.value, hf.name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
@@ -256,7 +283,10 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
|
||||
// Only update if timezone changed (avoid unnecessary DB writes)
|
||||
if prefs.Timezone == nil || *prefs.Timezone != timezone {
|
||||
prefs.Timezone = &timezone
|
||||
_ = s.notificationRepo.UpdatePreferences(prefs)
|
||||
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
|
||||
Msg("Failed to update user timezone in notification preferences")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,6 +460,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
|
||||
|
||||
// === Response/Request Types ===
|
||||
|
||||
// TODO(hardening): Move to internal/dto/responses/notification.go
|
||||
// NotificationResponse represents a notification in API response
|
||||
type NotificationResponse struct {
|
||||
ID uint `json:"id"`
|
||||
@@ -473,6 +504,7 @@ func NewNotificationResponse(n *models.Notification) NotificationResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
// TODO(hardening): Move to internal/dto/responses/notification.go
|
||||
// NotificationPreferencesResponse represents notification preferences
|
||||
type NotificationPreferencesResponse struct {
|
||||
TaskDueSoon bool `json:"task_due_soon"`
|
||||
@@ -511,6 +543,7 @@ func NewNotificationPreferencesResponse(p *models.NotificationPreference) *Notif
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(hardening): Move to internal/dto/requests/notification.go
|
||||
// UpdatePreferencesRequest represents preferences update request
|
||||
type UpdatePreferencesRequest struct {
|
||||
TaskDueSoon *bool `json:"task_due_soon"`
|
||||
@@ -532,6 +565,7 @@ type UpdatePreferencesRequest struct {
|
||||
DailyDigestHour *int `json:"daily_digest_hour"`
|
||||
}
|
||||
|
||||
// TODO(hardening): Move to internal/dto/responses/notification.go
|
||||
// DeviceResponse represents a device in API response
|
||||
type DeviceResponse struct {
|
||||
ID uint `json:"id"`
|
||||
@@ -569,6 +603,7 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(hardening): Move to internal/dto/requests/notification.go
|
||||
// RegisterDeviceRequest represents device registration request
|
||||
type RegisterDeviceRequest struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -39,6 +39,7 @@ func generateTrackingID() string {
|
||||
|
||||
// HasSentEmail checks if a specific email type has already been sent to a user
|
||||
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.HasSentEmail()
|
||||
var count int64
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("user_id = ? AND email_type = ?", userID, emailType).
|
||||
@@ -51,6 +52,7 @@ func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.Onbo
|
||||
|
||||
// RecordEmailSent records that an email was sent to a user
|
||||
func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.OnboardingEmailType, trackingID string) error {
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.Create()
|
||||
email := &models.OnboardingEmail{
|
||||
UserID: userID,
|
||||
EmailType: emailType,
|
||||
@@ -66,6 +68,7 @@ func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.O
|
||||
|
||||
// RecordEmailOpened records that an email was opened based on tracking ID
|
||||
func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.MarkOpened()
|
||||
now := time.Now().UTC()
|
||||
result := s.db.Model(&models.OnboardingEmail{}).
|
||||
Where("tracking_id = ? AND opened_at IS NULL", trackingID).
|
||||
@@ -84,6 +87,7 @@ func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
|
||||
|
||||
// GetEmailHistory gets all onboarding emails for a specific user
|
||||
func (s *OnboardingEmailService) GetEmailHistory(userID uint) ([]models.OnboardingEmail, error) {
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.FindByUserID()
|
||||
var emails []models.OnboardingEmail
|
||||
if err := s.db.Where("user_id = ?", userID).Order("sent_at DESC").Find(&emails).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -105,11 +109,13 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.CountAll()
|
||||
// Count total
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.FindAllPaginated()
|
||||
// Get paginated results with user info
|
||||
if err := s.db.Preload("User").
|
||||
Order("sent_at DESC").
|
||||
@@ -126,6 +132,7 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
|
||||
func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) {
|
||||
stats := &OnboardingEmailStats{}
|
||||
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.GetStats()
|
||||
// No residence email stats
|
||||
var noResTotal, noResOpened int64
|
||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||
@@ -159,6 +166,7 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// TODO(hardening): Move to internal/dto/responses/onboarding_email.go
|
||||
// OnboardingEmailStats represents statistics about onboarding emails
|
||||
type OnboardingEmailStats struct {
|
||||
NoResidenceTotal int64 `json:"no_residence_total"`
|
||||
@@ -173,6 +181,7 @@ func (s *OnboardingEmailService) UsersNeedingNoResidenceEmail() ([]models.User,
|
||||
|
||||
twoDaysAgo := time.Now().UTC().AddDate(0, 0, -2)
|
||||
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutResidence()
|
||||
// Find users who:
|
||||
// 1. Are verified
|
||||
// 2. Registered 2+ days ago
|
||||
@@ -201,6 +210,7 @@ func (s *OnboardingEmailService) UsersNeedingNoTasksEmail() ([]models.User, erro
|
||||
|
||||
fiveDaysAgo := time.Now().UTC().AddDate(0, 0, -5)
|
||||
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutTasks()
|
||||
// Find users who:
|
||||
// 1. Are verified
|
||||
// 2. Have at least one residence
|
||||
@@ -325,6 +335,7 @@ func (s *OnboardingEmailService) sendNoTasksEmail(user models.User) error {
|
||||
// SendOnboardingEmailToUser manually sends an onboarding email to a specific user
|
||||
// This is used by admin to force-send emails regardless of eligibility criteria
|
||||
func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailType models.OnboardingEmailType) error {
|
||||
// TODO(hardening): Replace with UserRepository.FindByID() (inject UserRepository)
|
||||
// Load the user
|
||||
var user models.User
|
||||
if err := s.db.First(&user, userID).Error; err != nil {
|
||||
@@ -362,6 +373,7 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp
|
||||
// If already sent before, delete the old record first to allow re-recording
|
||||
// This allows admins to "resend" emails while still tracking them
|
||||
if alreadySent {
|
||||
// TODO(hardening): Replace with OnboardingEmailRepository.DeleteByUserAndType()
|
||||
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
|
||||
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jung-kurt/gofpdf"
|
||||
"github.com/go-pdf/fpdf"
|
||||
)
|
||||
|
||||
// PDFService handles PDF generation
|
||||
@@ -18,7 +18,7 @@ func NewPDFService() *PDFService {
|
||||
|
||||
// GenerateTasksReportPDF generates a PDF report from task report data
|
||||
func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte, error) {
|
||||
pdf := gofpdf.New("P", "mm", "A4", "")
|
||||
pdf := fpdf.New("P", "mm", "A4", "")
|
||||
pdf.SetMargins(15, 15, 15)
|
||||
pdf.AddPage()
|
||||
|
||||
|
||||
@@ -133,14 +133,16 @@ func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*respons
|
||||
}
|
||||
}
|
||||
|
||||
// Attach completion summaries (honeycomb grid data)
|
||||
for i := range residenceResponses {
|
||||
summary, err := s.taskRepo.GetCompletionSummary(residenceResponses[i].ID, now, 10)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Uint("residence_id", residenceResponses[i].ID).Msg("Failed to fetch completion summary")
|
||||
continue
|
||||
// P-01: Batch fetch completion summaries in 2 queries total instead of 2*N
|
||||
summaries, err := s.taskRepo.GetBatchCompletionSummaries(residenceIDs, now, 10)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to fetch batch completion summaries")
|
||||
} else {
|
||||
for i := range residenceResponses {
|
||||
if summary, ok := summaries[residenceResponses[i].ID]; ok {
|
||||
residenceResponses[i].CompletionSummary = summary
|
||||
}
|
||||
}
|
||||
residenceResponses[i].CompletionSummary = summary
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -17,7 +18,8 @@ import (
|
||||
|
||||
// StorageService handles file uploads to local filesystem
|
||||
type StorageService struct {
|
||||
cfg *config.StorageConfig
|
||||
cfg *config.StorageConfig
|
||||
allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups
|
||||
}
|
||||
|
||||
// UploadResult contains information about an uploaded file
|
||||
@@ -44,9 +46,18 @@ func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) {
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("upload_dir", cfg.UploadDir).Msg("Storage service initialized")
|
||||
// P-12: Parse AllowedTypes once at initialization for O(1) lookups
|
||||
allowedTypes := make(map[string]struct{})
|
||||
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
|
||||
trimmed := strings.TrimSpace(t)
|
||||
if trimmed != "" {
|
||||
allowedTypes[trimmed] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return &StorageService{cfg: cfg}, nil
|
||||
log.Info().Str("upload_dir", cfg.UploadDir).Int("allowed_types", len(allowedTypes)).Msg("Storage service initialized")
|
||||
|
||||
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}, nil
|
||||
}
|
||||
|
||||
// Upload saves a file to the local filesystem
|
||||
@@ -56,17 +67,47 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
||||
return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize)
|
||||
}
|
||||
|
||||
// Get MIME type
|
||||
mimeType := file.Header.Get("Content-Type")
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
// Get claimed MIME type from header
|
||||
claimedMimeType := file.Header.Get("Content-Type")
|
||||
if claimedMimeType == "" {
|
||||
claimedMimeType = "application/octet-stream"
|
||||
}
|
||||
|
||||
// Validate MIME type
|
||||
// S-09: Detect actual content type from file bytes to prevent disguised uploads
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Read the first 512 bytes for content type detection
|
||||
sniffBuf := make([]byte, 512)
|
||||
n, err := src.Read(sniffBuf)
|
||||
if err != nil && n == 0 {
|
||||
return nil, fmt.Errorf("failed to read file for content type detection: %w", err)
|
||||
}
|
||||
detectedMimeType := http.DetectContentType(sniffBuf[:n])
|
||||
|
||||
// Validate that the detected type matches the claimed type (at the category level)
|
||||
// Allow application/octet-stream from detection since DetectContentType may not
|
||||
// recognize all valid types, but the claimed type must still be in our allowed list
|
||||
if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) {
|
||||
return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType)
|
||||
}
|
||||
|
||||
// Use the claimed MIME type (which is more specific) if it's allowed
|
||||
mimeType := claimedMimeType
|
||||
|
||||
// Validate MIME type against allowed list
|
||||
if !s.isAllowedType(mimeType) {
|
||||
return nil, fmt.Errorf("file type %s is not allowed", mimeType)
|
||||
}
|
||||
|
||||
// Seek back to beginning after sniffing
|
||||
if _, err := src.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, fmt.Errorf("failed to seek file: %w", err)
|
||||
}
|
||||
|
||||
// Generate unique filename
|
||||
ext := filepath.Ext(file.Filename)
|
||||
if ext == "" {
|
||||
@@ -83,15 +124,11 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
||||
subdir = "completions"
|
||||
}
|
||||
|
||||
// Full path
|
||||
destPath := filepath.Join(s.cfg.UploadDir, subdir, newFilename)
|
||||
|
||||
// Open source file
|
||||
src, err := file.Open()
|
||||
// S-18: Sanitize path to prevent traversal attacks
|
||||
destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
return nil, fmt.Errorf("invalid upload path: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Create destination file
|
||||
dst, err := os.Create(destPath)
|
||||
@@ -131,19 +168,11 @@ func (s *StorageService) Delete(fileURL string) error {
|
||||
// Convert URL to file path
|
||||
relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL)
|
||||
relativePath = strings.TrimPrefix(relativePath, "/")
|
||||
fullPath := filepath.Join(s.cfg.UploadDir, relativePath)
|
||||
|
||||
// Security check: ensure path is within upload directory
|
||||
absUploadDir, err := filepath.Abs(s.cfg.UploadDir)
|
||||
// S-18: Use SafeResolvePath to prevent path traversal
|
||||
fullPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve upload directory: %w", err)
|
||||
}
|
||||
absFilePath, err := filepath.Abs(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve file path: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir {
|
||||
return fmt.Errorf("invalid file path")
|
||||
return fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Remove(fullPath); err != nil {
|
||||
@@ -157,15 +186,23 @@ func (s *StorageService) Delete(fileURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isAllowedType checks if the MIME type is in the allowed list
|
||||
// isAllowedType checks if the MIME type is in the allowed list.
|
||||
// P-12: Uses the pre-parsed allowedTypes map for O(1) lookups instead of
|
||||
// splitting the config string on every call.
|
||||
func (s *StorageService) isAllowedType(mimeType string) bool {
|
||||
allowed := strings.Split(s.cfg.AllowedTypes, ",")
|
||||
for _, t := range allowed {
|
||||
if strings.TrimSpace(t) == mimeType {
|
||||
return true
|
||||
}
|
||||
_, ok := s.allowedTypes[mimeType]
|
||||
return ok
|
||||
}
|
||||
|
||||
// mimeTypesCompatible checks if the claimed and detected MIME types are compatible.
|
||||
// Two MIME types are compatible if they share the same primary type (e.g., both "image/*").
|
||||
func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool {
|
||||
claimedParts := strings.SplitN(claimed, "/", 2)
|
||||
detectedParts := strings.SplitN(detected, "/", 2)
|
||||
if len(claimedParts) < 1 || len(detectedParts) < 1 {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
return claimedParts[0] == detectedParts[0]
|
||||
}
|
||||
|
||||
// getExtensionFromMimeType returns a file extension for common MIME types
|
||||
@@ -191,5 +228,12 @@ func (s *StorageService) GetUploadDir() string {
|
||||
// NewStorageServiceForTest creates a StorageService without creating directories.
|
||||
// This is intended only for unit tests that need a StorageService with a known config.
|
||||
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
|
||||
return &StorageService{cfg: cfg}
|
||||
allowedTypes := make(map[string]struct{})
|
||||
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
|
||||
trimmed := strings.TrimSpace(t)
|
||||
if trimmed != "" {
|
||||
allowedTypes[trimmed] = struct{}{}
|
||||
}
|
||||
}
|
||||
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package services
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
|
||||
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
||||
@@ -34,7 +34,8 @@ func NewStripeService(
|
||||
subscriptionRepo *repositories.SubscriptionRepository,
|
||||
userRepo *repositories.UserRepository,
|
||||
) *StripeService {
|
||||
key := os.Getenv("STRIPE_SECRET_KEY")
|
||||
// S-21: Use Viper config instead of os.Getenv for consistent configuration management
|
||||
key := viper.GetString("STRIPE_SECRET_KEY")
|
||||
if key == "" {
|
||||
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
|
||||
} else {
|
||||
@@ -42,7 +43,7 @@ func NewStripeService(
|
||||
log.Info().Msg("Stripe API key configured")
|
||||
}
|
||||
|
||||
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
|
||||
webhookSecret := viper.GetString("STRIPE_WEBHOOK_SECRET")
|
||||
if webhookSecret == "" {
|
||||
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
|
||||
}
|
||||
|
||||
@@ -202,18 +202,19 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
||||
}
|
||||
|
||||
// getUserUsage calculates current usage for a user.
|
||||
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
|
||||
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
|
||||
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
|
||||
residences, err := s.residenceRepo.FindOwnedByUser(userID)
|
||||
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
|
||||
propertiesCount, err := s.residenceRepo.CountByOwner(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
propertiesCount := int64(len(residences))
|
||||
|
||||
// Collect residence IDs for batch queries
|
||||
residenceIDs := make([]uint, len(residences))
|
||||
for i, r := range residences {
|
||||
residenceIDs[i] = r.ID
|
||||
// Still need residence IDs for batch counting tasks/contractors/documents
|
||||
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Count tasks, contractors, and documents across all residences with single queries each
|
||||
|
||||
@@ -130,7 +130,7 @@ func (s *TaskService) ListTasks(userID uint, daysThreshold int, now time.Time) (
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
resp := responses.NewKanbanBoardResponseForAll(board)
|
||||
resp := responses.NewKanbanBoardResponseForAll(board, now)
|
||||
// NOTE: Summary statistics are calculated client-side from kanban data
|
||||
return &resp, nil
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
resp := responses.NewKanbanBoardResponse(board, residenceID)
|
||||
resp := responses.NewKanbanBoardResponse(board, residenceID, now)
|
||||
// NOTE: Summary statistics are calculated client-side from kanban data
|
||||
return &resp, nil
|
||||
}
|
||||
@@ -601,8 +601,8 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
||||
task.InProgress = false
|
||||
}
|
||||
|
||||
// P1-5: Wrap completion creation and task update in a transaction.
|
||||
// If either operation fails, both are rolled back to prevent orphaned completions.
|
||||
// P1-5 + B-07: Wrap completion creation, task update, and image creation
|
||||
// in a single transaction for atomicity. If any operation fails, all are rolled back.
|
||||
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
||||
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
||||
return err
|
||||
@@ -610,6 +610,18 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
||||
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
|
||||
return err
|
||||
}
|
||||
// B-07: Create images inside the same transaction as completion
|
||||
for _, imageURL := range req.ImageURLs {
|
||||
if imageURL != "" {
|
||||
img := &models.TaskCompletionImage{
|
||||
CompletionID: completion.ID,
|
||||
ImageURL: imageURL,
|
||||
}
|
||||
if err := tx.Create(img).Error; err != nil {
|
||||
return fmt.Errorf("failed to create completion image: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
@@ -621,19 +633,6 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
||||
return nil, apperrors.Internal(txErr)
|
||||
}
|
||||
|
||||
// Create images if provided
|
||||
for _, imageURL := range req.ImageURLs {
|
||||
if imageURL != "" {
|
||||
img := &models.TaskCompletionImage{
|
||||
CompletionID: completion.ID,
|
||||
ImageURL: imageURL,
|
||||
}
|
||||
if err := s.taskRepo.CreateCompletionImage(img); err != nil {
|
||||
log.Error().Err(err).Uint("completion_id", completion.ID).Msg("Failed to create completion image")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reload completion with user info and images
|
||||
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
|
||||
if err != nil {
|
||||
@@ -663,8 +662,10 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QuickComplete creates a minimal task completion (for widget use)
|
||||
// Returns only success/error, no response body
|
||||
// QuickComplete creates a minimal task completion (for widget use).
|
||||
// LE-01: The entire operation (completion creation + task update) is wrapped in a
|
||||
// transaction for atomicity.
|
||||
// Returns only success/error, no response body.
|
||||
func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
||||
// Get the task
|
||||
task, err := s.taskRepo.FindByID(taskID)
|
||||
@@ -697,10 +698,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
||||
CompletedFromColumn: completedFromColumn,
|
||||
}
|
||||
|
||||
if err := s.taskRepo.CreateCompletion(completion); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Update next_due_date and in_progress based on frequency
|
||||
// Determine interval days: Custom frequency uses task.CustomIntervalDays, otherwise use frequency.Days
|
||||
// Note: Frequency is no longer preloaded for performance, so we load it separately if needed
|
||||
@@ -729,7 +726,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
||||
} else {
|
||||
// Recurring task - calculate next due date from completion date + interval
|
||||
nextDue := completedAt.AddDate(0, 0, *quickIntervalDays)
|
||||
// frequencyName was already set when loading frequency above
|
||||
log.Info().
|
||||
Uint("task_id", task.ID).
|
||||
Str("frequency_name", frequencyName).
|
||||
@@ -742,12 +738,23 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
||||
// Reset in_progress to false
|
||||
task.InProgress = false
|
||||
}
|
||||
if err := s.taskRepo.Update(task); err != nil {
|
||||
if errors.Is(err, repositories.ErrVersionConflict) {
|
||||
|
||||
// LE-01: Wrap completion creation and task update in a transaction for atomicity
|
||||
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
||||
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
if errors.Is(txErr, repositories.ErrVersionConflict) {
|
||||
return apperrors.Conflict("error.version_conflict")
|
||||
}
|
||||
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion")
|
||||
return apperrors.Internal(err) // Return error so caller knows the update failed
|
||||
log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task in QuickComplete")
|
||||
return apperrors.Internal(txErr)
|
||||
}
|
||||
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
|
||||
|
||||
@@ -813,8 +820,16 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
|
||||
// Send email notification (to everyone INCLUDING the person who completed it)
|
||||
// Check user's email notification preferences first
|
||||
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
|
||||
prefs, err := s.notificationService.GetPreferences(user.ID)
|
||||
if err != nil || (prefs != nil && prefs.EmailTaskCompleted) {
|
||||
prefs, prefsErr := s.notificationService.GetPreferences(user.ID)
|
||||
// LE-06: Log fail-open behavior when preferences cannot be loaded
|
||||
if prefsErr != nil {
|
||||
log.Warn().
|
||||
Err(prefsErr).
|
||||
Uint("user_id", user.ID).
|
||||
Uint("task_id", task.ID).
|
||||
Msg("Failed to load notification preferences, falling back to sending email (fail-open)")
|
||||
}
|
||||
if prefsErr != nil || (prefs != nil && prefs.EmailTaskCompleted) {
|
||||
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
|
||||
if err := s.emailService.SendTaskCompletedEmail(
|
||||
user.Email,
|
||||
|
||||
@@ -32,7 +32,8 @@ func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.User
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
var result []responses.UserSummary
|
||||
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
|
||||
result := make([]responses.UserSummary, 0, len(users))
|
||||
for _, u := range users {
|
||||
result = append(result, responses.UserSummary{
|
||||
ID: u.ID,
|
||||
@@ -72,7 +73,8 @@ func (s *UserService) ListProfilesInSharedResidences(userID uint) ([]responses.U
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
var result []responses.UserProfileSummary
|
||||
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
|
||||
result := make([]responses.UserProfileSummary, 0, len(profiles))
|
||||
for _, p := range profiles {
|
||||
result = append(result, responses.UserProfileSummary{
|
||||
ID: p.ID,
|
||||
|
||||
Reference in New Issue
Block a user