From 42a5533a56089fb03826b8b5f36d10a8e685df21 Mon Sep 17 00:00:00 2001 From: Trey t Date: Wed, 18 Mar 2026 23:14:13 -0500 Subject: [PATCH] Fix 113 hardening issues across entire Go backend Security: - Replace all binding: tags with validate: + c.Validate() in admin handlers - Add rate limiting to auth endpoints (login, register, password reset) - Add security headers (HSTS, XSS protection, nosniff, frame options) - Wire Google Pub/Sub token verification into webhook handler - Replace ParseUnverified with proper OIDC/JWKS key verification - Verify inner Apple JWS signatures in webhook handler - Add io.LimitReader (1MB) to all webhook body reads - Add ownership verification to file deletion - Move hardcoded admin credentials to env vars - Add uniqueIndex to User.Email - Hide ConfirmationCode from JSON serialization - Mask confirmation codes in admin responses - Use http.DetectContentType for upload validation - Fix path traversal in storage service - Replace os.Getenv with Viper in stripe service - Sanitize Redis URLs before logging - Separate DEBUG_FIXED_CODES from DEBUG flag - Reject weak SECRET_KEY in production - Add host check on /_next/* proxy routes - Use explicit localhost CORS origins in debug mode - Replace err.Error() with generic messages in all admin error responses Critical fixes: - Rewrite FCM to HTTP v1 API with OAuth 2.0 service account auth - Fix user_customuser -> auth_user table names in raw SQL - Fix dashboard verified query to use UserProfile model - Add escapeLikeWildcards() to prevent SQL wildcard injection Bug fixes: - Add bounds checks for days/expiring_soon query params (1-3650) - Add receipt_data/transaction_id empty-check to RestoreSubscription - Change Active bool -> *bool in device handler - Check all unchecked GORM/FindByIDWithProfile errors - Add validation for notification hour fields (0-23) - Add max=10000 validation on task description updates Transactions & data integrity: - Wrap registration flow in transaction - Wrap QuickComplete in transaction - Move image creation inside completion transaction - Wrap SetSpecialties in transaction - Wrap GetOrCreateToken in transaction - Wrap completion+image deletion in transaction Performance: - Batch completion summaries (2 queries vs 2N) - Reuse single http.Client in IAP validation - Cache dashboard counts (30s TTL) - Batch COUNT queries in admin user list - Add Limit(500) to document queries - Add reminder_stage+due_date filters to reminder queries - Parse AllowedTypes once at init - In-memory user cache in auth middleware (30s TTL) - Timezone change detection cache - Optimize P95 with per-endpoint sorted buffers - Replace crypto/md5 with hash/fnv for ETags Code quality: - Add sync.Once to all monitoring Stop()/Close() methods - Replace 8 fmt.Printf with zerolog in auth service - Log previously discarded errors - Standardize delete response shapes - Route hardcoded English through i18n - Remove FileURL from DocumentResponse (keep MediaURL only) - Thread user timezone through kanban board responses - Initialize empty slices to prevent null JSON - Extract shared field map for task Update/UpdateTx - Delete unused SoftDeleteModel, min(), formatCron, legacy handlers Worker & jobs: - Wire Asynq email infrastructure into worker - Register HandleReminderLogCleanup with daily 3AM cron - Use per-user timezone in HandleSmartReminder - Replace direct DB queries with repository calls - Delete legacy reminder handlers (~200 lines) - Delete unused task type constants Dependencies: - Replace archived jung-kurt/gofpdf with go-pdf/fpdf - Replace unmaintained gomail.v2 with wneessen/go-mail - Add TODO for Echo jwt v3 transitive dep removal Test infrastructure: - Fix MakeRequest/SeedLookupData error handling - Replace os.Exit(0) with t.Skip() in scope/consistency tests - Add 11 new FCM v1 tests Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/api/main.go | 2 +- cmd/worker/main.go | 26 +- go.mod | 13 +- go.sum | 18 +- internal/admin/dto/requests.go | 29 +- internal/admin/handlers/admin_user_handler.go | 32 +- .../handlers/apple_social_auth_handler.go | 11 +- internal/admin/handlers/auth_handler.go | 14 +- internal/admin/handlers/auth_token_handler.go | 4 +- internal/admin/handlers/completion_handler.go | 6 +- .../handlers/completion_image_handler.go | 51 ++- .../handlers/confirmation_code_handler.go | 17 +- internal/admin/handlers/contractor_handler.go | 8 +- internal/admin/handlers/dashboard_handler.go | 30 +- internal/admin/handlers/device_handler.go | 31 +- internal/admin/handlers/document_handler.go | 15 +- .../admin/handlers/document_image_handler.go | 20 +- .../admin/handlers/feature_benefit_handler.go | 15 +- .../admin/handlers/limitations_handler.go | 57 ++- internal/admin/handlers/lookup_handler.go | 70 +++- .../admin/handlers/notification_handler.go | 31 +- .../handlers/notification_prefs_handler.go | 4 +- internal/admin/handlers/onboarding_handler.go | 11 +- .../handlers/password_reset_code_handler.go | 4 +- internal/admin/handlers/promotion_handler.go | 19 +- internal/admin/handlers/residence_handler.go | 11 +- internal/admin/handlers/settings_handler.go | 84 ++--- internal/admin/handlers/share_code_handler.go | 6 +- .../admin/handlers/subscription_handler.go | 8 +- internal/admin/handlers/task_handler.go | 16 +- .../admin/handlers/task_template_handler.go | 25 +- internal/admin/handlers/user_handler.go | 97 ++++- .../admin/handlers/user_profile_handler.go | 11 +- internal/admin/routes.go | 19 +- internal/config/config.go | 309 +++++++++------- internal/database/database.go | 72 ++-- internal/dto/requests/task.go | 2 +- internal/dto/responses/document.go | 13 +- internal/dto/responses/task.go | 16 +- internal/handlers/contractor_handler.go | 4 +- internal/handlers/document_handler.go | 16 +- internal/handlers/notification_handler.go | 20 +- internal/handlers/subscription_handler.go | 14 +- .../handlers/subscription_webhook_handler.go | 194 +++++++++- internal/handlers/task_handler.go | 20 +- internal/handlers/upload_handler.go | 58 ++- internal/handlers/upload_handler_test.go | 21 +- internal/i18n/translations/en.json | 5 + internal/middleware/auth.go | 43 ++- internal/middleware/host_check.go | 40 ++ internal/middleware/rate_limit.go | 68 ++++ internal/middleware/timezone.go | 46 ++- internal/middleware/user_cache.go | 113 ++++++ internal/models/base.go | 6 - internal/models/notification.go | 8 +- internal/models/task.go | 82 ----- internal/models/task_test.go | 170 --------- internal/models/user.go | 4 +- internal/monitoring/collector.go | 8 +- internal/monitoring/handler.go | 24 +- internal/monitoring/middleware.go | 99 +++-- internal/monitoring/service.go | 24 +- internal/monitoring/writer.go | 19 +- internal/push/apns.go | 13 +- internal/push/client.go | 10 +- internal/push/fcm.go | 340 ++++++++++++----- internal/push/fcm_test.go | 344 +++++++++++------- internal/repositories/contractor_repo.go | 40 +- internal/repositories/document_repo.go | 2 +- internal/repositories/reminder_repo.go | 29 +- internal/repositories/residence_repo.go | 14 + internal/repositories/subscription_repo.go | 17 +- internal/repositories/task_repo.go | 212 ++++++++--- internal/repositories/user_repo.go | 44 ++- internal/router/router.go | 47 ++- internal/services/auth_service.go | 141 ++++--- internal/services/cache_service.go | 83 +++-- internal/services/email_service.go | 104 +++--- internal/services/file_ownership_service.go | 66 ++++ internal/services/iap_validation.go | 22 +- internal/services/notification_service.go | 37 +- internal/services/onboarding_email_service.go | 12 + internal/services/pdf_service.go | 4 +- internal/services/residence_service.go | 16 +- internal/services/storage_service.go | 112 ++++-- internal/services/stripe_service.go | 7 +- internal/services/subscription_service.go | 13 +- internal/services/task_service.go | 75 ++-- internal/services/user_service.go | 6 +- internal/task/consistency_test.go | 23 +- internal/task/scopes/scopes_test.go | 22 +- internal/testutil/testutil.go | 80 +++- internal/worker/jobs/handler.go | 286 +++------------ internal/worker/scheduler.go | 110 +----- pkg/utils/logger.go | 11 +- 95 files changed, 2892 insertions(+), 1783 deletions(-) create mode 100644 internal/middleware/host_check.go create mode 100644 internal/middleware/rate_limit.go create mode 100644 internal/middleware/user_cache.go create mode 100644 internal/services/file_ownership_service.go diff --git a/cmd/api/main.go b/cmd/api/main.go index e04c632..3cb0dd6 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -47,7 +47,7 @@ func main() { Int("db_port", cfg.Database.Port). Str("db_name", cfg.Database.Database). Str("db_user", cfg.Database.User). - Str("redis_url", cfg.Redis.URL). + Str("redis_url", config.MaskURLCredentials(cfg.Redis.URL)). Msg("Starting HoneyDue API server") // Connect to database (retry with backoff) diff --git a/cmd/worker/main.go b/cmd/worker/main.go index 56a881d..6cefc00 100644 --- a/cmd/worker/main.go +++ b/cmd/worker/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "os" "os/signal" "syscall" @@ -115,12 +114,6 @@ func main() { } } - // Check worker kill switch - if !cfg.Features.WorkerEnabled { - log.Warn().Msg("Worker disabled by FEATURE_WORKER_ENABLED=false, exiting") - return - } - // Create Asynq server srv := asynq.NewServer( redisOpt, @@ -151,6 +144,13 @@ func main() { mux.HandleFunc(jobs.TypeSendEmail, jobHandler.HandleSendEmail) mux.HandleFunc(jobs.TypeSendPush, jobHandler.HandleSendPush) mux.HandleFunc(jobs.TypeOnboardingEmails, jobHandler.HandleOnboardingEmails) + mux.HandleFunc(jobs.TypeReminderLogCleanup, jobHandler.HandleReminderLogCleanup) + + // Register email job handlers (welcome, verification, password reset, password changed) + if emailService != nil { + emailJobHandler := jobs.NewEmailJobHandler(emailService) + emailJobHandler.RegisterHandlers(mux) + } // Start scheduler for periodic tasks scheduler := asynq.NewScheduler(redisOpt, nil) @@ -177,6 +177,13 @@ func main() { } log.Info().Str("cron", "0 10 * * *").Msg("Registered onboarding emails job (runs daily at 10:00 AM UTC)") + // Schedule reminder log cleanup (runs daily at 3:00 AM UTC) + // Removes reminder logs older than 90 days to prevent table bloat + if _, err := scheduler.Register("0 3 * * *", asynq.NewTask(jobs.TypeReminderLogCleanup, nil)); err != nil { + log.Fatal().Err(err).Msg("Failed to register reminder log cleanup job") + } + log.Info().Str("cron", "0 3 * * *").Msg("Registered reminder log cleanup job (runs daily at 3:00 AM UTC)") + // Handle graceful shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) @@ -205,8 +212,3 @@ func main() { log.Info().Msg("Worker stopped") } - -// formatCron creates a cron expression for a specific hour (runs at minute 0) -func formatCron(hour int) string { - return fmt.Sprintf("0 %02d * * *", hour) -} diff --git a/go.mod b/go.mod index 208a84a..a71c349 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/treytartt/honeydue-api go 1.24.0 require ( + github.com/go-pdf/fpdf v0.9.0 github.com/go-playground/validator/v10 v10.23.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/hibiken/asynq v0.25.1 - github.com/jung-kurt/gofpdf v1.16.2 github.com/labstack/echo/v4 v4.11.4 github.com/nicksnyder/go-i18n/v2 v2.6.0 github.com/redis/go-redis/v9 v9.17.1 @@ -18,11 +18,14 @@ require ( github.com/sideshow/apns2 v0.25.0 github.com/spf13/viper v1.20.1 github.com/stretchr/testify v1.11.1 + github.com/stripe/stripe-go/v81 v81.4.0 + github.com/wneessen/go-mail v0.7.2 golang.org/x/crypto v0.45.0 golang.org/x/oauth2 v0.34.0 golang.org/x/text v0.31.0 + golang.org/x/time v0.14.0 google.golang.org/api v0.257.0 - gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df + gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.31.1 @@ -44,7 +47,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect; TODO(S-19): Pulled by echo/v4 middleware — upgrade Echo to v4.12+ which removes built-in JWT middleware (uses echo-jwt/v4 with jwt/v5 instead), eliminating this vulnerable transitive dep github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect @@ -71,7 +74,6 @@ require ( github.com/spf13/afero v1.14.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect - github.com/stripe/stripe-go/v81 v81.4.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect @@ -86,10 +88,7 @@ require ( golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect - golang.org/x/time v0.14.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect google.golang.org/grpc v1.77.0 // indirect google.golang.org/protobuf v1.36.10 // indirect - gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e11a261..b3bb2e0 100644 --- a/go.sum +++ b/go.sum @@ -8,7 +8,6 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20201120081800-1786d5ef83d4/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= -github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -36,6 +35,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-pdf/fpdf v0.9.0 h1:PPvSaUuo1iMi9KkaAn90NuKi+P4gwMedWPHhj8YlJQw= +github.com/go-pdf/fpdf v0.9.0/go.mod h1:oO8N111TkmKb9D7VvWGLvLJlaZUQVPM+6V42pp3iV4Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -83,9 +84,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= -github.com/jung-kurt/gofpdf v1.16.2 h1:jgbatWHfRlPYiK85qgevsZTHviWXKwB1TTiKdz5PtRc= -github.com/jung-kurt/gofpdf v1.16.2/go.mod h1:1hl7y57EsiPAkLbOwzpzqgx1A30nQCk/YmFV8S2vmK0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -110,8 +108,6 @@ github.com/nicksnyder/go-i18n/v2 v2.6.0 h1:C/m2NNWNiTB6SK4Ao8df5EWm3JETSTIGNXBpM github.com/nicksnyder/go-i18n/v2 v2.6.0/go.mod h1:88sRqr0C6OPyJn0/KRNaEz1uWorjxIKP7rUUcvycecE= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= -github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -126,7 +122,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7 github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= -github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= @@ -150,7 +145,6 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -168,6 +162,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8= +github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= @@ -189,7 +185,6 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20170512130425-ab89591268e0/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= @@ -213,7 +208,6 @@ golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= @@ -237,13 +231,9 @@ google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHh google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= -gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= -gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/admin/dto/requests.go b/internal/admin/dto/requests.go index fde6477..c320d18 100644 --- a/internal/admin/dto/requests.go +++ b/internal/admin/dto/requests.go @@ -1,6 +1,9 @@ package dto -import "github.com/treytartt/honeydue-api/internal/middleware" +import ( + "github.com/shopspring/decimal" + "github.com/treytartt/honeydue-api/internal/middleware" +) // PaginationParams holds pagination query parameters type PaginationParams struct { @@ -115,9 +118,9 @@ type UpdateResidenceRequest struct { YearBuilt *int `json:"year_built"` Description *string `json:"description"` PurchaseDate *string `json:"purchase_date"` - PurchasePrice *float64 `json:"purchase_price"` - IsActive *bool `json:"is_active"` - IsPrimary *bool `json:"is_primary"` + PurchasePrice *decimal.Decimal `json:"purchase_price"` + IsActive *bool `json:"is_active"` + IsPrimary *bool `json:"is_primary"` } // TaskFilters holds task-specific filter parameters @@ -144,8 +147,8 @@ type UpdateTaskRequest struct { InProgress *bool `json:"in_progress"` DueDate *string `json:"due_date"` NextDueDate *string `json:"next_due_date"` - EstimatedCost *float64 `json:"estimated_cost"` - ActualCost *float64 `json:"actual_cost"` + EstimatedCost *decimal.Decimal `json:"estimated_cost"` + ActualCost *decimal.Decimal `json:"actual_cost"` ContractorID *uint `json:"contractor_id"` ParentTaskID *uint `json:"parent_task_id"` IsCancelled *bool `json:"is_cancelled"` @@ -201,8 +204,8 @@ type UpdateDocumentRequest struct { MimeType *string `json:"mime_type" validate:"omitempty,max=100"` PurchaseDate *string `json:"purchase_date"` ExpiryDate *string `json:"expiry_date"` - PurchasePrice *float64 `json:"purchase_price"` - Vendor *string `json:"vendor" validate:"omitempty,max=200"` + PurchasePrice *decimal.Decimal `json:"purchase_price"` + Vendor *string `json:"vendor" validate:"omitempty,max=200"` SerialNumber *string `json:"serial_number" validate:"omitempty,max=100"` ModelNumber *string `json:"model_number" validate:"omitempty,max=100"` Provider *string `json:"provider" validate:"omitempty,max=200"` @@ -292,9 +295,9 @@ type CreateTaskRequest struct { FrequencyID *uint `json:"frequency_id"` InProgress bool `json:"in_progress"` AssignedToID *uint `json:"assigned_to_id"` - DueDate *string `json:"due_date"` - EstimatedCost *float64 `json:"estimated_cost"` - ContractorID *uint `json:"contractor_id"` + DueDate *string `json:"due_date"` + EstimatedCost *decimal.Decimal `json:"estimated_cost"` + ContractorID *uint `json:"contractor_id"` } // CreateContractorRequest for creating a new contractor @@ -328,8 +331,8 @@ type CreateDocumentRequest struct { MimeType string `json:"mime_type" validate:"max=100"` PurchaseDate *string `json:"purchase_date"` ExpiryDate *string `json:"expiry_date"` - PurchasePrice *float64 `json:"purchase_price"` - Vendor string `json:"vendor" validate:"max=200"` + PurchasePrice *decimal.Decimal `json:"purchase_price"` + Vendor string `json:"vendor" validate:"max=200"` SerialNumber string `json:"serial_number" validate:"max=100"` ModelNumber string `json:"model_number" validate:"max=100"` TaskID *uint `json:"task_id"` diff --git a/internal/admin/handlers/admin_user_handler.go b/internal/admin/handlers/admin_user_handler.go index ac4ae0e..91e0421 100644 --- a/internal/admin/handlers/admin_user_handler.go +++ b/internal/admin/handlers/admin_user_handler.go @@ -33,21 +33,21 @@ type AdminUserFilters struct { // CreateAdminUserRequest for creating a new admin user type CreateAdminUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=8"` - FirstName string `json:"first_name" binding:"max=100"` - LastName string `json:"last_name" binding:"max=100"` - Role string `json:"role" binding:"omitempty,oneof=admin super_admin"` + Email string `json:"email" validate:"required,email"` + Password string `json:"password" validate:"required,min=8"` + FirstName string `json:"first_name" validate:"max=100"` + LastName string `json:"last_name" validate:"max=100"` + Role string `json:"role" validate:"omitempty,oneof=admin super_admin"` IsActive *bool `json:"is_active"` } // UpdateAdminUserRequest for updating an admin user type UpdateAdminUserRequest struct { - Email *string `json:"email" binding:"omitempty,email"` - Password *string `json:"password" binding:"omitempty,min=8"` - FirstName *string `json:"first_name" binding:"omitempty,max=100"` - LastName *string `json:"last_name" binding:"omitempty,max=100"` - Role *string `json:"role" binding:"omitempty,oneof=admin super_admin"` + Email *string `json:"email" validate:"omitempty,email"` + Password *string `json:"password" validate:"omitempty,min=8"` + FirstName *string `json:"first_name" validate:"omitempty,max=100"` + LastName *string `json:"last_name" validate:"omitempty,max=100"` + Role *string `json:"role" validate:"omitempty,oneof=admin super_admin"` IsActive *bool `json:"is_active"` } @@ -55,7 +55,7 @@ type UpdateAdminUserRequest struct { func (h *AdminUserManagementHandler) List(c echo.Context) error { var filters AdminUserFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var adminUsers []models.AdminUser @@ -134,7 +134,10 @@ func (h *AdminUserManagementHandler) Create(c echo.Context) error { var req CreateAdminUserRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } // Check if email already exists @@ -199,7 +202,10 @@ func (h *AdminUserManagementHandler) Update(c echo.Context) error { var req UpdateAdminUserRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } if req.Email != nil { diff --git a/internal/admin/handlers/apple_social_auth_handler.go b/internal/admin/handlers/apple_social_auth_handler.go index 462f8c8..f98570d 100644 --- a/internal/admin/handlers/apple_social_auth_handler.go +++ b/internal/admin/handlers/apple_social_auth_handler.go @@ -44,7 +44,7 @@ type UpdateAppleSocialAuthRequest struct { func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var entries []models.AppleSocialAuth @@ -139,7 +139,7 @@ func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error { var req UpdateAppleSocialAuthRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.Email != nil { @@ -183,14 +183,15 @@ func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error { func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } - if err := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{}).Error; err != nil { + result := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{}) + if result.Error != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"}) } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": len(req.IDs)}) + return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": result.RowsAffected}) } // toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse diff --git a/internal/admin/handlers/auth_handler.go b/internal/admin/handlers/auth_handler.go index 5e18962..957a948 100644 --- a/internal/admin/handlers/auth_handler.go +++ b/internal/admin/handlers/auth_handler.go @@ -27,8 +27,8 @@ func NewAdminAuthHandler(adminRepo *repositories.AdminRepository, cfg *config.Co // LoginRequest represents the admin login request type LoginRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required"` + Email string `json:"email" validate:"required,email"` + Password string `json:"password" validate:"required"` } // LoginResponse represents the admin login response @@ -71,7 +71,10 @@ func NewAdminUserResponse(admin *models.AdminUser) AdminUserResponse { func (h *AdminAuthHandler) Login(c echo.Context) error { var req LoginRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: " + err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } // Find admin by email @@ -100,7 +103,10 @@ func (h *AdminAuthHandler) Login(c echo.Context) error { _ = h.adminRepo.UpdateLastLogin(admin.ID) // Refresh admin data after updating last login - admin, _ = h.adminRepo.FindByID(admin.ID) + admin, err = h.adminRepo.FindByID(admin.ID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to refresh admin data"}) + } return c.JSON(http.StatusOK, LoginResponse{ Token: token, diff --git a/internal/admin/handlers/auth_token_handler.go b/internal/admin/handlers/auth_token_handler.go index 0c938fe..1f056f1 100644 --- a/internal/admin/handlers/auth_token_handler.go +++ b/internal/admin/handlers/auth_token_handler.go @@ -34,7 +34,7 @@ type AuthTokenResponse struct { func (h *AdminAuthTokenHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var tokens []models.AuthToken @@ -132,7 +132,7 @@ func (h *AdminAuthTokenHandler) Delete(c echo.Context) error { func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{}) diff --git a/internal/admin/handlers/completion_handler.go b/internal/admin/handlers/completion_handler.go index 768a45f..30d87a5 100644 --- a/internal/admin/handlers/completion_handler.go +++ b/internal/admin/handlers/completion_handler.go @@ -58,7 +58,7 @@ type CompletionFilters struct { func (h *AdminCompletionHandler) List(c echo.Context) error { var filters CompletionFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var completions []models.TaskCompletion @@ -167,7 +167,7 @@ func (h *AdminCompletionHandler) Delete(c echo.Context) error { func (h *AdminCompletionHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletion{}) @@ -201,7 +201,7 @@ func (h *AdminCompletionHandler) Update(c echo.Context) error { var req UpdateCompletionRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.Notes != nil { diff --git a/internal/admin/handlers/completion_image_handler.go b/internal/admin/handlers/completion_image_handler.go index 5a05f9c..7feb42e 100644 --- a/internal/admin/handlers/completion_image_handler.go +++ b/internal/admin/handlers/completion_image_handler.go @@ -35,8 +35,8 @@ type AdminCompletionImageResponse struct { // CreateCompletionImageRequest represents the request to create a completion image type CreateCompletionImageRequest struct { - CompletionID uint `json:"completion_id" binding:"required"` - ImageURL string `json:"image_url" binding:"required"` + CompletionID uint `json:"completion_id" validate:"required"` + ImageURL string `json:"image_url" validate:"required"` Caption string `json:"caption"` } @@ -50,7 +50,7 @@ type UpdateCompletionImageRequest struct { func (h *AdminCompletionImageHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Optional completion_id filter @@ -91,10 +91,39 @@ func (h *AdminCompletionImageHandler) List(c echo.Context) error { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch completion images"}) } + // Batch-load completion+task info to avoid N+1 queries + completionIDs := make([]uint, 0, len(images)) + for _, img := range images { + completionIDs = append(completionIDs, img.CompletionID) + } + + completionMap := make(map[uint]models.TaskCompletion) + if len(completionIDs) > 0 { + var completions []models.TaskCompletion + h.db.Preload("Task").Where("id IN ?", completionIDs).Find(&completions) + for _, c := range completions { + completionMap[c.ID] = c + } + } + // Build response with task info responses := make([]AdminCompletionImageResponse, len(images)) for i, image := range images { - responses[i] = h.toResponse(&image) + response := AdminCompletionImageResponse{ + ID: image.ID, + CompletionID: image.CompletionID, + ImageURL: image.ImageURL, + Caption: image.Caption, + CreatedAt: image.CreatedAt.Format("2006-01-02T15:04:05Z"), + UpdatedAt: image.UpdatedAt.Format("2006-01-02T15:04:05Z"), + } + if comp, ok := completionMap[image.CompletionID]; ok { + response.TaskID = comp.TaskID + if comp.Task.ID != 0 { + response.TaskTitle = comp.Task.Title + } + } + responses[i] = response } return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage())) @@ -122,7 +151,10 @@ func (h *AdminCompletionImageHandler) Get(c echo.Context) error { func (h *AdminCompletionImageHandler) Create(c echo.Context) error { var req CreateCompletionImageRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } // Verify completion exists @@ -164,7 +196,7 @@ func (h *AdminCompletionImageHandler) Update(c echo.Context) error { var req UpdateCompletionImageRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.ImageURL != nil { @@ -207,14 +239,15 @@ func (h *AdminCompletionImageHandler) Delete(c echo.Context) error { func (h *AdminCompletionImageHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } - if err := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletionImage{}).Error; err != nil { + result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletionImage{}) + if result.Error != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete completion images"}) } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Completion images deleted successfully", "count": len(req.IDs)}) + return c.JSON(http.StatusOK, map[string]interface{}{"message": "Completion images deleted successfully", "count": result.RowsAffected}) } // toResponse converts a TaskCompletionImage model to AdminCompletionImageResponse diff --git a/internal/admin/handlers/confirmation_code_handler.go b/internal/admin/handlers/confirmation_code_handler.go index e09e0ea..758cca5 100644 --- a/internal/admin/handlers/confirmation_code_handler.go +++ b/internal/admin/handlers/confirmation_code_handler.go @@ -3,6 +3,7 @@ package handlers import ( "net/http" "strconv" + "strings" "github.com/labstack/echo/v4" "gorm.io/gorm" @@ -11,6 +12,14 @@ import ( "github.com/treytartt/honeydue-api/internal/models" ) +// maskCode masks a confirmation code, showing only the last 4 characters. +func maskCode(code string) string { + if len(code) <= 4 { + return strings.Repeat("*", len(code)) + } + return strings.Repeat("*", len(code)-4) + code[len(code)-4:] +} + // AdminConfirmationCodeHandler handles admin confirmation code management endpoints type AdminConfirmationCodeHandler struct { db *gorm.DB @@ -37,7 +46,7 @@ type ConfirmationCodeResponse struct { func (h *AdminConfirmationCodeHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var codes []models.ConfirmationCode @@ -79,7 +88,7 @@ func (h *AdminConfirmationCodeHandler) List(c echo.Context) error { UserID: code.UserID, Username: code.User.Username, Email: code.User.Email, - Code: code.Code, + Code: maskCode(code.Code), ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"), IsUsed: code.IsUsed, CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -109,7 +118,7 @@ func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error { UserID: code.UserID, Username: code.User.Username, Email: code.User.Email, - Code: code.Code, + Code: maskCode(code.Code), ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"), IsUsed: code.IsUsed, CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -141,7 +150,7 @@ func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error { func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{}) diff --git a/internal/admin/handlers/contractor_handler.go b/internal/admin/handlers/contractor_handler.go index 5c7f6fe..e7e8d33 100644 --- a/internal/admin/handlers/contractor_handler.go +++ b/internal/admin/handlers/contractor_handler.go @@ -25,7 +25,7 @@ func NewAdminContractorHandler(db *gorm.DB) *AdminContractorHandler { func (h *AdminContractorHandler) List(c echo.Context) error { var filters dto.ContractorFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var contractors []models.Contractor @@ -130,7 +130,7 @@ func (h *AdminContractorHandler) Update(c echo.Context) error { var req dto.UpdateContractorRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify residence if changing @@ -213,7 +213,7 @@ func (h *AdminContractorHandler) Update(c echo.Context) error { func (h *AdminContractorHandler) Create(c echo.Context) error { var req dto.CreateContractorRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify residence exists if provided @@ -290,7 +290,7 @@ func (h *AdminContractorHandler) Delete(c echo.Context) error { func (h *AdminContractorHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Soft delete - deactivate all diff --git a/internal/admin/handlers/dashboard_handler.go b/internal/admin/handlers/dashboard_handler.go index 8cc16f8..fcb5e67 100644 --- a/internal/admin/handlers/dashboard_handler.go +++ b/internal/admin/handlers/dashboard_handler.go @@ -2,6 +2,7 @@ package handlers import ( "net/http" + "sync" "time" "github.com/labstack/echo/v4" @@ -11,6 +12,18 @@ import ( "github.com/treytartt/honeydue-api/internal/task/scopes" ) +// dashboardCache provides short-lived in-memory caching for dashboard stats +// to avoid expensive COUNT queries on every request. +type dashboardCache struct { + mu sync.RWMutex + stats *DashboardStats + expiresAt time.Time +} + +var dashCache = &dashboardCache{} + +const dashboardCacheTTL = 30 * time.Second + // AdminDashboardHandler handles admin dashboard endpoints type AdminDashboardHandler struct { db *gorm.DB @@ -94,6 +107,15 @@ type SubscriptionStats struct { // GetStats handles GET /api/admin/dashboard/stats func (h *AdminDashboardHandler) GetStats(c echo.Context) error { + // Check cache first + dashCache.mu.RLock() + if dashCache.stats != nil && time.Now().Before(dashCache.expiresAt) { + cached := *dashCache.stats + dashCache.mu.RUnlock() + return c.JSON(http.StatusOK, cached) + } + dashCache.mu.RUnlock() + stats := DashboardStats{} now := time.Now() thirtyDaysAgo := now.AddDate(0, 0, -30) @@ -101,7 +123,7 @@ func (h *AdminDashboardHandler) GetStats(c echo.Context) error { // User stats h.db.Model(&models.User{}).Count(&stats.Users.Total) h.db.Model(&models.User{}).Where("is_active = ?", true).Count(&stats.Users.Active) - h.db.Model(&models.User{}).Where("verified = ?", true).Count(&stats.Users.Verified) + h.db.Model(&models.UserProfile{}).Where("verified = ?", true).Count(&stats.Users.Verified) h.db.Model(&models.User{}).Where("date_joined >= ?", thirtyDaysAgo).Count(&stats.Users.New30d) // Residence stats @@ -164,5 +186,11 @@ func (h *AdminDashboardHandler) GetStats(c echo.Context) error { h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&stats.Subscriptions.Premium) h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&stats.Subscriptions.Pro) + // Cache the result + dashCache.mu.Lock() + dashCache.stats = &stats + dashCache.expiresAt = time.Now().Add(dashboardCacheTTL) + dashCache.mu.Unlock() + return c.JSON(http.StatusOK, stats) } diff --git a/internal/admin/handlers/device_handler.go b/internal/admin/handlers/device_handler.go index 160574a..5a8eb2f 100644 --- a/internal/admin/handlers/device_handler.go +++ b/internal/admin/handlers/device_handler.go @@ -50,7 +50,7 @@ type GCMDeviceResponse struct { func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var devices []models.APNSDevice @@ -106,7 +106,7 @@ func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error { func (h *AdminDeviceHandler) ListGCM(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var devices []models.GCMDevice @@ -174,13 +174,15 @@ func (h *AdminDeviceHandler) UpdateAPNS(c echo.Context) error { } var req struct { - Active bool `json:"active"` + Active *bool `json:"active"` } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } - device.Active = req.Active + if req.Active != nil { + device.Active = *req.Active + } if err := h.db.Save(&device).Error; err != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"}) } @@ -204,13 +206,15 @@ func (h *AdminDeviceHandler) UpdateGCM(c echo.Context) error { } var req struct { - Active bool `json:"active"` + Active *bool `json:"active"` } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } - device.Active = req.Active + if req.Active != nil { + device.Active = *req.Active + } if err := h.db.Save(&device).Error; err != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"}) } @@ -260,7 +264,7 @@ func (h *AdminDeviceHandler) DeleteGCM(c echo.Context) error { func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("id IN ?", req.IDs).Delete(&models.APNSDevice{}) @@ -275,7 +279,7 @@ func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error { func (h *AdminDeviceHandler) BulkDeleteGCM(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("id IN ?", req.IDs).Delete(&models.GCMDevice{}) @@ -307,10 +311,3 @@ func (h *AdminDeviceHandler) GetStats(c echo.Context) error { "total": apnsTotal + gcmTotal, }) } - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/internal/admin/handlers/document_handler.go b/internal/admin/handlers/document_handler.go index 3ceea50..2407c32 100644 --- a/internal/admin/handlers/document_handler.go +++ b/internal/admin/handlers/document_handler.go @@ -6,7 +6,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/shopspring/decimal" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/admin/dto" @@ -27,7 +26,7 @@ func NewAdminDocumentHandler(db *gorm.DB) *AdminDocumentHandler { func (h *AdminDocumentHandler) List(c echo.Context) error { var filters dto.DocumentFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var documents []models.Document @@ -132,7 +131,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error { var req dto.UpdateDocumentRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify residence if changing @@ -183,8 +182,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error { } } if req.PurchasePrice != nil { - d := decimal.NewFromFloat(*req.PurchasePrice) - document.PurchasePrice = &d + document.PurchasePrice = req.PurchasePrice } if req.Vendor != nil { document.Vendor = *req.Vendor @@ -232,7 +230,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error { func (h *AdminDocumentHandler) Create(c echo.Context) error { var req dto.CreateDocumentRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify residence exists @@ -282,8 +280,7 @@ func (h *AdminDocumentHandler) Create(c echo.Context) error { } } if req.PurchasePrice != nil { - d := decimal.NewFromFloat(*req.PurchasePrice) - document.PurchasePrice = &d + document.PurchasePrice = req.PurchasePrice } if err := h.db.Create(&document).Error; err != nil { @@ -322,7 +319,7 @@ func (h *AdminDocumentHandler) Delete(c echo.Context) error { func (h *AdminDocumentHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Soft delete - deactivate all diff --git a/internal/admin/handlers/document_image_handler.go b/internal/admin/handlers/document_image_handler.go index c8f5773..a5edc19 100644 --- a/internal/admin/handlers/document_image_handler.go +++ b/internal/admin/handlers/document_image_handler.go @@ -36,8 +36,8 @@ type DocumentImageResponse struct { // CreateDocumentImageRequest represents the request to create a document image type CreateDocumentImageRequest struct { - DocumentID uint `json:"document_id" binding:"required"` - ImageURL string `json:"image_url" binding:"required"` + DocumentID uint `json:"document_id" validate:"required"` + ImageURL string `json:"image_url" validate:"required"` Caption string `json:"caption"` } @@ -51,7 +51,7 @@ type UpdateDocumentImageRequest struct { func (h *AdminDocumentImageHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Optional document_id filter @@ -123,7 +123,10 @@ func (h *AdminDocumentImageHandler) Get(c echo.Context) error { func (h *AdminDocumentImageHandler) Create(c echo.Context) error { var req CreateDocumentImageRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } // Verify document exists @@ -165,7 +168,7 @@ func (h *AdminDocumentImageHandler) Update(c echo.Context) error { var req UpdateDocumentImageRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.ImageURL != nil { @@ -208,14 +211,15 @@ func (h *AdminDocumentImageHandler) Delete(c echo.Context) error { func (h *AdminDocumentImageHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } - if err := h.db.Where("id IN ?", req.IDs).Delete(&models.DocumentImage{}).Error; err != nil { + result := h.db.Where("id IN ?", req.IDs).Delete(&models.DocumentImage{}) + if result.Error != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"}) } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document images deleted successfully", "count": len(req.IDs)}) + return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document images deleted successfully", "count": result.RowsAffected}) } // toResponse converts a DocumentImage model to DocumentImageResponse diff --git a/internal/admin/handlers/feature_benefit_handler.go b/internal/admin/handlers/feature_benefit_handler.go index da96980..8bc7f9d 100644 --- a/internal/admin/handlers/feature_benefit_handler.go +++ b/internal/admin/handlers/feature_benefit_handler.go @@ -37,7 +37,7 @@ type FeatureBenefitResponse struct { func (h *AdminFeatureBenefitHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var benefits []models.FeatureBenefit @@ -112,15 +112,18 @@ func (h *AdminFeatureBenefitHandler) Get(c echo.Context) error { // Create handles POST /api/admin/feature-benefits func (h *AdminFeatureBenefitHandler) Create(c echo.Context) error { var req struct { - FeatureName string `json:"feature_name" binding:"required"` - FreeTierText string `json:"free_tier_text" binding:"required"` - ProTierText string `json:"pro_tier_text" binding:"required"` + FeatureName string `json:"feature_name" validate:"required"` + FreeTierText string `json:"free_tier_text" validate:"required"` + ProTierText string `json:"pro_tier_text" validate:"required"` DisplayOrder int `json:"display_order"` IsActive *bool `json:"is_active"` } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } benefit := models.FeatureBenefit{ @@ -175,7 +178,7 @@ func (h *AdminFeatureBenefitHandler) Update(c echo.Context) error { } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.FeatureName != nil { diff --git a/internal/admin/handlers/limitations_handler.go b/internal/admin/handlers/limitations_handler.go index 4627a52..9e3ec0a 100644 --- a/internal/admin/handlers/limitations_handler.go +++ b/internal/admin/handlers/limitations_handler.go @@ -34,7 +34,9 @@ func (h *AdminLimitationsHandler) GetSettings(c echo.Context) error { if err == gorm.ErrRecordNotFound { // Create default settings settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false} - h.db.Create(&settings) + if err := h.db.Create(&settings).Error; err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default settings"}) + } } else { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"}) } @@ -54,7 +56,7 @@ type UpdateLimitationsSettingsRequest struct { func (h *AdminLimitationsHandler) UpdateSettings(c echo.Context) error { var req UpdateLimitationsSettingsRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var settings models.SubscriptionSettings @@ -117,8 +119,12 @@ func (h *AdminLimitationsHandler) ListTierLimits(c echo.Context) error { if len(limits) == 0 { freeLimits := models.GetDefaultFreeLimits() proLimits := models.GetDefaultProLimits() - h.db.Create(&freeLimits) - h.db.Create(&proLimits) + if err := h.db.Create(&freeLimits).Error; err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default free tier limits"}) + } + if err := h.db.Create(&proLimits).Error; err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default pro tier limits"}) + } limits = []models.TierLimits{freeLimits, proLimits} } @@ -149,7 +155,9 @@ func (h *AdminLimitationsHandler) GetTierLimits(c echo.Context) error { } else { limits = models.GetDefaultProLimits() } - h.db.Create(&limits) + if err := h.db.Create(&limits).Error; err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default tier limits"}) + } } else { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch tier limits"}) } @@ -175,7 +183,7 @@ func (h *AdminLimitationsHandler) UpdateTierLimits(c echo.Context) error { var req UpdateTierLimitsRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var limits models.TierLimits @@ -188,13 +196,23 @@ func (h *AdminLimitationsHandler) UpdateTierLimits(c echo.Context) error { } } - // Update fields - note: we need to handle nil vs zero difference - // A nil pointer in the request means "don't change" - // The actual limit value can be nil (unlimited) or a number - limits.PropertiesLimit = req.PropertiesLimit - limits.TasksLimit = req.TasksLimit - limits.ContractorsLimit = req.ContractorsLimit - limits.DocumentsLimit = req.DocumentsLimit + // Update fields only when explicitly provided in the request body. + // JSON unmarshaling sets *int to nil when the key is absent, and to + // a non-nil *int (possibly pointing to 0) when the key is present. + // We rely on Bind populating these before calling this handler, so + // a nil pointer here means "don't change". + if req.PropertiesLimit != nil { + limits.PropertiesLimit = req.PropertiesLimit + } + if req.TasksLimit != nil { + limits.TasksLimit = req.TasksLimit + } + if req.ContractorsLimit != nil { + limits.ContractorsLimit = req.ContractorsLimit + } + if req.DocumentsLimit != nil { + limits.DocumentsLimit = req.DocumentsLimit + } if err := h.db.Save(&limits).Error; err != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update tier limits"}) @@ -297,9 +315,9 @@ func (h *AdminLimitationsHandler) GetUpgradeTrigger(c echo.Context) error { // CreateUpgradeTriggerRequest represents the create request type CreateUpgradeTriggerRequest struct { - TriggerKey string `json:"trigger_key" binding:"required"` - Title string `json:"title" binding:"required"` - Message string `json:"message" binding:"required"` + TriggerKey string `json:"trigger_key" validate:"required"` + Title string `json:"title" validate:"required"` + Message string `json:"message" validate:"required"` PromoHTML string `json:"promo_html"` ButtonText string `json:"button_text"` IsActive *bool `json:"is_active"` @@ -309,7 +327,10 @@ type CreateUpgradeTriggerRequest struct { func (h *AdminLimitationsHandler) CreateUpgradeTrigger(c echo.Context) error { var req CreateUpgradeTriggerRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } // Validate trigger key @@ -380,7 +401,7 @@ func (h *AdminLimitationsHandler) UpdateUpgradeTrigger(c echo.Context) error { var req UpdateUpgradeTriggerRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.TriggerKey != nil { diff --git a/internal/admin/handlers/lookup_handler.go b/internal/admin/handlers/lookup_handler.go index b76101e..9050a4c 100644 --- a/internal/admin/handlers/lookup_handler.go +++ b/internal/admin/handlers/lookup_handler.go @@ -162,10 +162,10 @@ type TaskCategoryResponse struct { } type CreateUpdateCategoryRequest struct { - Name string `json:"name" binding:"required,max=50"` + Name string `json:"name" validate:"required,max=50"` Description string `json:"description"` - Icon string `json:"icon" binding:"max=50"` - Color string `json:"color" binding:"max=7"` + Icon string `json:"icon" validate:"max=50"` + Color string `json:"color" validate:"max=7"` DisplayOrder *int `json:"display_order"` } @@ -193,7 +193,10 @@ func (h *AdminLookupHandler) ListCategories(c echo.Context) error { func (h *AdminLookupHandler) CreateCategory(c echo.Context) error { var req CreateUpdateCategoryRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } category := models.TaskCategory{ @@ -239,7 +242,10 @@ func (h *AdminLookupHandler) UpdateCategory(c echo.Context) error { var req CreateUpdateCategoryRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } category.Name = req.Name @@ -301,9 +307,9 @@ type TaskPriorityResponse struct { } type CreateUpdatePriorityRequest struct { - Name string `json:"name" binding:"required,max=20"` - Level int `json:"level" binding:"required,min=1,max=10"` - Color string `json:"color" binding:"max=7"` + Name string `json:"name" validate:"required,max=20"` + Level int `json:"level" validate:"required,min=1,max=10"` + Color string `json:"color" validate:"max=7"` DisplayOrder *int `json:"display_order"` } @@ -330,7 +336,10 @@ func (h *AdminLookupHandler) ListPriorities(c echo.Context) error { func (h *AdminLookupHandler) CreatePriority(c echo.Context) error { var req CreateUpdatePriorityRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } priority := models.TaskPriority{ @@ -374,7 +383,10 @@ func (h *AdminLookupHandler) UpdatePriority(c echo.Context) error { var req CreateUpdatePriorityRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } priority.Name = req.Name @@ -434,7 +446,7 @@ type TaskFrequencyResponse struct { } type CreateUpdateFrequencyRequest struct { - Name string `json:"name" binding:"required,max=20"` + Name string `json:"name" validate:"required,max=20"` Days *int `json:"days"` DisplayOrder *int `json:"display_order"` } @@ -480,7 +492,10 @@ func (h *AdminLookupHandler) ListFrequencies(c echo.Context) error { func (h *AdminLookupHandler) CreateFrequency(c echo.Context) error { var req CreateUpdateFrequencyRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } frequency := models.TaskFrequency{ @@ -528,7 +543,10 @@ func (h *AdminLookupHandler) UpdateFrequency(c echo.Context) error { var req CreateUpdateFrequencyRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } frequency.Name = req.Name @@ -588,7 +606,7 @@ type ResidenceTypeResponse struct { } type CreateUpdateResidenceTypeRequest struct { - Name string `json:"name" binding:"required,max=20"` + Name string `json:"name" validate:"required,max=20"` } func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error { @@ -611,7 +629,10 @@ func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error { func (h *AdminLookupHandler) CreateResidenceType(c echo.Context) error { var req CreateUpdateResidenceTypeRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } residenceType := models.ResidenceType{Name: req.Name} @@ -644,7 +665,10 @@ func (h *AdminLookupHandler) UpdateResidenceType(c echo.Context) error { var req CreateUpdateResidenceTypeRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } residenceType.Name = req.Name @@ -694,9 +718,9 @@ type ContractorSpecialtyResponse struct { } type CreateUpdateSpecialtyRequest struct { - Name string `json:"name" binding:"required,max=50"` + Name string `json:"name" validate:"required,max=50"` Description string `json:"description"` - Icon string `json:"icon" binding:"max=50"` + Icon string `json:"icon" validate:"max=50"` DisplayOrder *int `json:"display_order"` } @@ -723,7 +747,10 @@ func (h *AdminLookupHandler) ListSpecialties(c echo.Context) error { func (h *AdminLookupHandler) CreateSpecialty(c echo.Context) error { var req CreateUpdateSpecialtyRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } specialty := models.ContractorSpecialty{ @@ -767,7 +794,10 @@ func (h *AdminLookupHandler) UpdateSpecialty(c echo.Context) error { var req CreateUpdateSpecialtyRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } specialty.Name = req.Name diff --git a/internal/admin/handlers/notification_handler.go b/internal/admin/handlers/notification_handler.go index bea8886..0da419b 100644 --- a/internal/admin/handlers/notification_handler.go +++ b/internal/admin/handlers/notification_handler.go @@ -8,6 +8,7 @@ import ( "time" "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/admin/dto" @@ -36,7 +37,7 @@ func NewAdminNotificationHandler(db *gorm.DB, emailService *services.EmailServic func (h *AdminNotificationHandler) List(c echo.Context) error { var filters dto.NotificationFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var notifications []models.Notification @@ -151,7 +152,7 @@ func (h *AdminNotificationHandler) Update(c echo.Context) error { var req dto.UpdateNotificationRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } now := time.Now().UTC() @@ -235,7 +236,7 @@ func (h *AdminNotificationHandler) toNotificationDetailResponse(notif *models.No func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error { var req dto.SendTestNotificationRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify user exists @@ -294,13 +295,10 @@ func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error { err := h.pushClient.SendToAll(ctx, iosTokens, androidTokens, req.Title, req.Body, pushData) if err != nil { - // Update notification with error - h.db.Model(¬ification).Updates(map[string]interface{}{ - "error": err.Error(), - }) + // Log the real error for debugging + log.Error().Err(err).Uint("notification_id", notification.ID).Msg("Failed to send push notification") return c.JSON(http.StatusInternalServerError, map[string]interface{}{ - "error": "Failed to send push notification", - "details": err.Error(), + "error": "Failed to send push notification", }) } } else { @@ -327,7 +325,7 @@ func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error { func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error { var req dto.SendTestEmailRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify user exists @@ -369,9 +367,9 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error { err := h.emailService.SendEmail(user.Email, req.Subject, htmlBody, req.Body) if err != nil { + log.Error().Err(err).Str("email", user.Email).Msg("Failed to send test email") return c.JSON(http.StatusInternalServerError, map[string]interface{}{ - "error": "Failed to send email", - "details": err.Error(), + "error": "Failed to send email", }) } @@ -384,11 +382,14 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error { // SendPostVerificationEmail handles POST /api/admin/emails/send-post-verification func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) error { var req struct { - UserID uint `json:"user_id" binding:"required"` + UserID uint `json:"user_id" validate:"required"` } if err := c.Bind(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "user_id is required"}) } + if err := c.Validate(&req); err != nil { + return err + } // Verify user exists var user models.User @@ -410,9 +411,9 @@ func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) err err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName) if err != nil { + log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email") return c.JSON(http.StatusInternalServerError, map[string]interface{}{ - "error": "Failed to send email", - "details": err.Error(), + "error": "Failed to send email", }) } diff --git a/internal/admin/handlers/notification_prefs_handler.go b/internal/admin/handlers/notification_prefs_handler.go index 273e4d4..0231650 100644 --- a/internal/admin/handlers/notification_prefs_handler.go +++ b/internal/admin/handlers/notification_prefs_handler.go @@ -55,7 +55,7 @@ type NotificationPrefResponse struct { func (h *AdminNotificationPrefsHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var prefs []models.NotificationPreference @@ -212,7 +212,7 @@ func (h *AdminNotificationPrefsHandler) Update(c echo.Context) error { var req UpdateNotificationPrefRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Apply updates diff --git a/internal/admin/handlers/onboarding_handler.go b/internal/admin/handlers/onboarding_handler.go index 61927f2..d2daf50 100644 --- a/internal/admin/handlers/onboarding_handler.go +++ b/internal/admin/handlers/onboarding_handler.go @@ -240,7 +240,7 @@ func (h *AdminOnboardingHandler) Delete(c echo.Context) error { // DELETE /api/admin/onboarding-emails/bulk func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error { var req struct { - IDs []uint `json:"ids" binding:"required"` + IDs []uint `json:"ids" validate:"required"` } if err := c.Bind(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request"}) @@ -263,8 +263,8 @@ func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error { // SendOnboardingEmailRequest represents a request to send an onboarding email type SendOnboardingEmailRequest struct { - UserID uint `json:"user_id" binding:"required"` - EmailType string `json:"email_type" binding:"required"` + UserID uint `json:"user_id" validate:"required"` + EmailType string `json:"email_type" validate:"required"` } // Send sends an onboarding email to a specific user @@ -278,6 +278,9 @@ func (h *AdminOnboardingHandler) Send(c echo.Context) error { if err := c.Bind(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: user_id and email_type are required"}) } + if err := c.Validate(&req); err != nil { + return err + } // Validate email type var emailType models.OnboardingEmailType @@ -301,7 +304,7 @@ func (h *AdminOnboardingHandler) Send(c echo.Context) error { // Send the email if err := h.onboardingService.SendOnboardingEmailToUser(req.UserID, emailType); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to send onboarding email"}) } return c.JSON(http.StatusOK, map[string]interface{}{ diff --git a/internal/admin/handlers/password_reset_code_handler.go b/internal/admin/handlers/password_reset_code_handler.go index c99c4bc..5ac139a 100644 --- a/internal/admin/handlers/password_reset_code_handler.go +++ b/internal/admin/handlers/password_reset_code_handler.go @@ -39,7 +39,7 @@ type PasswordResetCodeResponse struct { func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var codes []models.PasswordResetCode @@ -147,7 +147,7 @@ func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error { func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{}) diff --git a/internal/admin/handlers/promotion_handler.go b/internal/admin/handlers/promotion_handler.go index 8a064ae..b1395e9 100644 --- a/internal/admin/handlers/promotion_handler.go +++ b/internal/admin/handlers/promotion_handler.go @@ -41,7 +41,7 @@ type PromotionResponse struct { func (h *AdminPromotionHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var promotions []models.Promotion @@ -123,18 +123,21 @@ func (h *AdminPromotionHandler) Get(c echo.Context) error { // Create handles POST /api/admin/promotions func (h *AdminPromotionHandler) Create(c echo.Context) error { var req struct { - PromotionID string `json:"promotion_id" binding:"required"` - Title string `json:"title" binding:"required"` - Message string `json:"message" binding:"required"` + PromotionID string `json:"promotion_id" validate:"required"` + Title string `json:"title" validate:"required"` + Message string `json:"message" validate:"required"` Link *string `json:"link"` - StartDate string `json:"start_date" binding:"required"` - EndDate string `json:"end_date" binding:"required"` + StartDate string `json:"start_date" validate:"required"` + EndDate string `json:"end_date" validate:"required"` TargetTier string `json:"target_tier"` IsActive *bool `json:"is_active"` } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } startDate, err := time.Parse("2006-01-02T15:04:05Z", req.StartDate) @@ -219,7 +222,7 @@ func (h *AdminPromotionHandler) Update(c echo.Context) error { } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.PromotionID != nil { diff --git a/internal/admin/handlers/residence_handler.go b/internal/admin/handlers/residence_handler.go index fb3dd76..d0940aa 100644 --- a/internal/admin/handlers/residence_handler.go +++ b/internal/admin/handlers/residence_handler.go @@ -27,7 +27,7 @@ func NewAdminResidenceHandler(db *gorm.DB) *AdminResidenceHandler { func (h *AdminResidenceHandler) List(c echo.Context) error { var filters dto.ResidenceFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var residences []models.Residence @@ -143,7 +143,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error { var req dto.UpdateResidenceRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.OwnerID != nil { @@ -204,8 +204,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error { } } if req.PurchasePrice != nil { - d := decimal.NewFromFloat(*req.PurchasePrice) - residence.PurchasePrice = &d + residence.PurchasePrice = req.PurchasePrice } if req.IsActive != nil { residence.IsActive = *req.IsActive @@ -226,7 +225,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error { func (h *AdminResidenceHandler) Create(c echo.Context) error { var req dto.CreateResidenceRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify owner exists @@ -300,7 +299,7 @@ func (h *AdminResidenceHandler) Delete(c echo.Context) error { func (h *AdminResidenceHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Soft delete - deactivate all diff --git a/internal/admin/handlers/settings_handler.go b/internal/admin/handlers/settings_handler.go index cb5ed50..350723c 100644 --- a/internal/admin/handlers/settings_handler.go +++ b/internal/admin/handlers/settings_handler.go @@ -47,7 +47,9 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error { TrialEnabled: true, TrialDurationDays: 14, } - h.db.Create(&settings) + if err := h.db.Create(&settings).Error; err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default settings"}) + } } else { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"}) } @@ -73,7 +75,7 @@ type UpdateSettingsRequest struct { func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error { var req UpdateSettingsRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var settings models.SubscriptionSettings @@ -123,12 +125,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error { func (h *AdminSettingsHandler) SeedLookups(c echo.Context) error { // First seed lookup tables if err := h.runSeedFile("001_lookups.sql"); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed lookups: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed lookups"}) } // Then seed task templates if err := h.runSeedFile("003_task_templates.sql"); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates"}) } // Cache all lookups in Redis @@ -349,7 +351,7 @@ func parseTags(tags string) []string { // SeedTestData handles POST /api/admin/settings/seed-test-data func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error { if err := h.runSeedFile("002_test_data.sql"); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed test data: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed test data"}) } return c.JSON(http.StatusOK, map[string]interface{}{"message": "Test data seeded successfully"}) @@ -358,7 +360,7 @@ func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error { // SeedTaskTemplates handles POST /api/admin/settings/seed-task-templates func (h *AdminSettingsHandler) SeedTaskTemplates(c echo.Context) error { if err := h.runSeedFile("003_task_templates.sql"); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates"}) } return c.JSON(http.StatusOK, map[string]interface{}{"message": "Task templates seeded successfully"}) @@ -590,38 +592,38 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { // 1. Delete task completion images if err := tx.Exec("DELETE FROM task_taskcompletionimage").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completion images: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completion images"}) } // 2. Delete task completions if err := tx.Exec("DELETE FROM task_taskcompletion").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completions: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completions"}) } // 3. Delete notifications (must be before tasks since notifications have task_id FK) if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM notifications_notification WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications"}) } } else { if err := tx.Exec("DELETE FROM notifications_notification").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications"}) } } // 4. Delete document images if err := tx.Exec("DELETE FROM task_documentimage").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"}) } // 5. Delete documents if err := tx.Exec("DELETE FROM task_document").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete documents: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete documents"}) } // 6. Delete task reminder logs (must be before tasks since reminder logs have task_id FK) @@ -631,64 +633,64 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if tableExists { if err := tx.Exec("DELETE FROM task_reminderlog").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task reminder logs: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task reminder logs"}) } } // 7. Delete tasks (must be before contractors since tasks reference contractors) if err := tx.Exec("DELETE FROM task_task").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete tasks: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete tasks"}) } // 8. Delete contractor specialties (many-to-many) if err := tx.Exec("DELETE FROM task_contractor_specialties").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractor specialties: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractor specialties"}) } // 9. Delete contractors if err := tx.Exec("DELETE FROM task_contractor").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractors: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractors"}) } // 10. Delete residence_users (many-to-many for shared residences) if err := tx.Exec("DELETE FROM residence_residence_users").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence users: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence users"}) } // 11. Delete residence share codes (must be before residences since share codes have residence_id FK) if err := tx.Exec("DELETE FROM residence_residencesharecode").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence share codes: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence share codes"}) } // 12. Delete residences if err := tx.Exec("DELETE FROM residence_residence").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residences: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residences"}) } // 13. Delete push devices for non-superusers (both APNS and GCM) if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM push_notifications_apnsdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices"}) } if err := tx.Exec("DELETE FROM push_notifications_gcmdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices"}) } } else { if err := tx.Exec("DELETE FROM push_notifications_apnsdevice").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices"}) } if err := tx.Exec("DELETE FROM push_notifications_gcmdevice").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices"}) } } @@ -696,12 +698,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM notifications_notificationpreference WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences"}) } } else { if err := tx.Exec("DELETE FROM notifications_notificationpreference").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences"}) } } @@ -709,12 +711,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM subscription_usersubscription WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions"}) } } else { if err := tx.Exec("DELETE FROM subscription_usersubscription").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions"}) } } @@ -722,12 +724,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM user_passwordresetcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"}) } } else { if err := tx.Exec("DELETE FROM user_passwordresetcode").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"}) } } @@ -735,12 +737,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM user_confirmationcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"}) } } else { if err := tx.Exec("DELETE FROM user_confirmationcode").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"}) } } @@ -748,12 +750,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM user_authtoken WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens"}) } } else { if err := tx.Exec("DELETE FROM user_authtoken").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens"}) } } @@ -761,12 +763,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM user_applesocialauth WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth"}) } } else { if err := tx.Exec("DELETE FROM user_applesocialauth").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth"}) } } @@ -774,12 +776,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM user_userprofile WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"}) } } else { if err := tx.Exec("DELETE FROM user_userprofile").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"}) } } @@ -790,12 +792,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { if len(preservedUserIDs) > 0 { if err := tx.Exec("DELETE FROM onboarding_emails WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails"}) } } else { if err := tx.Exec("DELETE FROM onboarding_emails").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails"}) } } } @@ -804,12 +806,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error { // Always filter by is_superuser to be safe, regardless of preservedUserIDs if err := tx.Exec("DELETE FROM auth_user WHERE is_superuser = false").Error; err != nil { tx.Rollback() - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete users: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete users"}) } // Commit the transaction if err := tx.Commit().Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to commit transaction: " + err.Error()}) + return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to commit transaction"}) } return c.JSON(http.StatusOK, ClearAllDataResponse{ diff --git a/internal/admin/handlers/share_code_handler.go b/internal/admin/handlers/share_code_handler.go index 9e82baa..2df7819 100644 --- a/internal/admin/handlers/share_code_handler.go +++ b/internal/admin/handlers/share_code_handler.go @@ -38,7 +38,7 @@ type ShareCodeResponse struct { func (h *AdminShareCodeHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var codes []models.ResidenceShareCode @@ -156,7 +156,7 @@ func (h *AdminShareCodeHandler) Update(c echo.Context) error { IsActive *bool `json:"is_active"` } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Only update IsActive when explicitly provided (non-nil). @@ -216,7 +216,7 @@ func (h *AdminShareCodeHandler) Delete(c echo.Context) error { func (h *AdminShareCodeHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } result := h.db.Where("id IN ?", req.IDs).Delete(&models.ResidenceShareCode{}) diff --git a/internal/admin/handlers/subscription_handler.go b/internal/admin/handlers/subscription_handler.go index b9d640c..68136bd 100644 --- a/internal/admin/handlers/subscription_handler.go +++ b/internal/admin/handlers/subscription_handler.go @@ -26,7 +26,7 @@ func NewAdminSubscriptionHandler(db *gorm.DB) *AdminSubscriptionHandler { func (h *AdminSubscriptionHandler) List(c echo.Context) error { var filters dto.SubscriptionFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var subscriptions []models.UserSubscription @@ -38,8 +38,8 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error { // Apply search (search by user email) if filters.Search != "" { search := "%" + filters.Search + "%" - query = query.Joins("JOIN users ON users.id = subscription_usersubscription.user_id"). - Where("users.email ILIKE ? OR users.username ILIKE ?", search, search) + query = query.Joins("JOIN auth_user ON auth_user.id = subscription_usersubscription.user_id"). + Where("auth_user.email ILIKE ? OR auth_user.username ILIKE ?", search, search) } // Apply filters @@ -140,7 +140,7 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error { var req dto.UpdateSubscriptionRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.Tier != nil { diff --git a/internal/admin/handlers/task_handler.go b/internal/admin/handlers/task_handler.go index c7e6008..a4dc077 100644 --- a/internal/admin/handlers/task_handler.go +++ b/internal/admin/handlers/task_handler.go @@ -6,7 +6,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/shopspring/decimal" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/admin/dto" @@ -27,7 +26,7 @@ func NewAdminTaskHandler(db *gorm.DB) *AdminTaskHandler { func (h *AdminTaskHandler) List(c echo.Context) error { var filters dto.TaskFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var tasks []models.Task @@ -149,7 +148,7 @@ func (h *AdminTaskHandler) Update(c echo.Context) error { var req dto.UpdateTaskRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify residence if changing @@ -216,10 +215,10 @@ func (h *AdminTaskHandler) Update(c echo.Context) error { } } if req.EstimatedCost != nil { - updates["estimated_cost"] = decimal.NewFromFloat(*req.EstimatedCost) + updates["estimated_cost"] = *req.EstimatedCost } if req.ActualCost != nil { - updates["actual_cost"] = decimal.NewFromFloat(*req.ActualCost) + updates["actual_cost"] = *req.ActualCost } if req.ContractorID != nil { updates["contractor_id"] = *req.ContractorID @@ -248,7 +247,7 @@ func (h *AdminTaskHandler) Update(c echo.Context) error { func (h *AdminTaskHandler) Create(c echo.Context) error { var req dto.CreateTaskRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Verify residence exists @@ -285,8 +284,7 @@ func (h *AdminTaskHandler) Create(c echo.Context) error { } } if req.EstimatedCost != nil { - d := decimal.NewFromFloat(*req.EstimatedCost) - task.EstimatedCost = &d + task.EstimatedCost = req.EstimatedCost } if err := h.db.Create(&task).Error; err != nil { @@ -326,7 +324,7 @@ func (h *AdminTaskHandler) Delete(c echo.Context) error { func (h *AdminTaskHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Soft delete - archive and cancel all diff --git a/internal/admin/handlers/task_template_handler.go b/internal/admin/handlers/task_template_handler.go index 3069725..cde0c16 100644 --- a/internal/admin/handlers/task_template_handler.go +++ b/internal/admin/handlers/task_template_handler.go @@ -28,6 +28,8 @@ func NewAdminTaskTemplateHandler(db *gorm.DB) *AdminTaskTemplateHandler { func (h *AdminTaskTemplateHandler) refreshTaskTemplatesCache(ctx context.Context) { cache := services.GetCache() if cache == nil { + log.Warn().Msg("Cache service unavailable, skipping task templates cache refresh") + return } var templates []models.TaskTemplate @@ -68,12 +70,12 @@ type TaskTemplateResponse struct { // CreateUpdateTaskTemplateRequest represents the request body for creating/updating templates type CreateUpdateTaskTemplateRequest struct { - Title string `json:"title" binding:"required,max=200"` + Title string `json:"title" validate:"required,max=200"` Description string `json:"description"` CategoryID *uint `json:"category_id"` FrequencyID *uint `json:"frequency_id"` - IconIOS string `json:"icon_ios" binding:"max=100"` - IconAndroid string `json:"icon_android" binding:"max=100"` + IconIOS string `json:"icon_ios" validate:"max=100"` + IconAndroid string `json:"icon_android" validate:"max=100"` Tags string `json:"tags"` DisplayOrder *int `json:"display_order"` IsActive *bool `json:"is_active"` @@ -140,7 +142,10 @@ func (h *AdminTaskTemplateHandler) GetTemplate(c echo.Context) error { func (h *AdminTaskTemplateHandler) CreateTemplate(c echo.Context) error { var req CreateUpdateTaskTemplateRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } template := models.TaskTemplate{ @@ -191,7 +196,10 @@ func (h *AdminTaskTemplateHandler) UpdateTemplate(c echo.Context) error { var req CreateUpdateTaskTemplateRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } template.Title = req.Title @@ -271,10 +279,13 @@ func (h *AdminTaskTemplateHandler) ToggleActive(c echo.Context) error { // BulkCreate handles POST /admin/api/task-templates/bulk/ func (h *AdminTaskTemplateHandler) BulkCreate(c echo.Context) error { var req struct { - Templates []CreateUpdateTaskTemplateRequest `json:"templates" binding:"required,dive"` + Templates []CreateUpdateTaskTemplateRequest `json:"templates" validate:"required,dive"` } if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) + } + if err := c.Validate(&req); err != nil { + return err } templates := make([]models.TaskTemplate, len(req.Templates)) diff --git a/internal/admin/handlers/user_handler.go b/internal/admin/handlers/user_handler.go index f23ffc2..8c99c09 100644 --- a/internal/admin/handlers/user_handler.go +++ b/internal/admin/handlers/user_handler.go @@ -3,14 +3,25 @@ package handlers import ( "net/http" "strconv" + "strings" "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/admin/dto" "github.com/treytartt/honeydue-api/internal/models" ) +// escapeLikeWildcards escapes SQL LIKE wildcards (%, _) in user input +// to prevent wildcard injection in LIKE queries. +func escapeLikeWildcards(s string) string { + s = strings.ReplaceAll(s, "\\", "\\\\") + s = strings.ReplaceAll(s, "%", "\\%") + s = strings.ReplaceAll(s, "_", "\\_") + return s +} + // AdminUserHandler handles admin user management endpoints type AdminUserHandler struct { db *gorm.DB @@ -25,7 +36,7 @@ func NewAdminUserHandler(db *gorm.DB) *AdminUserHandler { func (h *AdminUserHandler) List(c echo.Context) error { var filters dto.UserFilters if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var users []models.User @@ -35,7 +46,7 @@ func (h *AdminUserHandler) List(c echo.Context) error { // Apply search if filters.Search != "" { - search := "%" + filters.Search + "%" + search := "%" + escapeLikeWildcards(filters.Search) + "%" query = query.Where( "username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?", search, search, search, search, @@ -70,10 +81,49 @@ func (h *AdminUserHandler) List(c echo.Context) error { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch users"}) } + // Batch COUNT queries for residence and task counts instead of N+1 + userIDs := make([]uint, len(users)) + for i, user := range users { + userIDs[i] = user.ID + } + + type countResult struct { + OwnerID uint + Count int64 + } + + residenceCounts := make(map[uint]int64) + taskCounts := make(map[uint]int64) + + if len(userIDs) > 0 { + var resCounts []countResult + h.db.Model(&models.Residence{}). + Select("owner_id, COUNT(*) as count"). + Where("owner_id IN ?", userIDs). + Group("owner_id"). + Scan(&resCounts) + for _, rc := range resCounts { + residenceCounts[rc.OwnerID] = rc.Count + } + + var tskCounts []struct { + CreatedByID uint + Count int64 + } + h.db.Model(&models.Task{}). + Select("created_by_id, COUNT(*) as count"). + Where("created_by_id IN ?", userIDs). + Group("created_by_id"). + Scan(&tskCounts) + for _, tc := range tskCounts { + taskCounts[tc.CreatedByID] = tc.Count + } + } + // Build response responses := make([]dto.UserResponse, len(users)) for i, user := range users { - responses[i] = h.toUserResponse(&user) + responses[i] = h.toUserResponseWithCounts(&user, int(residenceCounts[user.ID]), int(taskCounts[user.ID])) } return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage())) @@ -122,7 +172,7 @@ func (h *AdminUserHandler) Get(c echo.Context) error { func (h *AdminUserHandler) Create(c echo.Context) error { var req dto.CreateUserRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Check if username exists @@ -170,7 +220,9 @@ func (h *AdminUserHandler) Create(c echo.Context) error { UserID: user.ID, PhoneNumber: req.PhoneNumber, } - h.db.Create(&profile) + if err := h.db.Create(&profile).Error; err != nil { + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile") + } // Reload with profile h.db.Preload("Profile").First(&user, user.ID) @@ -194,7 +246,7 @@ func (h *AdminUserHandler) Update(c echo.Context) error { var req dto.UpdateUserRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Check username uniqueness if changing @@ -298,7 +350,7 @@ func (h *AdminUserHandler) Delete(c echo.Context) error { func (h *AdminUserHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } // Soft delete - deactivate all @@ -309,8 +361,30 @@ func (h *AdminUserHandler) BulkDelete(c echo.Context) error { return c.JSON(http.StatusOK, map[string]interface{}{"message": "Users deactivated successfully", "count": len(req.IDs)}) } -// toUserResponse converts a User model to UserResponse DTO +// toUserResponseWithCounts converts a User model to UserResponse DTO with pre-computed counts +func (h *AdminUserHandler) toUserResponseWithCounts(user *models.User, residenceCount, taskCount int) dto.UserResponse { + response := h.buildUserResponse(user) + response.ResidenceCount = residenceCount + response.TaskCount = taskCount + return response +} + +// toUserResponse converts a User model to UserResponse DTO (fetches counts individually) func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse { + response := h.buildUserResponse(user) + + // Get counts individually (used for single-user views) + var residenceCount, taskCount int64 + h.db.Model(&models.Residence{}).Where("owner_id = ?", user.ID).Count(&residenceCount) + h.db.Model(&models.Task{}).Where("created_by_id = ?", user.ID).Count(&taskCount) + response.ResidenceCount = int(residenceCount) + response.TaskCount = int(taskCount) + + return response +} + +// buildUserResponse creates the base UserResponse without counts +func (h *AdminUserHandler) buildUserResponse(user *models.User) dto.UserResponse { response := dto.UserResponse{ ID: user.ID, Username: user.Username, @@ -335,12 +409,5 @@ func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse { } } - // Get counts - var residenceCount, taskCount int64 - h.db.Model(&models.Residence{}).Where("owner_id = ?", user.ID).Count(&residenceCount) - h.db.Model(&models.Task{}).Where("created_by_id = ?", user.ID).Count(&taskCount) - response.ResidenceCount = int(residenceCount) - response.TaskCount = int(taskCount) - return response } diff --git a/internal/admin/handlers/user_profile_handler.go b/internal/admin/handlers/user_profile_handler.go index 63c8d04..cfc59a9 100644 --- a/internal/admin/handlers/user_profile_handler.go +++ b/internal/admin/handlers/user_profile_handler.go @@ -50,7 +50,7 @@ type UpdateUserProfileRequest struct { func (h *AdminUserProfileHandler) List(c echo.Context) error { var filters dto.PaginationParams if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } var profiles []models.UserProfile @@ -144,7 +144,7 @@ func (h *AdminUserProfileHandler) Update(c echo.Context) error { var req UpdateUserProfileRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } if req.Verified != nil { @@ -205,14 +205,15 @@ func (h *AdminUserProfileHandler) Delete(c echo.Context) error { func (h *AdminUserProfileHandler) BulkDelete(c echo.Context) error { var req dto.BulkDeleteRequest if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()}) + return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) } - if err := h.db.Where("id IN ?", req.IDs).Delete(&models.UserProfile{}).Error; err != nil { + result := h.db.Where("id IN ?", req.IDs).Delete(&models.UserProfile{}) + if result.Error != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"}) } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "User profiles deleted successfully", "count": len(req.IDs)}) + return c.JSON(http.StatusOK, map[string]interface{}{"message": "User profiles deleted successfully", "count": result.RowsAffected}) } // toProfileResponse converts a UserProfile model to UserProfileResponse diff --git a/internal/admin/routes.go b/internal/admin/routes.go index 86baa5f..c2d6141 100644 --- a/internal/admin/routes.go +++ b/internal/admin/routes.go @@ -379,9 +379,10 @@ func SetupRoutes(router *echo.Echo, db *gorm.DB, cfg *config.Config, deps *Depen documentImages.DELETE("/:id", documentImageHandler.Delete) } - // System settings management + // System settings management (super admin only) settingsHandler := handlers.NewAdminSettingsHandler(db) settings := protected.Group("/settings") + settings.Use(middleware.RequireSuperAdmin()) { settings.GET("", settingsHandler.GetSettings) settings.PUT("", settingsHandler.UpdateSettings) @@ -500,9 +501,21 @@ func setupAdminProxy(router *echo.Echo) { }) } - // Proxy Next.js static assets (served from /_next/ regardless of host) + // Proxy Next.js static assets — only when the request comes from a + // known host to prevent SSRF via crafted Host headers. + var allowedProxyHosts []string + if adminHost != "" { + allowedProxyHosts = append(allowedProxyHosts, adminHost) + } + // Also allow localhost variants for development + allowedProxyHosts = append(allowedProxyHosts, + "localhost:3001", "127.0.0.1:3001", + "localhost:8000", "127.0.0.1:8000", + ) + hostCheck := middleware.HostCheck(allowedProxyHosts) + router.Any("/_next/*", func(c echo.Context) error { proxy.ServeHTTP(c.Response(), c.Request()) return nil - }) + }, hostCheck) } diff --git a/internal/config/config.go b/internal/config/config.go index 8d46305..66a6342 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "os" "strconv" "strings" + "sync" "time" "github.com/spf13/viper" @@ -32,6 +33,7 @@ type Config struct { type ServerConfig struct { Port int Debug bool + DebugFixedCodes bool // Separate from Debug: enables fixed confirmation codes for local testing AllowedHosts []string CorsAllowedOrigins []string // Comma-separated origins for CORS (production only; debug uses wildcard) Timezone string @@ -75,7 +77,12 @@ type PushConfig struct { APNSSandbox bool APNSProduction bool // If true, use production APNs; if false, use sandbox - // FCM (Android) - uses direct HTTP to FCM legacy API + // FCM (Android) - uses FCM HTTP v1 API with OAuth 2.0 + FCMProjectID string // Firebase project ID (required for v1 API) + FCMServiceAccountPath string // Path to Google service account JSON file + FCMServiceAccountJSON string // Raw service account JSON (alternative to path, e.g. for env var injection) + + // Deprecated: FCMServerKey is for the legacy HTTP API. Use FCMProjectID + service account instead. FCMServerKey string } @@ -147,135 +154,166 @@ type FeatureFlags struct { WorkerEnabled bool // FEATURE_WORKER_ENABLED (default: true) } -var cfg *Config +var ( + cfg *Config + cfgOnce sync.Once +) + +// knownWeakSecretKeys contains well-known default or placeholder secret keys +// that must not be used in production. +var knownWeakSecretKeys = map[string]bool{ + "secret": true, + "changeme": true, + "change-me": true, + "password": true, + "change-me-in-production-secret-key-12345": true, +} // Load reads configuration from environment variables func Load() (*Config, error) { - viper.SetEnvPrefix("") - viper.AutomaticEnv() - viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + var loadErr error - // Set defaults - setDefaults() + cfgOnce.Do(func() { + viper.SetEnvPrefix("") + viper.AutomaticEnv() + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - // Parse DATABASE_URL if set (Dokku-style) - dbConfig := DatabaseConfig{ - Host: viper.GetString("DB_HOST"), - Port: viper.GetInt("DB_PORT"), - User: viper.GetString("POSTGRES_USER"), - Password: viper.GetString("POSTGRES_PASSWORD"), - Database: viper.GetString("POSTGRES_DB"), - SSLMode: viper.GetString("DB_SSLMODE"), - MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"), - MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"), - MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"), - } + // Set defaults + setDefaults() - // Override with DATABASE_URL if present - if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" { - parsed, err := parseDatabaseURL(databaseURL) - if err == nil { - dbConfig.Host = parsed.Host - dbConfig.Port = parsed.Port - dbConfig.User = parsed.User - dbConfig.Password = parsed.Password - dbConfig.Database = parsed.Database - if parsed.SSLMode != "" { - dbConfig.SSLMode = parsed.SSLMode + // Parse DATABASE_URL if set (Dokku-style) + dbConfig := DatabaseConfig{ + Host: viper.GetString("DB_HOST"), + Port: viper.GetInt("DB_PORT"), + User: viper.GetString("POSTGRES_USER"), + Password: viper.GetString("POSTGRES_PASSWORD"), + Database: viper.GetString("POSTGRES_DB"), + SSLMode: viper.GetString("DB_SSLMODE"), + MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"), + MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"), + MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"), + } + + // Override with DATABASE_URL if present (F-16: log warning on parse failure) + if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" { + parsed, err := parseDatabaseURL(databaseURL) + if err != nil { + maskedURL := MaskURLCredentials(databaseURL) + fmt.Printf("WARNING: Failed to parse DATABASE_URL (%s): %v — falling back to individual DB_* env vars\n", maskedURL, err) + } else { + dbConfig.Host = parsed.Host + dbConfig.Port = parsed.Port + dbConfig.User = parsed.User + dbConfig.Password = parsed.Password + dbConfig.Database = parsed.Database + if parsed.SSLMode != "" { + dbConfig.SSLMode = parsed.SSLMode + } } } - } - cfg = &Config{ - Server: ServerConfig{ - Port: viper.GetInt("PORT"), - Debug: viper.GetBool("DEBUG"), - AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","), - CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")), - Timezone: viper.GetString("TIMEZONE"), - StaticDir: viper.GetString("STATIC_DIR"), - BaseURL: viper.GetString("BASE_URL"), - }, - Database: dbConfig, - Redis: RedisConfig{ - URL: viper.GetString("REDIS_URL"), - Password: viper.GetString("REDIS_PASSWORD"), - DB: viper.GetInt("REDIS_DB"), - }, - Email: EmailConfig{ - Host: viper.GetString("EMAIL_HOST"), - Port: viper.GetInt("EMAIL_PORT"), - User: viper.GetString("EMAIL_HOST_USER"), - Password: viper.GetString("EMAIL_HOST_PASSWORD"), - From: viper.GetString("DEFAULT_FROM_EMAIL"), - UseTLS: viper.GetBool("EMAIL_USE_TLS"), - }, - Push: PushConfig{ - APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"), - APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"), - APNSTeamID: viper.GetString("APNS_TEAM_ID"), - APNSTopic: viper.GetString("APNS_TOPIC"), - APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"), - APNSProduction: viper.GetBool("APNS_PRODUCTION"), - FCMServerKey: viper.GetString("FCM_SERVER_KEY"), - }, - Worker: WorkerConfig{ - TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"), - OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"), - DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"), - }, - Security: SecurityConfig{ - SecretKey: viper.GetString("SECRET_KEY"), - TokenCacheTTL: 5 * time.Minute, - PasswordResetExpiry: 15 * time.Minute, - ConfirmationExpiry: 24 * time.Hour, - MaxPasswordResetRate: 3, - }, - Storage: StorageConfig{ - UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), - BaseURL: viper.GetString("STORAGE_BASE_URL"), - MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"), - AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"), - }, - AppleAuth: AppleAuthConfig{ - ClientID: viper.GetString("APPLE_CLIENT_ID"), - TeamID: viper.GetString("APPLE_TEAM_ID"), - }, - GoogleAuth: GoogleAuthConfig{ - ClientID: viper.GetString("GOOGLE_CLIENT_ID"), - AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"), - IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"), - }, - AppleIAP: AppleIAPConfig{ - KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"), - KeyID: viper.GetString("APPLE_IAP_KEY_ID"), - IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"), - BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"), - Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"), - }, - GoogleIAP: GoogleIAPConfig{ - ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"), - PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"), - }, - Stripe: StripeConfig{ - SecretKey: viper.GetString("STRIPE_SECRET_KEY"), - WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"), - PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"), - PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"), - }, - Features: FeatureFlags{ - PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"), - EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"), - WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"), - OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"), - PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"), - WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"), - }, - } + cfg = &Config{ + Server: ServerConfig{ + Port: viper.GetInt("PORT"), + Debug: viper.GetBool("DEBUG"), + DebugFixedCodes: viper.GetBool("DEBUG_FIXED_CODES"), + AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","), + CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")), + Timezone: viper.GetString("TIMEZONE"), + StaticDir: viper.GetString("STATIC_DIR"), + BaseURL: viper.GetString("BASE_URL"), + }, + Database: dbConfig, + Redis: RedisConfig{ + URL: viper.GetString("REDIS_URL"), + Password: viper.GetString("REDIS_PASSWORD"), + DB: viper.GetInt("REDIS_DB"), + }, + Email: EmailConfig{ + Host: viper.GetString("EMAIL_HOST"), + Port: viper.GetInt("EMAIL_PORT"), + User: viper.GetString("EMAIL_HOST_USER"), + Password: viper.GetString("EMAIL_HOST_PASSWORD"), + From: viper.GetString("DEFAULT_FROM_EMAIL"), + UseTLS: viper.GetBool("EMAIL_USE_TLS"), + }, + Push: PushConfig{ + APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"), + APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"), + APNSTeamID: viper.GetString("APNS_TEAM_ID"), + APNSTopic: viper.GetString("APNS_TOPIC"), + APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"), + APNSProduction: viper.GetBool("APNS_PRODUCTION"), + FCMProjectID: viper.GetString("FCM_PROJECT_ID"), + FCMServiceAccountPath: viper.GetString("FCM_SERVICE_ACCOUNT_PATH"), + FCMServiceAccountJSON: viper.GetString("FCM_SERVICE_ACCOUNT_JSON"), + FCMServerKey: viper.GetString("FCM_SERVER_KEY"), + }, + Worker: WorkerConfig{ + TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"), + OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"), + DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"), + }, + Security: SecurityConfig{ + SecretKey: viper.GetString("SECRET_KEY"), + TokenCacheTTL: 5 * time.Minute, + PasswordResetExpiry: 15 * time.Minute, + ConfirmationExpiry: 24 * time.Hour, + MaxPasswordResetRate: 3, + }, + Storage: StorageConfig{ + UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), + BaseURL: viper.GetString("STORAGE_BASE_URL"), + MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"), + AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"), + }, + AppleAuth: AppleAuthConfig{ + ClientID: viper.GetString("APPLE_CLIENT_ID"), + TeamID: viper.GetString("APPLE_TEAM_ID"), + }, + GoogleAuth: GoogleAuthConfig{ + ClientID: viper.GetString("GOOGLE_CLIENT_ID"), + AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"), + IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"), + }, + AppleIAP: AppleIAPConfig{ + KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"), + KeyID: viper.GetString("APPLE_IAP_KEY_ID"), + IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"), + BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"), + Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"), + }, + GoogleIAP: GoogleIAPConfig{ + ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"), + PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"), + }, + Stripe: StripeConfig{ + SecretKey: viper.GetString("STRIPE_SECRET_KEY"), + WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"), + PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"), + PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"), + }, + Features: FeatureFlags{ + PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"), + EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"), + WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"), + OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"), + PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"), + WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"), + }, + } - // Validate required fields - if err := validate(cfg); err != nil { - return nil, err + // Validate required fields + if err := validate(cfg); err != nil { + loadErr = err + // Reset so a subsequent call can retry after env is fixed + cfg = nil + cfgOnce = sync.Once{} + } + }) + + if loadErr != nil { + return nil, loadErr } return cfg, nil @@ -290,6 +328,7 @@ func setDefaults() { // Server defaults viper.SetDefault("PORT", 8000) viper.SetDefault("DEBUG", false) + viper.SetDefault("DEBUG_FIXED_CODES", false) // Separate flag for fixed confirmation codes viper.SetDefault("ALLOWED_HOSTS", "localhost,127.0.0.1") viper.SetDefault("TIMEZONE", "UTC") viper.SetDefault("STATIC_DIR", "/app/static") @@ -347,7 +386,13 @@ func setDefaults() { viper.SetDefault("FEATURE_WORKER_ENABLED", true) } +// isWeakSecretKey checks if the provided key is a known weak/default value. +func isWeakSecretKey(key string) bool { + return knownWeakSecretKeys[strings.ToLower(strings.TrimSpace(key))] +} + func validate(cfg *Config) error { + // S-08: Validate SECRET_KEY against known weak defaults if cfg.Security.SecretKey == "" { if cfg.Server.Debug { // In debug mode, use a default key with a warning for local development @@ -358,9 +403,12 @@ func validate(cfg *Config) error { // In production, refuse to start without a proper secret key return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)") } - } else if cfg.Security.SecretKey == "change-me-in-production-secret-key-12345" { - // Warn if someone explicitly set the well-known debug key - fmt.Println("WARNING: SECRET_KEY is set to the well-known debug default. Change it for production use.") + } else if isWeakSecretKey(cfg.Security.SecretKey) { + if cfg.Server.Debug { + fmt.Printf("WARNING: SECRET_KEY is set to a well-known weak value (%q). Change it for production use.\n", cfg.Security.SecretKey) + } else { + return fmt.Errorf("FATAL: SECRET_KEY is set to a well-known weak value (%q). Use a strong, unique secret in production", cfg.Security.SecretKey) + } } // Database password might come from DATABASE_URL, don't require it separately @@ -369,6 +417,21 @@ func validate(cfg *Config) error { return nil } +// MaskURLCredentials parses a URL and replaces any password with "***". +// If parsing fails, it returns the string "" to avoid leaking credentials. +func MaskURLCredentials(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + if u.User != nil { + if _, hasPassword := u.User.Password(); hasPassword { + u.User = url.UserPassword(u.User.Username(), "***") + } + } + return u.Redacted() +} + // DSN returns the database connection string func (d *DatabaseConfig) DSN() string { return fmt.Sprintf( diff --git a/internal/database/database.go b/internal/database/database.go index 8f2b4c7..1e7a0cd 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -5,6 +5,8 @@ import ( "time" "github.com/rs/zerolog/log" + "github.com/spf13/viper" + "golang.org/x/crypto/bcrypt" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -376,13 +378,23 @@ func migrateGoAdmin() error { } db.Exec(`CREATE INDEX IF NOT EXISTS idx_goadmin_site_key ON goadmin_site(key)`) - // Seed default admin user only on first run (ON CONFLICT DO NOTHING). + // Seed default GoAdmin user only on first run (ON CONFLICT DO NOTHING). // Password is NOT reset on subsequent migrations to preserve operator changes. - db.Exec(` - INSERT INTO goadmin_users (username, password, name, avatar) - VALUES ('admin', '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm', 'Administrator', '') - ON CONFLICT DO NOTHING - `) + goAdminUsername := viper.GetString("GOADMIN_ADMIN_USERNAME") + goAdminPassword := viper.GetString("GOADMIN_ADMIN_PASSWORD") + if goAdminUsername == "" || goAdminPassword == "" { + log.Warn().Msg("GOADMIN_ADMIN_USERNAME and/or GOADMIN_ADMIN_PASSWORD not set; skipping GoAdmin admin user seed") + } else { + goAdminHash, err := bcrypt.GenerateFromPassword([]byte(goAdminPassword), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash GoAdmin admin password: %w", err) + } + db.Exec(` + INSERT INTO goadmin_users (username, password, name, avatar) + VALUES (?, ?, 'Administrator', '') + ON CONFLICT DO NOTHING + `, goAdminUsername, string(goAdminHash)) + } // Seed default roles db.Exec(`INSERT INTO goadmin_roles (name, slug) VALUES ('Administrator', 'administrator') ON CONFLICT DO NOTHING`) @@ -393,15 +405,17 @@ func migrateGoAdmin() error { db.Exec(`INSERT INTO goadmin_permissions (name, slug, http_method, http_path) VALUES ('Dashboard', 'dashboard', 'GET', '/') ON CONFLICT DO NOTHING`) // Assign admin user to administrator role (if not already assigned) - db.Exec(` - INSERT INTO goadmin_role_users (role_id, user_id) - SELECT r.id, u.id FROM goadmin_roles r, goadmin_users u - WHERE r.slug = 'administrator' AND u.username = 'admin' - AND NOT EXISTS ( - SELECT 1 FROM goadmin_role_users ru - WHERE ru.role_id = r.id AND ru.user_id = u.id - ) - `) + if goAdminUsername != "" { + db.Exec(` + INSERT INTO goadmin_role_users (role_id, user_id) + SELECT r.id, u.id FROM goadmin_roles r, goadmin_users u + WHERE r.slug = 'administrator' AND u.username = ? + AND NOT EXISTS ( + SELECT 1 FROM goadmin_role_users ru + WHERE ru.role_id = r.id AND ru.user_id = u.id + ) + `, goAdminUsername) + } // Assign all permissions to administrator role (if not already assigned) db.Exec(` @@ -448,15 +462,25 @@ func migrateGoAdmin() error { // Seed default Next.js admin user only on first run. // Password is NOT reset on subsequent migrations to preserve operator changes. - var adminCount int64 - db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = 'admin@honeydue.com'`).Scan(&adminCount) - if adminCount == 0 { - log.Info().Msg("Seeding default admin user for Next.js admin panel...") - db.Exec(` - INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at) - VALUES ('admin@honeydue.com', '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O', 'Admin', 'User', 'super_admin', true, NOW(), NOW()) - `) - log.Info().Msg("Default admin user created: admin@honeydue.com") + adminEmail := viper.GetString("ADMIN_EMAIL") + adminPassword := viper.GetString("ADMIN_PASSWORD") + if adminEmail == "" || adminPassword == "" { + log.Warn().Msg("ADMIN_EMAIL and/or ADMIN_PASSWORD not set; skipping Next.js admin user seed") + } else { + var adminCount int64 + db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = ?`, adminEmail).Scan(&adminCount) + if adminCount == 0 { + log.Info().Str("email", adminEmail).Msg("Seeding default admin user for Next.js admin panel...") + adminHash, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash admin password: %w", err) + } + db.Exec(` + INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at) + VALUES (?, ?, 'Admin', 'User', 'super_admin', true, NOW(), NOW()) + `, adminEmail, string(adminHash)) + log.Info().Str("email", adminEmail).Msg("Default admin user created") + } } return nil diff --git a/internal/dto/requests/task.go b/internal/dto/requests/task.go index 8fe3ff0..fc2bb05 100644 --- a/internal/dto/requests/task.go +++ b/internal/dto/requests/task.go @@ -71,7 +71,7 @@ type CreateTaskRequest struct { // UpdateTaskRequest represents the request to update a task type UpdateTaskRequest struct { Title *string `json:"title" validate:"omitempty,min=1,max=200"` - Description *string `json:"description"` + Description *string `json:"description" validate:"omitempty,max=10000"` CategoryID *uint `json:"category_id"` PriorityID *uint `json:"priority_id"` FrequencyID *uint `json:"frequency_id"` diff --git a/internal/dto/responses/document.go b/internal/dto/responses/document.go index 63cf69f..fdcd279 100644 --- a/internal/dto/responses/document.go +++ b/internal/dto/responses/document.go @@ -21,8 +21,9 @@ type DocumentUserResponse struct { type DocumentImageResponse struct { ID uint `json:"id"` ImageURL string `json:"image_url"` - MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document-image/{id} + MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document-image/{id} Caption string `json:"caption"` + Error string `json:"error,omitempty"` // Non-empty when the image could not be resolved } // DocumentResponse represents a document in the API response @@ -35,7 +36,6 @@ type DocumentResponse struct { Title string `json:"title"` Description string `json:"description"` DocumentType models.DocumentType `json:"document_type"` - FileURL string `json:"file_url"` MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document/{id} FileName string `json:"file_name"` FileSize *int64 `json:"file_size"` @@ -80,7 +80,6 @@ func NewDocumentResponse(d *models.Document) DocumentResponse { Title: d.Title, Description: d.Description, DocumentType: d.DocumentType, - FileURL: d.FileURL, MediaURL: fmt.Sprintf("/api/media/document/%d", d.ID), // Authenticated endpoint FileName: d.FileName, FileSize: d.FileSize, @@ -104,12 +103,16 @@ func NewDocumentResponse(d *models.Document) DocumentResponse { // Convert images with authenticated media URLs for _, img := range d.Images { - resp.Images = append(resp.Images, DocumentImageResponse{ + imgResp := DocumentImageResponse{ ID: img.ID, ImageURL: img.ImageURL, MediaURL: fmt.Sprintf("/api/media/document-image/%d", img.ID), // Authenticated endpoint Caption: img.Caption, - }) + } + if img.ImageURL == "" { + imgResp.Error = "image source URL is missing" + } + resp.Images = append(resp.Images, imgResp) } return resp diff --git a/internal/dto/responses/task.go b/internal/dto/responses/task.go index fc1fac3..3520dad 100644 --- a/internal/dto/responses/task.go +++ b/internal/dto/responses/task.go @@ -281,13 +281,15 @@ func NewTaskListResponse(tasks []models.Task) []TaskResponse { return results } -// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model -func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanBoardResponse { +// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model. +// The `now` parameter should be the start of day in the user's timezone so that +// individual task kanban columns are categorized consistently with the board query. +func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint, now time.Time) KanbanBoardResponse { columns := make([]KanbanColumnResponse, len(board.Columns)) for i, col := range board.Columns { tasks := make([]TaskResponse, len(col.Tasks)) for j, t := range col.Tasks { - tasks[j] = NewTaskResponse(&t) + tasks[j] = NewTaskResponseWithTime(&t, board.DaysThreshold, now) } columns[i] = KanbanColumnResponse{ Name: col.Name, @@ -306,13 +308,15 @@ func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanB } } -// NewKanbanBoardResponseForAll creates a KanbanBoardResponse for all residences (no specific residence ID) -func NewKanbanBoardResponseForAll(board *models.KanbanBoard) KanbanBoardResponse { +// NewKanbanBoardResponseForAll creates a KanbanBoardResponse for all residences (no specific residence ID). +// The `now` parameter should be the start of day in the user's timezone so that +// individual task kanban columns are categorized consistently with the board query. +func NewKanbanBoardResponseForAll(board *models.KanbanBoard, now time.Time) KanbanBoardResponse { columns := make([]KanbanColumnResponse, len(board.Columns)) for i, col := range board.Columns { tasks := make([]TaskResponse, len(col.Tasks)) for j, t := range col.Tasks { - tasks[j] = NewTaskResponse(&t) + tasks[j] = NewTaskResponseWithTime(&t, board.DaysThreshold, now) } columns[i] = KanbanColumnResponse{ Name: col.Name, diff --git a/internal/handlers/contractor_handler.go b/internal/handlers/contractor_handler.go index 0fb67a5..d8def5c 100644 --- a/internal/handlers/contractor_handler.go +++ b/internal/handlers/contractor_handler.go @@ -8,6 +8,8 @@ import ( "github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/dto/requests" + "github.com/treytartt/honeydue-api/internal/dto/responses" + "github.com/treytartt/honeydue-api/internal/i18n" "github.com/treytartt/honeydue-api/internal/middleware" "github.com/treytartt/honeydue-api/internal/services" ) @@ -115,7 +117,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error { if err != nil { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Contractor deleted successfully"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.contractor_deleted")}) } // ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/ diff --git a/internal/handlers/document_handler.go b/internal/handlers/document_handler.go index 0237424..be334a0 100644 --- a/internal/handlers/document_handler.go +++ b/internal/handlers/document_handler.go @@ -12,6 +12,8 @@ import ( "github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/dto/requests" + "github.com/treytartt/honeydue-api/internal/dto/responses" + "github.com/treytartt/honeydue-api/internal/i18n" "github.com/treytartt/honeydue-api/internal/middleware" "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/repositories" @@ -60,6 +62,9 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error { } if es := c.QueryParam("expiring_soon"); es != "" { if parsed, err := strconv.Atoi(es); err == nil { + if parsed < 1 || parsed > 3650 { + return apperrors.BadRequest("error.days_out_of_range") + } filter.ExpiringSoon = &parsed } } @@ -192,7 +197,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error { } } - if uploadedFile != nil && h.storageService != nil { + if uploadedFile != nil { + if h.storageService == nil { + return apperrors.Internal(nil) + } result, err := h.storageService.Upload(uploadedFile, "documents") if err != nil { return apperrors.BadRequest("error.failed_to_upload_file") @@ -262,7 +270,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error { if err != nil { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document deleted successfully"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.document_deleted")}) } // ActivateDocument handles POST /api/documents/:id/activate/ @@ -280,7 +288,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error { if err != nil { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document activated successfully", "document": response}) + return c.JSON(http.StatusOK, response) } // DeactivateDocument handles POST /api/documents/:id/deactivate/ @@ -298,7 +306,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error { if err != nil { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document deactivated successfully", "document": response}) + return c.JSON(http.StatusOK, response) } // UploadDocumentImage handles POST /api/documents/:id/images/ diff --git a/internal/handlers/notification_handler.go b/internal/handlers/notification_handler.go index 309c1a3..04d5dad 100644 --- a/internal/handlers/notification_handler.go +++ b/internal/handlers/notification_handler.go @@ -7,6 +7,8 @@ import ( "github.com/labstack/echo/v4" "github.com/treytartt/honeydue-api/internal/apperrors" + "github.com/treytartt/honeydue-api/internal/dto/responses" + "github.com/treytartt/honeydue-api/internal/i18n" "github.com/treytartt/honeydue-api/internal/middleware" "github.com/treytartt/honeydue-api/internal/services" ) @@ -87,7 +89,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.notification_marked_read"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.notification_marked_read")}) } // MarkAllAsRead handles POST /api/notifications/mark-all-read/ @@ -102,7 +104,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.all_notifications_marked_read"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.all_notifications_marked_read")}) } // GetPreferences handles GET /api/notifications/preferences/ @@ -200,7 +202,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error { return apperrors.BadRequest("error.registration_id_required") } if req.Platform == "" { - req.Platform = "ios" // Default to iOS + return apperrors.BadRequest("error.platform_required") + } + if req.Platform != "ios" && req.Platform != "android" { + return apperrors.BadRequest("error.invalid_platform") } err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID) @@ -208,7 +213,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.device_unregistered"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.device_removed")}) } // DeleteDevice handles DELETE /api/notifications/devices/:id/ @@ -225,7 +230,10 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error { platform := c.QueryParam("platform") if platform == "" { - platform = "ios" // Default to iOS + return apperrors.BadRequest("error.platform_required") + } + if platform != "ios" && platform != "android" { + return apperrors.BadRequest("error.invalid_platform") } err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID) @@ -233,5 +241,5 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.device_removed"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.device_removed")}) } diff --git a/internal/handlers/subscription_handler.go b/internal/handlers/subscription_handler.go index f8d3a50..80699b2 100644 --- a/internal/handlers/subscription_handler.go +++ b/internal/handlers/subscription_handler.go @@ -6,6 +6,7 @@ import ( "github.com/labstack/echo/v4" "github.com/treytartt/honeydue-api/internal/apperrors" + "github.com/treytartt/honeydue-api/internal/i18n" "github.com/treytartt/honeydue-api/internal/middleware" "github.com/treytartt/honeydue-api/internal/services" ) @@ -139,7 +140,7 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error { } return c.JSON(http.StatusOK, map[string]interface{}{ - "message": "message.subscription_upgraded", + "message": i18n.LocalizedMessage(c, "message.subscription_upgraded"), "subscription": subscription, }) } @@ -157,7 +158,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error { } return c.JSON(http.StatusOK, map[string]interface{}{ - "message": "message.subscription_cancelled", + "message": i18n.LocalizedMessage(c, "message.subscription_cancelled"), "subscription": subscription, }) } @@ -182,8 +183,15 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error { switch req.Platform { case "ios": + // B-14: Validate that at least one of receipt_data or transaction_id is provided + if req.ReceiptData == "" && req.TransactionID == "" { + return apperrors.BadRequest("error.receipt_data_required") + } subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID) case "android": + if req.PurchaseToken == "" { + return apperrors.BadRequest("error.purchase_token_required") + } subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID) default: return apperrors.BadRequest("error.invalid_platform") @@ -194,7 +202,7 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error { } return c.JSON(http.StatusOK, map[string]interface{}{ - "message": "message.subscription_restored", + "message": i18n.LocalizedMessage(c, "message.subscription_restored"), "subscription": subscription, }) } diff --git a/internal/handlers/subscription_webhook_handler.go b/internal/handlers/subscription_webhook_handler.go index 61b61db..a678d38 100644 --- a/internal/handlers/subscription_webhook_handler.go +++ b/internal/handlers/subscription_webhook_handler.go @@ -2,15 +2,18 @@ package handlers import ( "crypto/ecdsa" + "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "io" + "math/big" "net/http" "os" "strings" + "sync" "time" "github.com/golang-jwt/jwt/v5" @@ -23,6 +26,11 @@ import ( "github.com/treytartt/honeydue-api/internal/services" ) +// maxWebhookBodySize is the maximum allowed request body size for webhook +// payloads (1 MB). This prevents a malicious or misbehaving sender from +// forcing the server to allocate unbounded memory. +const maxWebhookBodySize = 1 << 20 // 1 MB + // SubscriptionWebhookHandler handles subscription webhook callbacks type SubscriptionWebhookHandler struct { subscriptionRepo *repositories.SubscriptionRepository @@ -112,7 +120,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error { return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"}) } - body, err := io.ReadAll(c.Request().Body) + body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize)) if err != nil { log.Error().Err(err).Msg("Apple Webhook: Failed to read body") return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"}) @@ -211,13 +219,22 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri return ¬ification, nil } -// decodeAppleTransaction decodes a signed transaction info JWS +// decodeAppleTransaction decodes and verifies a signed transaction info JWS. +// The inner JWS signature is verified using the same Apple certificate chain +// validation as the outer notification payload. func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction string) (*AppleTransactionInfo, error) { parts := strings.Split(signedTransaction, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWS format") } + // S-16: Verify the inner JWS signature for signedTransactionInfo. + // Apple signs each inner JWS independently with the same x5c certificate + // chain as the outer notification, so the same verification applies. + if err := h.VerifyAppleSignature(signedTransaction); err != nil { + return nil, fmt.Errorf("Apple transaction JWS signature verification failed: %w", err) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("failed to decode payload: %w", err) @@ -231,13 +248,20 @@ func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction st return &info, nil } -// decodeAppleRenewalInfo decodes signed renewal info JWS +// decodeAppleRenewalInfo decodes and verifies a signed renewal info JWS. +// The inner JWS signature is verified using the same Apple certificate chain +// validation as the outer notification payload. func (h *SubscriptionWebhookHandler) decodeAppleRenewalInfo(signedRenewal string) (*AppleRenewalInfo, error) { parts := strings.Split(signedRenewal, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWS format") } + // S-16: Verify the inner JWS signature for signedRenewalInfo. + if err := h.VerifyAppleSignature(signedRenewal); err != nil { + return nil, fmt.Errorf("Apple renewal JWS signature verification failed: %w", err) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("failed to decode payload: %w", err) @@ -484,7 +508,12 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error { return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"}) } - body, err := io.ReadAll(c.Request().Body) + // C-01: Verify the Google Pub/Sub push authentication token before processing + if !h.VerifyGooglePubSubToken(c) { + return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "unauthorized"}) + } + + body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize)) if err != nil { log.Error().Err(err).Msg("Google Webhook: Failed to read body") return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"}) @@ -781,7 +810,7 @@ func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error { return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"}) } - body, err := io.ReadAll(c.Request().Body) + body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize)) if err != nil { log.Error().Err(err).Msg("Stripe Webhook: Failed to read body") return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"}) @@ -884,10 +913,109 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string) return nil } +// googleOIDCCertsURL is the endpoint that serves Google's public OAuth2 +// certificates used to verify JWTs issued by accounts.google.com (including +// Pub/Sub push tokens). +const googleOIDCCertsURL = "https://www.googleapis.com/oauth2/v3/certs" + +// googleJWKSCache caches the fetched Google JWKS keys so we don't hit the +// network on every webhook request. Keys are refreshed when the cache expires. +var ( + googleJWKSCache map[string]*rsa.PublicKey + googleJWKSCacheMu sync.RWMutex + googleJWKSCacheTime time.Time + googleJWKSCacheTTL = 1 * time.Hour +) + +// googleJWKSResponse represents the JSON Web Key Set response from Google. +type googleJWKSResponse struct { + Keys []googleJWK `json:"keys"` +} + +// googleJWK represents a single JSON Web Key from Google's OIDC endpoint. +type googleJWK struct { + Kid string `json:"kid"` // Key ID + Kty string `json:"kty"` // Key type (RSA) + Alg string `json:"alg"` // Algorithm (RS256) + N string `json:"n"` // RSA modulus (base64url) + E string `json:"e"` // RSA exponent (base64url) +} + +// fetchGoogleOIDCKeys fetches Google's public OIDC keys from their well-known +// endpoint, returning a map of key-id to RSA public key. Results are cached +// for googleJWKSCacheTTL to avoid excessive network calls. +func fetchGoogleOIDCKeys() (map[string]*rsa.PublicKey, error) { + googleJWKSCacheMu.RLock() + if googleJWKSCache != nil && time.Since(googleJWKSCacheTime) < googleJWKSCacheTTL { + cached := googleJWKSCache + googleJWKSCacheMu.RUnlock() + return cached, nil + } + googleJWKSCacheMu.RUnlock() + + googleJWKSCacheMu.Lock() + defer googleJWKSCacheMu.Unlock() + + // Double-check after acquiring write lock + if googleJWKSCache != nil && time.Since(googleJWKSCacheTime) < googleJWKSCacheTTL { + return googleJWKSCache, nil + } + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(googleOIDCCertsURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch Google OIDC keys: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Google OIDC keys endpoint returned status %d", resp.StatusCode) + } + + var jwks googleJWKSResponse + if err := json.NewDecoder(io.LimitReader(resp.Body, maxWebhookBodySize)).Decode(&jwks); err != nil { + return nil, fmt.Errorf("failed to decode Google OIDC keys: %w", err) + } + + keys := make(map[string]*rsa.PublicKey, len(jwks.Keys)) + for _, k := range jwks.Keys { + if k.Kty != "RSA" { + continue + } + nBytes, err := base64.RawURLEncoding.DecodeString(k.N) + if err != nil { + log.Warn().Str("kid", k.Kid).Err(err).Msg("Google OIDC: failed to decode modulus") + continue + } + eBytes, err := base64.RawURLEncoding.DecodeString(k.E) + if err != nil { + log.Warn().Str("kid", k.Kid).Err(err).Msg("Google OIDC: failed to decode exponent") + continue + } + + n := new(big.Int).SetBytes(nBytes) + e := 0 + for _, b := range eBytes { + e = e<<8 + int(b) + } + + keys[k.Kid] = &rsa.PublicKey{N: n, E: e} + } + + if len(keys) == 0 { + return nil, fmt.Errorf("no usable RSA keys found in Google OIDC response") + } + + googleJWKSCache = keys + googleJWKSCacheTime = time.Now() + return keys, nil +} + // VerifyGooglePubSubToken verifies the Pub/Sub push authentication token. -// Returns false (deny) when the Authorization header is missing or the token -// cannot be validated. This prevents unauthenticated callers from injecting -// webhook events. +// The token is a JWT signed by Google (accounts.google.com). This function +// verifies the signature against Google's published OIDC public keys, checks +// the issuer claim, and validates the email claim is a Google service account. +// Returns false (deny) when verification fails for any reason. func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool { authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { @@ -907,12 +1035,52 @@ func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) boo return false } - // Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs - // issued by accounts.google.com. We verify the claims to ensure the - // token was intended for our service. - token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{}) + // Fetch Google's OIDC public keys for signature verification + googleKeys, err := fetchGoogleOIDCKeys() if err != nil { - log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token") + log.Error().Err(err).Msg("Google Webhook: failed to fetch OIDC keys, denying request") + return false + } + + // Parse and verify the JWT signature against Google's published public keys + token, err := jwt.Parse(bearerToken, func(token *jwt.Token) (interface{}, error) { + // Ensure the signing method is RSA (Google uses RS256) + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + kid, _ := token.Header["kid"].(string) + if kid == "" { + return nil, fmt.Errorf("missing kid header in token") + } + + key, ok := googleKeys[kid] + if !ok { + // Key may have rotated; try refreshing once + googleJWKSCacheMu.Lock() + googleJWKSCacheTime = time.Time{} // Force refresh + googleJWKSCacheMu.Unlock() + + refreshedKeys, err := fetchGoogleOIDCKeys() + if err != nil { + return nil, fmt.Errorf("failed to refresh Google OIDC keys: %w", err) + } + key, ok = refreshedKeys[kid] + if !ok { + return nil, fmt.Errorf("unknown key ID: %s", kid) + } + } + + return key, nil + }, jwt.WithValidMethods([]string{"RS256"})) + + if err != nil { + log.Warn().Err(err).Msg("Google Webhook: JWT signature verification failed") + return false + } + + if !token.Valid { + log.Warn().Msg("Google Webhook: token is invalid after verification") return false } diff --git a/internal/handlers/task_handler.go b/internal/handlers/task_handler.go index 32df60c..0fea817 100644 --- a/internal/handlers/task_handler.go +++ b/internal/handlers/task_handler.go @@ -38,19 +38,29 @@ func (h *TaskHandler) ListTasks(c echo.Context) error { userNow := middleware.GetUserNow(c) // Auto-capture timezone from header for background job calculations (e.g., daily digest) - // Runs synchronously — this is a lightweight DB upsert that should complete quickly + // Only write to DB if the timezone has actually changed from the cached value if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" { - h.taskService.UpdateUserTimezone(user.ID, tzHeader) + cachedTZ, _ := c.Get("user_timezone").(string) + if cachedTZ != tzHeader { + h.taskService.UpdateUserTimezone(user.ID, tzHeader) + c.Set("user_timezone", tzHeader) + } } daysThreshold := 30 // Support "days" param first, fall back to "days_threshold" for backward compatibility if d := c.QueryParam("days"); d != "" { if parsed, err := strconv.Atoi(d); err == nil { + if parsed < 1 || parsed > 3650 { + return apperrors.BadRequest("error.days_out_of_range") + } daysThreshold = parsed } } else if d := c.QueryParam("days_threshold"); d != "" { if parsed, err := strconv.Atoi(d); err == nil { + if parsed < 1 || parsed > 3650 { + return apperrors.BadRequest("error.days_out_of_range") + } daysThreshold = parsed } } @@ -97,10 +107,16 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error { // Support "days" param first, fall back to "days_threshold" for backward compatibility if d := c.QueryParam("days"); d != "" { if parsed, err := strconv.Atoi(d); err == nil { + if parsed < 1 || parsed > 3650 { + return apperrors.BadRequest("error.days_out_of_range") + } daysThreshold = parsed } } else if d := c.QueryParam("days_threshold"); d != "" { if parsed, err := strconv.Atoi(d); err == nil { + if parsed < 1 || parsed > 3650 { + return apperrors.BadRequest("error.days_out_of_range") + } daysThreshold = parsed } } diff --git a/internal/handlers/upload_handler.go b/internal/handlers/upload_handler.go index 9532c21..88728ec 100644 --- a/internal/handlers/upload_handler.go +++ b/internal/handlers/upload_handler.go @@ -7,19 +7,31 @@ import ( "github.com/rs/zerolog/log" "github.com/treytartt/honeydue-api/internal/apperrors" + "github.com/treytartt/honeydue-api/internal/dto/responses" + "github.com/treytartt/honeydue-api/internal/i18n" "github.com/treytartt/honeydue-api/internal/middleware" - "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/services" ) +// FileOwnershipChecker verifies whether a user owns a file referenced by URL. +// Implementations should check associated records (e.g., task completion images, +// document files, document images) to determine ownership. +type FileOwnershipChecker interface { + IsFileOwnedByUser(fileURL string, userID uint) (bool, error) +} + // UploadHandler handles file upload endpoints type UploadHandler struct { - storageService *services.StorageService + storageService *services.StorageService + fileOwnershipChecker FileOwnershipChecker } // NewUploadHandler creates a new upload handler -func NewUploadHandler(storageService *services.StorageService) *UploadHandler { - return &UploadHandler{storageService: storageService} +func NewUploadHandler(storageService *services.StorageService, fileOwnershipChecker FileOwnershipChecker) *UploadHandler { + return &UploadHandler{ + storageService: storageService, + fileOwnershipChecker: fileOwnershipChecker, + } } // UploadImage handles POST /api/uploads/image @@ -83,13 +95,14 @@ type DeleteFileRequest struct { // DeleteFile handles DELETE /api/uploads // Expects JSON body with "url" field. -// -// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete -// any file if they know the URL. The upload system does not track which user uploaded -// which file, so a proper fix requires adding an uploads table or file ownership metadata. -// For now, deletions are logged with user ID for audit trail, and StorageService.Delete -// enforces path containment to prevent deleting files outside the upload directory. +// Verifies that the requesting user owns the file by checking associated records +// (task completion images, document files/images) before allowing deletion. func (h *UploadHandler) DeleteFile(c echo.Context) error { + user, err := middleware.MustGetAuthUser(c) + if err != nil { + return err + } + var req DeleteFileRequest if err := c.Bind(&req); err != nil { @@ -100,17 +113,28 @@ func (h *UploadHandler) DeleteFile(c echo.Context) error { return apperrors.BadRequest("error.url_required") } - // Log the deletion with user ID for audit trail - if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok { - log.Info(). - Uint("user_id", user.ID). - Str("file_url", req.URL). - Msg("File deletion requested") + // Verify ownership: the user must own a record that references this file URL + if h.fileOwnershipChecker != nil { + owned, err := h.fileOwnershipChecker.IsFileOwnedByUser(req.URL, user.ID) + if err != nil { + log.Error().Err(err).Uint("user_id", user.ID).Str("file_url", req.URL).Msg("Failed to check file ownership") + return apperrors.Internal(err) + } + if !owned { + log.Warn().Uint("user_id", user.ID).Str("file_url", req.URL).Msg("Unauthorized file deletion attempt") + return apperrors.Forbidden("error.file_access_denied") + } } + // Log the deletion with user ID for audit trail + log.Info(). + Uint("user_id", user.ID). + Str("file_url", req.URL). + Msg("File deletion requested") + if err := h.storageService.Delete(req.URL); err != nil { return err } - return c.JSON(http.StatusOK, map[string]interface{}{"message": "File deleted successfully"}) + return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.file_deleted")}) } diff --git a/internal/handlers/upload_handler_test.go b/internal/handlers/upload_handler_test.go index a9ed33b..e3055e4 100644 --- a/internal/handlers/upload_handler_test.go +++ b/internal/handlers/upload_handler_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/treytartt/honeydue-api/internal/i18n" + "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/testutil" ) @@ -18,12 +19,16 @@ func init() { func TestDeleteFile_MissingURL_Returns400(t *testing.T) { // Use a test storage service — DeleteFile won't reach storage since validation fails first storageSvc := newTestStorageService("/var/uploads") - handler := NewUploadHandler(storageSvc) + handler := NewUploadHandler(storageSvc, nil) e := testutil.SetupTestRouter() - // Register route - e.DELETE("/api/uploads/", handler.DeleteFile) + // Register route with mock auth middleware + testUser := &models.User{FirstName: "Test", Email: "test@test.com"} + testUser.ID = 1 + authGroup := e.Group("/api") + authGroup.Use(testutil.MockAuthMiddleware(testUser)) + authGroup.DELETE("/uploads/", handler.DeleteFile) // Send request with empty JSON body (url field missing) w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token") @@ -32,10 +37,16 @@ func TestDeleteFile_MissingURL_Returns400(t *testing.T) { func TestDeleteFile_EmptyURL_Returns400(t *testing.T) { storageSvc := newTestStorageService("/var/uploads") - handler := NewUploadHandler(storageSvc) + handler := NewUploadHandler(storageSvc, nil) e := testutil.SetupTestRouter() - e.DELETE("/api/uploads/", handler.DeleteFile) + + // Register route with mock auth middleware + testUser := &models.User{FirstName: "Test", Email: "test@test.com"} + testUser.ID = 1 + authGroup := e.Group("/api") + authGroup.Use(testutil.MockAuthMiddleware(testUser)) + authGroup.DELETE("/uploads/", handler.DeleteFile) // Send request with empty url field w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token") diff --git a/internal/i18n/translations/en.json b/internal/i18n/translations/en.json index 2adb7af..30564b3 100644 --- a/internal/i18n/translations/en.json +++ b/internal/i18n/translations/en.json @@ -111,6 +111,11 @@ "error.purchase_token_required": "purchase_token is required for Android", "error.no_file_provided": "No file provided", + "error.url_required": "File URL is required", + "error.file_access_denied": "You don't have access to this file", + "error.days_out_of_range": "Days parameter must be between 1 and 3650", + "error.platform_required": "Platform is required (ios or android)", + "error.registration_id_required": "Registration ID is required", "error.failed_to_fetch_residence_types": "Failed to fetch residence types", "error.failed_to_fetch_task_categories": "Failed to fetch task categories", diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 45c19d9..df4a6b7 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -25,19 +25,24 @@ const ( TokenCacheTTL = 5 * time.Minute // TokenCachePrefix is the prefix for token cache keys TokenCachePrefix = "auth_token_" + // UserCacheTTL is how long full user records are cached in memory to + // avoid hitting the database on every authenticated request. + UserCacheTTL = 30 * time.Second ) // AuthMiddleware provides token authentication middleware type AuthMiddleware struct { - db *gorm.DB - cache *services.CacheService + db *gorm.DB + cache *services.CacheService + userCache *UserCache } // NewAuthMiddleware creates a new auth middleware instance func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { return &AuthMiddleware{ - db: db, - cache: cache, + db: db, + cache: cache, + userCache: NewUserCache(UserCacheTTL), } } @@ -138,7 +143,8 @@ func extractToken(c echo.Context) (string, error) { return token, nil } -// getUserFromCache tries to get user from Redis cache +// getUserFromCache tries to get user from Redis cache, then from the +// in-memory user cache, before falling back to the database. func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { if m.cache == nil { return nil, fmt.Errorf("cache not available") @@ -152,10 +158,20 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m return nil, err } - // Get user from database by ID + // Try in-memory user cache first to avoid a DB round-trip + if cached := m.userCache.Get(userID); cached != nil { + if !cached.IsActive { + m.cache.InvalidateAuthToken(ctx, token) + m.userCache.Invalidate(userID) + return nil, fmt.Errorf("user is inactive") + } + return cached, nil + } + + // In-memory cache miss — fetch from database var user models.User if err := m.db.First(&user, userID).Error; err != nil { - // User was deleted - invalidate cache + // User was deleted - invalidate caches m.cache.InvalidateAuthToken(ctx, token) return nil, err } @@ -166,10 +182,13 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m return nil, fmt.Errorf("user is inactive") } + // Store in in-memory cache for subsequent requests + m.userCache.Set(&user) return &user, nil } -// getUserFromDatabase looks up the token in the database +// getUserFromDatabase looks up the token in the database and caches the +// resulting user record in memory. func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { var authToken models.AuthToken if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil { @@ -181,6 +200,8 @@ func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) return nil, fmt.Errorf("user is inactive") } + // Store in in-memory cache for subsequent requests + m.userCache.Set(&authToken.User) return &authToken.User, nil } @@ -220,7 +241,11 @@ func GetAuthToken(c echo.Context) string { if token == nil { return "" } - return token.(string) + tokenStr, ok := token.(string) + if !ok { + return "" + } + return tokenStr } // MustGetAuthUser retrieves the authenticated user or returns error with 401 diff --git a/internal/middleware/host_check.go b/internal/middleware/host_check.go new file mode 100644 index 0000000..7836934 --- /dev/null +++ b/internal/middleware/host_check.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "net/http" + + "github.com/labstack/echo/v4" + + "github.com/treytartt/honeydue-api/internal/dto/responses" +) + +// HostCheck returns middleware that validates the request Host header against +// a set of allowed hosts. This prevents SSRF attacks where an attacker crafts +// a request with an arbitrary Host header to reach internal services via the +// reverse proxy. +// +// If allowedHosts is empty the middleware is a no-op (all hosts pass). +func HostCheck(allowedHosts []string) echo.MiddlewareFunc { + allowed := make(map[string]struct{}, len(allowedHosts)) + for _, h := range allowedHosts { + allowed[h] = struct{}{} + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // If no allowed hosts configured, skip the check + if len(allowed) == 0 { + return next(c) + } + + host := c.Request().Host + if _, ok := allowed[host]; !ok { + return c.JSON(http.StatusForbidden, responses.ErrorResponse{ + Error: "Forbidden", + }) + } + + return next(c) + } + } +} diff --git a/internal/middleware/rate_limit.go b/internal/middleware/rate_limit.go new file mode 100644 index 0000000..a46cb5a --- /dev/null +++ b/internal/middleware/rate_limit.go @@ -0,0 +1,68 @@ +package middleware + +import ( + "net/http" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "golang.org/x/time/rate" + + "github.com/treytartt/honeydue-api/internal/dto/responses" +) + +// AuthRateLimiter returns rate-limiting middleware tuned for authentication +// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client +// IP address. +// +// Parameters: +// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute) +// - burst: maximum burst size above the sustained rate +func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc { + store := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{ + Rate: ratePerSecond, + Burst: burst, + ExpiresIn: 5 * time.Minute, + }, + ) + + return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{ + Skipper: middleware.DefaultSkipper, + IdentifierExtractor: func(c echo.Context) (string, error) { + return c.RealIP(), nil + }, + Store: store, + DenyHandler: func(c echo.Context, _ string, _ error) error { + return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{ + Error: "Too many requests. Please try again later.", + }) + }, + ErrorHandler: func(c echo.Context, err error) error { + return c.JSON(http.StatusForbidden, responses.ErrorResponse{ + Error: "Unable to process request.", + }) + }, + }) +} + +// LoginRateLimiter returns rate-limiting middleware for login endpoints. +// Allows 10 requests per minute with a burst of 5. +func LoginRateLimiter() echo.MiddlewareFunc { + // 10 requests per 60 seconds = ~0.167 req/s, burst 5 + return AuthRateLimiter(rate.Limit(10.0/60.0), 5) +} + +// RegistrationRateLimiter returns rate-limiting middleware for registration +// endpoints. Allows 5 requests per minute with a burst of 3. +func RegistrationRateLimiter() echo.MiddlewareFunc { + // 5 requests per 60 seconds = ~0.083 req/s, burst 3 + return AuthRateLimiter(rate.Limit(5.0/60.0), 3) +} + +// PasswordResetRateLimiter returns rate-limiting middleware for password +// reset endpoints. Allows 3 requests per minute with a burst of 2. +func PasswordResetRateLimiter() echo.MiddlewareFunc { + // 3 requests per 60 seconds = 0.05 req/s, burst 2 + return AuthRateLimiter(rate.Limit(3.0/60.0), 2) +} diff --git a/internal/middleware/timezone.go b/internal/middleware/timezone.go index 1bd9d52..12a352f 100644 --- a/internal/middleware/timezone.go +++ b/internal/middleware/timezone.go @@ -7,14 +7,25 @@ import ( ) const ( - // TimezoneKey is the key used to store the user's timezone in the context + // TimezoneKey is the key used to store the user's timezone *time.Location in the context TimezoneKey = "user_timezone" + // TimezoneNameKey stores the raw IANA timezone string from the request header + TimezoneNameKey = "user_timezone_name" + // TimezoneChangedKey is a bool context key indicating whether the timezone + // differs from the previously cached value for this user. Handlers should + // only persist the timezone to DB when this is true. + TimezoneChangedKey = "timezone_changed" // UserNowKey is the key used to store the timezone-aware "now" time in the context UserNowKey = "user_now" // TimezoneHeader is the HTTP header name for the user's timezone TimezoneHeader = "X-Timezone" ) +// package-level timezone cache shared across requests. It is safe for +// concurrent use and has no TTL — entries are only updated when a new +// timezone value is observed for a given user. +var tzCache = NewTimezoneCache() + // TimezoneMiddleware extracts the user's timezone from the request header // and stores a timezone-aware "now" time in the context. // @@ -22,14 +33,31 @@ const ( // or a UTC offset (e.g., "-08:00", "+05:30"). // // If no timezone is provided or it's invalid, UTC is used as the default. +// +// The middleware also compares the incoming timezone with a cached value per +// user and sets TimezoneChangedKey in the context so downstream handlers +// know whether a DB write is needed. func TimezoneMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { tzName := c.Request().Header.Get(TimezoneHeader) loc := parseTimezone(tzName) - // Store the location and the current time in that timezone + // Store the location and the raw name in the context c.Set(TimezoneKey, loc) + c.Set(TimezoneNameKey, tzName) + + // Determine whether the timezone changed for this user so handlers + // can skip unnecessary DB writes. + changed := false + if tzName != "" { + if user := GetAuthUser(c); user != nil { + if !tzCache.GetAndCompare(user.ID, tzName) { + changed = true + } + } + } + c.Set(TimezoneChangedKey, changed) // Calculate "now" in the user's timezone, then get start of day // For date comparisons, we want to compare against the START of the user's current day @@ -42,6 +70,20 @@ func TimezoneMiddleware() echo.MiddlewareFunc { } } +// IsTimezoneChanged returns true when the user's timezone header differs from +// the previously observed value. Handlers should only persist the timezone to +// DB when this returns true. +func IsTimezoneChanged(c echo.Context) bool { + val, ok := c.Get(TimezoneChangedKey).(bool) + return ok && val +} + +// GetTimezoneName returns the raw timezone string from the request header. +func GetTimezoneName(c echo.Context) string { + val, _ := c.Get(TimezoneNameKey).(string) + return val +} + // parseTimezone parses a timezone string and returns a *time.Location. // Supports IANA timezone names (e.g., "America/Los_Angeles") and // UTC offsets (e.g., "-08:00", "+05:30"). diff --git a/internal/middleware/user_cache.go b/internal/middleware/user_cache.go new file mode 100644 index 0000000..c442a4a --- /dev/null +++ b/internal/middleware/user_cache.go @@ -0,0 +1,113 @@ +package middleware + +import ( + "sync" + "time" + + "github.com/treytartt/honeydue-api/internal/models" +) + +// userCacheEntry holds a cached user record with an expiration time. +type userCacheEntry struct { + user *models.User + expiresAt time.Time +} + +// UserCache is a concurrency-safe in-memory cache for User records, keyed by +// user ID. Entries expire after a configurable TTL. The cache uses a sync.Map +// for lock-free reads on the hot path, with periodic lazy eviction of stale +// entries during Set operations. +type UserCache struct { + store sync.Map + ttl time.Duration + lastGC time.Time + gcMu sync.Mutex + gcEvery time.Duration +} + +// NewUserCache creates a UserCache with the given TTL for entries. +func NewUserCache(ttl time.Duration) *UserCache { + return &UserCache{ + ttl: ttl, + lastGC: time.Now(), + gcEvery: 2 * time.Minute, + } +} + +// Get returns a cached user by ID, or nil if not found or expired. +func (c *UserCache) Get(userID uint) *models.User { + val, ok := c.store.Load(userID) + if !ok { + return nil + } + entry := val.(*userCacheEntry) + if time.Now().After(entry.expiresAt) { + c.store.Delete(userID) + return nil + } + // Return a shallow copy so callers cannot mutate the cached value. + user := *entry.user + return &user +} + +// Set stores a user in the cache. It also triggers a background garbage- +// collection sweep if enough time has elapsed since the last one. +func (c *UserCache) Set(user *models.User) { + // Store a copy to prevent external mutation of the cached object. + copied := *user + c.store.Store(user.ID, &userCacheEntry{ + user: &copied, + expiresAt: time.Now().Add(c.ttl), + }) + c.maybeGC() +} + +// Invalidate removes a user from the cache by ID. +func (c *UserCache) Invalidate(userID uint) { + c.store.Delete(userID) +} + +// maybeGC lazily sweeps expired entries at most once per gcEvery interval. +func (c *UserCache) maybeGC() { + c.gcMu.Lock() + if time.Since(c.lastGC) < c.gcEvery { + c.gcMu.Unlock() + return + } + c.lastGC = time.Now() + c.gcMu.Unlock() + + now := time.Now() + c.store.Range(func(key, value any) bool { + entry := value.(*userCacheEntry) + if now.After(entry.expiresAt) { + c.store.Delete(key) + } + return true + }) +} + +// TimezoneCache tracks the last-known timezone per user ID so the timezone +// middleware only writes to the database when the value actually changes. +type TimezoneCache struct { + store sync.Map +} + +// NewTimezoneCache creates a new TimezoneCache. +func NewTimezoneCache() *TimezoneCache { + return &TimezoneCache{} +} + +// GetAndCompare returns true if the cached timezone for the user matches tz. +// If the timezone is different (or not yet cached), it updates the cache and +// returns false, signaling that a DB write is needed. +func (tc *TimezoneCache) GetAndCompare(userID uint, tz string) (unchanged bool) { + val, loaded := tc.store.Load(userID) + if loaded { + if cached, ok := val.(string); ok && cached == tz { + return true + } + } + tc.store.Store(userID, tz) + return false +} diff --git a/internal/models/base.go b/internal/models/base.go index ec85510..70b84a5 100644 --- a/internal/models/base.go +++ b/internal/models/base.go @@ -13,12 +13,6 @@ type BaseModel struct { UpdatedAt time.Time `json:"updated_at"` } -// SoftDeleteModel extends BaseModel with soft delete support -type SoftDeleteModel struct { - BaseModel - DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` -} - // BeforeCreate sets timestamps before creating a record func (b *BaseModel) BeforeCreate(tx *gorm.DB) error { now := time.Now().UTC() diff --git a/internal/models/notification.go b/internal/models/notification.go index 43f62ad..280144b 100644 --- a/internal/models/notification.go +++ b/internal/models/notification.go @@ -29,10 +29,10 @@ type NotificationPreference struct { // Custom notification times (nullable, stored as UTC hour 0-23) // When nil, system defaults from config are used - TaskDueSoonHour *int `gorm:"column:task_due_soon_hour" json:"task_due_soon_hour"` - TaskOverdueHour *int `gorm:"column:task_overdue_hour" json:"task_overdue_hour"` - WarrantyExpiringHour *int `gorm:"column:warranty_expiring_hour" json:"warranty_expiring_hour"` - DailyDigestHour *int `gorm:"column:daily_digest_hour" json:"daily_digest_hour"` + TaskDueSoonHour *int `gorm:"column:task_due_soon_hour" json:"task_due_soon_hour" validate:"omitempty,min=0,max=23"` + TaskOverdueHour *int `gorm:"column:task_overdue_hour" json:"task_overdue_hour" validate:"omitempty,min=0,max=23"` + WarrantyExpiringHour *int `gorm:"column:warranty_expiring_hour" json:"warranty_expiring_hour" validate:"omitempty,min=0,max=23"` + DailyDigestHour *int `gorm:"column:daily_digest_hour" json:"daily_digest_hour" validate:"omitempty,min=0,max=23"` // User timezone for background job calculations (IANA name, e.g., "America/Los_Angeles") // Auto-captured from X-Timezone header on API calls diff --git a/internal/models/task.go b/internal/models/task.go index 72f3241..66f8195 100644 --- a/internal/models/task.go +++ b/internal/models/task.go @@ -181,88 +181,6 @@ func (t *Task) IsDueSoon(days int) bool { return !effectiveDate.Before(now) && effectiveDate.Before(threshold) } -// GetKanbanColumn returns the kanban column name for this task using the -// Chain of Responsibility pattern from the categorization package. -// Uses UTC time for categorization. -// -// For timezone-aware categorization, use GetKanbanColumnWithTimezone. -func (t *Task) GetKanbanColumn(daysThreshold int) string { - // Import would cause circular dependency, so we inline the logic - // This delegates to the categorization package via internal/task re-export - return t.GetKanbanColumnWithTimezone(daysThreshold, time.Now().UTC()) -} - -// GetKanbanColumnWithTimezone returns the kanban column name using a specific -// time (in the user's timezone). The time is used to determine "today" for -// overdue/due-soon calculations. -// -// Example: For a user in Tokyo, pass time.Now().In(tokyoLocation) to get -// accurate categorization relative to their local date. -func (t *Task) GetKanbanColumnWithTimezone(daysThreshold int, now time.Time) string { - // Note: We can't import categorization directly due to circular dependency. - // Instead, this method implements the categorization logic inline. - // The logic MUST match categorization.Chain exactly. - - if daysThreshold <= 0 { - daysThreshold = 30 - } - - // Start of day normalization - startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) - threshold := startOfDay.AddDate(0, 0, daysThreshold) - - // Priority 1: Cancelled - if t.IsCancelled { - return "cancelled_tasks" - } - - // Priority 2: Archived (goes to cancelled column - both are "inactive" states) - if t.IsArchived { - return "cancelled_tasks" - } - - // Priority 3: Completed (NextDueDate nil with completions) - hasCompletions := len(t.Completions) > 0 || t.CompletionCount > 0 - if t.NextDueDate == nil && hasCompletions { - return "completed_tasks" - } - - // Priority 4: In Progress - if t.InProgress { - return "in_progress_tasks" - } - - // Get effective date: NextDueDate ?? DueDate - var effectiveDate *time.Time - if t.NextDueDate != nil { - effectiveDate = t.NextDueDate - } else { - effectiveDate = t.DueDate - } - - if effectiveDate != nil { - // Normalize effective date to same timezone for calendar date comparison - // Task dates are stored as UTC but represent calendar dates (YYYY-MM-DD) - normalizedEffective := time.Date( - effectiveDate.Year(), effectiveDate.Month(), effectiveDate.Day(), - 0, 0, 0, 0, now.Location(), - ) - - // Priority 5: Overdue (effective date before today) - if normalizedEffective.Before(startOfDay) { - return "overdue_tasks" - } - - // Priority 6: Due Soon (effective date before threshold) - if normalizedEffective.Before(threshold) { - return "due_soon_tasks" - } - } - - // Priority 7: Upcoming (default) - return "upcoming_tasks" -} - // TaskCompletion represents the task_taskcompletion table type TaskCompletion struct { BaseModel diff --git a/internal/models/task_test.go b/internal/models/task_test.go index 4b96e44..594d362 100644 --- a/internal/models/task_test.go +++ b/internal/models/task_test.go @@ -248,180 +248,10 @@ func TestDocument_JSONSerialization(t *testing.T) { assert.Equal(t, "5000", result["purchase_price"]) // Decimal serializes as string } -// ============================================================================ -// TASK KANBAN COLUMN TESTS -// These tests verify GetKanbanColumn and GetKanbanColumnWithTimezone methods -// ============================================================================ - func timePtr(t time.Time) *time.Time { return &t } -func TestTask_GetKanbanColumn_PriorityOrder(t *testing.T) { - now := time.Date(2025, 12, 16, 12, 0, 0, 0, time.UTC) - yesterday := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC) - in5Days := time.Date(2025, 12, 21, 0, 0, 0, 0, time.UTC) - in60Days := time.Date(2026, 2, 14, 0, 0, 0, 0, time.UTC) - - tests := []struct { - name string - task *Task - expected string - }{ - // Priority 1: Cancelled - { - name: "cancelled takes highest priority", - task: &Task{ - IsCancelled: true, - NextDueDate: timePtr(yesterday), - InProgress: true, - }, - expected: "cancelled_tasks", - }, - - // Priority 2: Completed - { - name: "completed: NextDueDate nil with completions", - task: &Task{ - IsCancelled: false, - NextDueDate: nil, - DueDate: timePtr(yesterday), - Completions: []TaskCompletion{{BaseModel: BaseModel{ID: 1}}}, - }, - expected: "completed_tasks", - }, - - // Priority 3: In Progress - { - name: "in progress takes priority over overdue", - task: &Task{ - IsCancelled: false, - NextDueDate: timePtr(yesterday), - InProgress: true, - }, - expected: "in_progress_tasks", - }, - - // Priority 4: Overdue - { - name: "overdue: effective date in past", - task: &Task{ - IsCancelled: false, - NextDueDate: timePtr(yesterday), - }, - expected: "overdue_tasks", - }, - - // Priority 5: Due Soon - { - name: "due soon: within 30-day threshold", - task: &Task{ - IsCancelled: false, - NextDueDate: timePtr(in5Days), - }, - expected: "due_soon_tasks", - }, - - // Priority 6: Upcoming - { - name: "upcoming: beyond threshold", - task: &Task{ - IsCancelled: false, - NextDueDate: timePtr(in60Days), - }, - expected: "upcoming_tasks", - }, - { - name: "upcoming: no due date", - task: &Task{ - IsCancelled: false, - NextDueDate: nil, - DueDate: nil, - }, - expected: "upcoming_tasks", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.task.GetKanbanColumnWithTimezone(30, now) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestTask_GetKanbanColumnWithTimezone_TimezoneAware(t *testing.T) { - // Task due Dec 17, 2025 - taskDueDate := time.Date(2025, 12, 17, 0, 0, 0, 0, time.UTC) - - task := &Task{ - NextDueDate: timePtr(taskDueDate), - IsCancelled: false, - } - - // At 11 PM UTC on Dec 16 (UTC user) - task is tomorrow, due_soon - utcDec16Evening := time.Date(2025, 12, 16, 23, 0, 0, 0, time.UTC) - result := task.GetKanbanColumnWithTimezone(30, utcDec16Evening) - assert.Equal(t, "due_soon_tasks", result, "UTC Dec 16 evening") - - // At 8 AM UTC on Dec 17 (UTC user) - task is today, due_soon - utcDec17Morning := time.Date(2025, 12, 17, 8, 0, 0, 0, time.UTC) - result = task.GetKanbanColumnWithTimezone(30, utcDec17Morning) - assert.Equal(t, "due_soon_tasks", result, "UTC Dec 17 morning") - - // At 8 AM UTC on Dec 18 (UTC user) - task was yesterday, overdue - utcDec18Morning := time.Date(2025, 12, 18, 8, 0, 0, 0, time.UTC) - result = task.GetKanbanColumnWithTimezone(30, utcDec18Morning) - assert.Equal(t, "overdue_tasks", result, "UTC Dec 18 morning") - - // Tokyo user at 11 PM UTC Dec 16 = 8 AM Dec 17 Tokyo - // Task due Dec 17 is TODAY for Tokyo user - due_soon - tokyo, _ := time.LoadLocation("Asia/Tokyo") - tokyoDec17Morning := utcDec16Evening.In(tokyo) - result = task.GetKanbanColumnWithTimezone(30, tokyoDec17Morning) - assert.Equal(t, "due_soon_tasks", result, "Tokyo Dec 17 morning") - - // Tokyo at 8 AM Dec 18 UTC = 5 PM Dec 18 Tokyo - // Task due Dec 17 was YESTERDAY for Tokyo - overdue - tokyoDec18 := utcDec18Morning.In(tokyo) - result = task.GetKanbanColumnWithTimezone(30, tokyoDec18) - assert.Equal(t, "overdue_tasks", result, "Tokyo Dec 18") -} - -func TestTask_GetKanbanColumnWithTimezone_DueSoonThreshold(t *testing.T) { - now := time.Date(2025, 12, 16, 12, 0, 0, 0, time.UTC) - - // Task due in 29 days - within 30-day threshold - due29Days := time.Date(2026, 1, 14, 0, 0, 0, 0, time.UTC) - task29 := &Task{NextDueDate: timePtr(due29Days)} - result := task29.GetKanbanColumnWithTimezone(30, now) - assert.Equal(t, "due_soon_tasks", result, "29 days should be due_soon") - - // Task due in exactly 30 days - at threshold boundary (upcoming, not due_soon) - due30Days := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC) - task30 := &Task{NextDueDate: timePtr(due30Days)} - result = task30.GetKanbanColumnWithTimezone(30, now) - assert.Equal(t, "upcoming_tasks", result, "30 days should be upcoming (at boundary)") - - // Task due in 31 days - beyond threshold - due31Days := time.Date(2026, 1, 16, 0, 0, 0, 0, time.UTC) - task31 := &Task{NextDueDate: timePtr(due31Days)} - result = task31.GetKanbanColumnWithTimezone(30, now) - assert.Equal(t, "upcoming_tasks", result, "31 days should be upcoming") -} - -func TestTask_GetKanbanColumn_CompletionCount(t *testing.T) { - // Test that CompletionCount is also used for completion detection - task := &Task{ - NextDueDate: nil, - CompletionCount: 1, // Using CompletionCount instead of Completions slice - Completions: []TaskCompletion{}, - } - - result := task.GetKanbanColumn(30) - assert.Equal(t, "completed_tasks", result) -} - func TestTask_IsOverdueAt_DayBased(t *testing.T) { // Test that IsOverdueAt uses day-based comparison now := time.Date(2025, 12, 16, 15, 0, 0, 0, time.UTC) // 3 PM UTC diff --git a/internal/models/user.go b/internal/models/user.go index 91b293b..da066fe 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -18,7 +18,7 @@ type User struct { Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"` FirstName string `gorm:"column:first_name;size:150" json:"first_name"` LastName string `gorm:"column:last_name;size:150" json:"last_name"` - Email string `gorm:"column:email;size:254" json:"email"` + Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"` IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"` IsActive bool `gorm:"column:is_active;default:true" json:"is_active"` DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"` @@ -142,7 +142,7 @@ func (UserProfile) TableName() string { type ConfirmationCode struct { BaseModel UserID uint `gorm:"column:user_id;index;not null" json:"user_id"` - Code string `gorm:"column:code;size:6;not null" json:"code"` + Code string `gorm:"column:code;size:6;not null" json:"-"` ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"` IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"` diff --git a/internal/monitoring/collector.go b/internal/monitoring/collector.go index b8e750a..2fe3ee1 100644 --- a/internal/monitoring/collector.go +++ b/internal/monitoring/collector.go @@ -2,6 +2,7 @@ package monitoring import ( "runtime" + "sync" "time" "github.com/hibiken/asynq" @@ -20,6 +21,7 @@ type Collector struct { httpCollector *HTTPStatsCollector // nil for worker asynqClient *asynq.Inspector // nil for api stopChan chan struct{} + stopOnce sync.Once } // NewCollector creates a new stats collector @@ -193,7 +195,9 @@ func (c *Collector) publishStats() { } } -// Stop stops the stats publishing +// Stop stops the stats publishing. It is safe to call multiple times. func (c *Collector) Stop() { - close(c.stopChan) + c.stopOnce.Do(func() { + close(c.stopChan) + }) } diff --git a/internal/monitoring/handler.go b/internal/monitoring/handler.go index a028e1b..d90b349 100644 --- a/internal/monitoring/handler.go +++ b/internal/monitoring/handler.go @@ -150,7 +150,10 @@ func (h *Handler) WebSocket(c echo.Context) error { defer statsTicker.Stop() // Send initial stats - h.sendStats(conn, &wsMu) + if err := h.sendStats(conn, &wsMu); err != nil { + cancel() + return nil + } for { select { @@ -173,11 +176,16 @@ func (h *Handler) WebSocket(c echo.Context) error { if err != nil { log.Debug().Err(err).Msg("WebSocket write error") + cancel() + return nil } case <-statsTicker.C: // Send periodic stats update - h.sendStats(conn, &wsMu) + if err := h.sendStats(conn, &wsMu); err != nil { + cancel() + return nil + } case <-ctx.Done(): return nil @@ -185,9 +193,11 @@ func (h *Handler) WebSocket(c echo.Context) error { } } -func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) { +func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) error { allStats, err := h.statsStore.GetAllStats() if err != nil { + log.Error().Err(err).Msg("failed to send stats") + return err } wsMsg := WSMessage{ @@ -196,6 +206,12 @@ func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) { } mu.Lock() - conn.WriteJSON(wsMsg) + err = conn.WriteJSON(wsMsg) mu.Unlock() + + if err != nil { + log.Debug().Err(err).Msg("WebSocket write error sending stats") + } + + return err } diff --git a/internal/monitoring/middleware.go b/internal/monitoring/middleware.go index 93820d9..1800d59 100644 --- a/internal/monitoring/middleware.go +++ b/internal/monitoring/middleware.go @@ -10,39 +10,33 @@ import ( // HTTPStatsCollector collects HTTP request metrics type HTTPStatsCollector struct { - mu sync.RWMutex - requests map[string]int64 // endpoint -> count - totalLatency map[string]time.Duration // endpoint -> total latency - errors map[string]int64 // endpoint -> error count - byStatus map[int]int64 // status code -> count - latencies []latencySample // recent latency samples for P95 - startTime time.Time - lastReset time.Time -} - -type latencySample struct { - endpoint string - latency time.Duration - timestamp time.Time + mu sync.RWMutex + requests map[string]int64 // endpoint -> count + totalLatency map[string]time.Duration // endpoint -> total latency + errors map[string]int64 // endpoint -> error count + byStatus map[int]int64 // status code -> count + endpointLatencies map[string][]time.Duration // per-endpoint sorted latency buffers for P95 + startTime time.Time + lastReset time.Time } const ( - maxLatencySamples = 1000 - maxEndpoints = 200 // Cap unique endpoints tracked - statsResetPeriod = 1 * time.Hour // Reset stats periodically to prevent unbounded growth + maxLatencySamplesPerEndpoint = 200 // Max latency samples kept per endpoint + maxEndpoints = 200 // Cap unique endpoints tracked + statsResetPeriod = 1 * time.Hour // Reset stats periodically to prevent unbounded growth ) // NewHTTPStatsCollector creates a new HTTP stats collector func NewHTTPStatsCollector() *HTTPStatsCollector { now := time.Now() return &HTTPStatsCollector{ - requests: make(map[string]int64), - totalLatency: make(map[string]time.Duration), - errors: make(map[string]int64), - byStatus: make(map[int]int64), - latencies: make([]latencySample, 0, maxLatencySamples), - startTime: now, - lastReset: now, + requests: make(map[string]int64), + totalLatency: make(map[string]time.Duration), + errors: make(map[string]int64), + byStatus: make(map[int]int64), + endpointLatencies: make(map[string][]time.Duration), + startTime: now, + lastReset: now, } } @@ -70,17 +64,22 @@ func (c *HTTPStatsCollector) Record(endpoint string, latency time.Duration, stat c.errors[endpoint]++ } - // Store latency sample - c.latencies = append(c.latencies, latencySample{ - endpoint: endpoint, - latency: latency, - timestamp: time.Now(), + // Insert latency into per-endpoint sorted buffer using binary search + buf := c.endpointLatencies[endpoint] + idx := sort.Search(len(buf), func(i int) bool { + return buf[i] >= latency }) + buf = append(buf, 0) + copy(buf[idx+1:], buf[idx:]) + buf[idx] = latency - // Keep only recent samples - if len(c.latencies) > maxLatencySamples { - c.latencies = c.latencies[len(c.latencies)-maxLatencySamples:] + // Trim to max samples per endpoint by removing the median element + // to preserve distribution tails (important for P95 accuracy) + if len(buf) > maxLatencySamplesPerEndpoint { + mid := len(buf) / 2 + buf = append(buf[:mid], buf[mid+1:]...) } + c.endpointLatencies[endpoint] = buf } // resetLocked resets stats while holding the lock @@ -89,7 +88,7 @@ func (c *HTTPStatsCollector) resetLocked() { c.totalLatency = make(map[string]time.Duration) c.errors = make(map[string]int64) c.byStatus = make(map[int]int64) - c.latencies = make([]latencySample, 0, maxLatencySamples) + c.endpointLatencies = make(map[string][]time.Duration) c.lastReset = time.Now() // Keep startTime for uptime calculation } @@ -147,33 +146,23 @@ func (c *HTTPStatsCollector) GetStats() HTTPStats { return stats } -// calculateP95 calculates the 95th percentile latency for an endpoint -// Must be called with read lock held +// calculateP95 calculates the 95th percentile latency for an endpoint. +// The per-endpoint buffer is maintained in sorted order during insertion, +// so this is an O(1) index lookup. +// Must be called with read lock held. func (c *HTTPStatsCollector) calculateP95(endpoint string) float64 { - var endpointLatencies []time.Duration - - for _, sample := range c.latencies { - if sample.endpoint == endpoint { - endpointLatencies = append(endpointLatencies, sample.latency) - } - } - - if len(endpointLatencies) == 0 { + buf := c.endpointLatencies[endpoint] + if len(buf) == 0 { return 0 } - // Sort latencies - sort.Slice(endpointLatencies, func(i, j int) bool { - return endpointLatencies[i] < endpointLatencies[j] - }) - - // Calculate P95 index - p95Index := int(float64(len(endpointLatencies)) * 0.95) - if p95Index >= len(endpointLatencies) { - p95Index = len(endpointLatencies) - 1 + // Buffer is already sorted; direct index lookup + p95Index := int(float64(len(buf)) * 0.95) + if p95Index >= len(buf) { + p95Index = len(buf) - 1 } - return float64(endpointLatencies[p95Index].Milliseconds()) + return float64(buf[p95Index].Milliseconds()) } // Reset clears all collected stats @@ -185,7 +174,7 @@ func (c *HTTPStatsCollector) Reset() { c.totalLatency = make(map[string]time.Duration) c.errors = make(map[string]int64) c.byStatus = make(map[int]int64) - c.latencies = make([]latencySample, 0, maxLatencySamples) + c.endpointLatencies = make(map[string][]time.Duration) c.startTime = time.Now() } diff --git a/internal/monitoring/service.go b/internal/monitoring/service.go index 1a1d3e0..31dd2e0 100644 --- a/internal/monitoring/service.go +++ b/internal/monitoring/service.go @@ -2,6 +2,7 @@ package monitoring import ( "io" + "sync" "time" "github.com/hibiken/asynq" @@ -31,6 +32,8 @@ type Service struct { logWriter *RedisLogWriter db *gorm.DB settingsStopCh chan struct{} + stopOnce sync.Once + statsInterval time.Duration } // Config holds configuration for the monitoring service @@ -71,6 +74,7 @@ func NewService(cfg Config) *Service { logWriter: logWriter, db: cfg.DB, settingsStopCh: make(chan struct{}), + statsInterval: cfg.StatsInterval, } // Check initial setting from database @@ -90,11 +94,11 @@ func (s *Service) SetAsynqInspector(inspector *asynq.Inspector) { func (s *Service) Start() { log.Info(). Str("process", s.process). - Dur("interval", DefaultStatsInterval). + Dur("interval", s.statsInterval). Bool("enabled", s.logWriter.IsEnabled()). Msg("Starting monitoring service") - s.collector.StartPublishing(DefaultStatsInterval) + s.collector.StartPublishing(s.statsInterval) // Start settings sync if database is available if s.db != nil { @@ -102,17 +106,19 @@ func (s *Service) Start() { } } -// Stop stops the monitoring service +// Stop stops the monitoring service. It is safe to call multiple times. func (s *Service) Stop() { - // Stop settings sync - close(s.settingsStopCh) + s.stopOnce.Do(func() { + // Stop settings sync + close(s.settingsStopCh) - s.collector.Stop() + s.collector.Stop() - // Flush and close the log writer's background goroutine - s.logWriter.Close() + // Flush and close the log writer's background goroutine + s.logWriter.Close() - log.Info().Str("process", s.process).Msg("Monitoring service stopped") + log.Info().Str("process", s.process).Msg("Monitoring service stopped") + }) } // syncSettingsFromDB checks the database for the enable_monitoring setting diff --git a/internal/monitoring/writer.go b/internal/monitoring/writer.go index 30e8f50..907edf7 100644 --- a/internal/monitoring/writer.go +++ b/internal/monitoring/writer.go @@ -2,6 +2,7 @@ package monitoring import ( "encoding/json" + "sync" "sync/atomic" "time" @@ -18,11 +19,12 @@ const ( // It uses a single background goroutine with a buffered channel instead of // spawning a new goroutine per log line, preventing unbounded goroutine growth. type RedisLogWriter struct { - buffer *LogBuffer - process string - enabled atomic.Bool - ch chan LogEntry - done chan struct{} + buffer *LogBuffer + process string + enabled atomic.Bool + ch chan LogEntry + done chan struct{} + closeOnce sync.Once } // NewRedisLogWriter creates a new writer that captures logs to Redis. @@ -53,9 +55,12 @@ func (w *RedisLogWriter) drainLoop() { // Close shuts down the background goroutine. It should be called during // graceful shutdown to ensure all buffered entries are flushed. +// It is safe to call multiple times. func (w *RedisLogWriter) Close() { - close(w.ch) - <-w.done // Wait for drain to finish + w.closeOnce.Do(func() { + close(w.ch) + <-w.done // Wait for drain to finish + }) } // SetEnabled enables or disables log capture to Redis diff --git a/internal/push/apns.go b/internal/push/apns.go index 4a714d7..4544ea3 100644 --- a/internal/push/apns.go +++ b/internal/push/apns.go @@ -38,16 +38,15 @@ func NewAPNsClient(cfg *config.PushConfig) (*APNsClient, error) { TeamID: cfg.APNSTeamID, } - // Create client - production or sandbox - // Use APNSProduction if set, otherwise fall back to inverse of APNSSandbox + // Create client - sandbox if APNSSandbox is true, production otherwise. + // APNSSandbox is the single source of truth (defaults to true for safety). var client *apns2.Client - useProduction := cfg.APNSProduction || !cfg.APNSSandbox - if useProduction { - client = apns2.NewTokenClient(authToken).Production() - log.Info().Msg("APNs client configured for PRODUCTION") - } else { + if cfg.APNSSandbox { client = apns2.NewTokenClient(authToken).Development() log.Info().Msg("APNs client configured for DEVELOPMENT/SANDBOX") + } else { + client = apns2.NewTokenClient(authToken).Production() + log.Info().Msg("APNs client configured for PRODUCTION") } return &APNsClient{ diff --git a/internal/push/client.go b/internal/push/client.go index 7820466..a6aabbe 100644 --- a/internal/push/client.go +++ b/internal/push/client.go @@ -38,17 +38,17 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) { log.Warn().Msg("APNs not configured - iOS push disabled") } - // Initialize FCM client (Android) - if cfg.FCMServerKey != "" { + // Initialize FCM client (Android) - requires project ID + service account credentials + if cfg.FCMProjectID != "" && (cfg.FCMServiceAccountPath != "" || cfg.FCMServiceAccountJSON != "") { fcmClient, err := NewFCMClient(cfg) if err != nil { - log.Warn().Err(err).Msg("Failed to initialize FCM client - Android push disabled") + log.Warn().Err(err).Msg("Failed to initialize FCM v1 client - Android push disabled") } else { client.fcm = fcmClient - log.Info().Msg("FCM client initialized successfully") + log.Info().Msg("FCM v1 client initialized successfully") } } else { - log.Warn().Msg("FCM not configured - Android push disabled") + log.Warn().Msg("FCM not configured (need FCM_PROJECT_ID + FCM_SERVICE_ACCOUNT_PATH or FCM_SERVICE_ACCOUNT_JSON) - Android push disabled") } return client, nil diff --git a/internal/push/fcm.go b/internal/push/fcm.go index 48c959c..c2fcfeb 100644 --- a/internal/push/fcm.go +++ b/internal/push/fcm.go @@ -5,138 +5,304 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" + "os" "time" "github.com/rs/zerolog/log" + "golang.org/x/oauth2/google" "github.com/treytartt/honeydue-api/internal/config" ) -const fcmEndpoint = "https://fcm.googleapis.com/fcm/send" +const ( + // fcmV1EndpointFmt is the FCM HTTP v1 API endpoint template. + fcmV1EndpointFmt = "https://fcm.googleapis.com/v1/projects/%s/messages:send" -// FCMClient handles direct communication with Firebase Cloud Messaging + // fcmScope is the OAuth 2.0 scope required for FCM HTTP v1 API. + fcmScope = "https://www.googleapis.com/auth/firebase.messaging" +) + +// FCMClient handles communication with Firebase Cloud Messaging +// using the HTTP v1 API and OAuth 2.0 service account authentication. type FCMClient struct { - serverKey string + projectID string + endpoint string httpClient *http.Client } -// FCMMessage represents an FCM message payload -type FCMMessage struct { - To string `json:"to,omitempty"` - RegistrationIDs []string `json:"registration_ids,omitempty"` - Notification *FCMNotification `json:"notification,omitempty"` - Data map[string]string `json:"data,omitempty"` - Priority string `json:"priority,omitempty"` - ContentAvailable bool `json:"content_available,omitempty"` +// --- Request types (FCM v1 API) --- + +// fcmV1Request is the top-level request body for the FCM v1 API. +type fcmV1Request struct { + Message *fcmV1Message `json:"message"` } -// FCMNotification represents the notification payload +// fcmV1Message represents a single FCM v1 message. +type fcmV1Message struct { + Token string `json:"token"` + Notification *FCMNotification `json:"notification,omitempty"` + Data map[string]string `json:"data,omitempty"` + Android *fcmAndroidConfig `json:"android,omitempty"` +} + +// FCMNotification represents the notification payload. type FCMNotification struct { Title string `json:"title,omitempty"` Body string `json:"body,omitempty"` - Sound string `json:"sound,omitempty"` - Badge string `json:"badge,omitempty"` - Icon string `json:"icon,omitempty"` } -// FCMResponse represents the FCM API response -type FCMResponse struct { - MulticastID int64 `json:"multicast_id"` - Success int `json:"success"` - Failure int `json:"failure"` - CanonicalIDs int `json:"canonical_ids"` - Results []FCMResult `json:"results"` +// fcmAndroidConfig provides Android-specific message configuration. +type fcmAndroidConfig struct { + Priority string `json:"priority,omitempty"` } -// FCMResult represents a single result in the FCM response -type FCMResult struct { - MessageID string `json:"message_id,omitempty"` - RegistrationID string `json:"registration_id,omitempty"` - Error string `json:"error,omitempty"` +// --- Response types (FCM v1 API) --- + +// fcmV1Response is the successful response from the FCM v1 API. +type fcmV1Response struct { + Name string `json:"name"` // e.g. "projects/myproject/messages/0:1234567890" } -// NewFCMClient creates a new FCM client +// fcmV1ErrorResponse is the error response from the FCM v1 API. +type fcmV1ErrorResponse struct { + Error fcmV1Error `json:"error"` +} + +// fcmV1Error contains the structured error details. +type fcmV1Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details json.RawMessage `json:"details,omitempty"` +} + +// --- Error types --- + +// FCMErrorCode represents well-known FCM v1 error codes for programmatic handling. +type FCMErrorCode string + +const ( + FCMErrUnregistered FCMErrorCode = "UNREGISTERED" + FCMErrQuotaExceeded FCMErrorCode = "QUOTA_EXCEEDED" + FCMErrUnavailable FCMErrorCode = "UNAVAILABLE" + FCMErrInternal FCMErrorCode = "INTERNAL" + FCMErrInvalidArgument FCMErrorCode = "INVALID_ARGUMENT" + FCMErrSenderIDMismatch FCMErrorCode = "SENDER_ID_MISMATCH" + FCMErrThirdPartyAuth FCMErrorCode = "THIRD_PARTY_AUTH_ERROR" +) + +// FCMSendError is a structured error returned when an individual FCM send fails. +type FCMSendError struct { + Token string + StatusCode int + ErrorCode FCMErrorCode + Message string +} + +func (e *FCMSendError) Error() string { + return fmt.Sprintf("FCM send failed for token %s: %s (status %d, code %s)", + truncateToken(e.Token), e.Message, e.StatusCode, e.ErrorCode) +} + +// IsUnregistered returns true if the device token is no longer valid and should be removed. +func (e *FCMSendError) IsUnregistered() bool { + return e.ErrorCode == FCMErrUnregistered +} + +// --- OAuth 2.0 transport --- + +// oauth2BearerTransport is an http.RoundTripper that attaches an OAuth 2.0 Bearer +// token to every outgoing request. The token source handles refresh automatically. +type oauth2BearerTransport struct { + base http.RoundTripper + getToken func() (string, error) +} + +func (t *oauth2BearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + accessToken, err := t.getToken() + if err != nil { + return nil, fmt.Errorf("failed to obtain OAuth 2.0 token for FCM: %w", err) + } + r := req.Clone(req.Context()) + r.Header.Set("Authorization", "Bearer "+accessToken) + return t.base.RoundTrip(r) +} + +// --- Client construction --- + +// NewFCMClient creates a new FCM client using the HTTP v1 API with OAuth 2.0 +// service account authentication. It accepts either a path to a service account +// JSON file or the raw JSON content directly via config. func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) { - if cfg.FCMServerKey == "" { - return nil, fmt.Errorf("FCM server key not configured") + if cfg.FCMProjectID == "" { + return nil, fmt.Errorf("FCM project ID not configured (set FCM_PROJECT_ID)") } - return &FCMClient{ - serverKey: cfg.FCMServerKey, - httpClient: &http.Client{ - Timeout: 30 * time.Second, + credJSON, err := resolveServiceAccountJSON(cfg) + if err != nil { + return nil, err + } + + // Create OAuth 2.0 credentials with the FCM messaging scope. + // The google library handles automatic token refresh. + creds, err := google.CredentialsFromJSON(context.Background(), credJSON, fcmScope) + if err != nil { + return nil, fmt.Errorf("failed to parse FCM service account credentials: %w", err) + } + + // Build an HTTP client that automatically attaches and refreshes OAuth tokens. + transport := &oauth2BearerTransport{ + base: http.DefaultTransport, + getToken: func() (string, error) { + tok, err := creds.TokenSource.Token() + if err != nil { + return "", err + } + return tok.AccessToken, nil }, + } + + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: transport, + } + + endpoint := fmt.Sprintf(fcmV1EndpointFmt, cfg.FCMProjectID) + + log.Info(). + Str("project_id", cfg.FCMProjectID). + Str("endpoint", endpoint). + Msg("FCM v1 client initialized with OAuth 2.0") + + return &FCMClient{ + projectID: cfg.FCMProjectID, + endpoint: endpoint, + httpClient: httpClient, }, nil } -// Send sends a push notification to Android devices +// resolveServiceAccountJSON returns the service account JSON bytes from config. +func resolveServiceAccountJSON(cfg *config.PushConfig) ([]byte, error) { + if cfg.FCMServiceAccountJSON != "" { + return []byte(cfg.FCMServiceAccountJSON), nil + } + if cfg.FCMServiceAccountPath != "" { + data, err := os.ReadFile(cfg.FCMServiceAccountPath) + if err != nil { + return nil, fmt.Errorf("failed to read FCM service account file %s: %w", cfg.FCMServiceAccountPath, err) + } + return data, nil + } + return nil, fmt.Errorf("FCM service account not configured (set FCM_SERVICE_ACCOUNT_PATH or FCM_SERVICE_ACCOUNT_JSON)") +} + +// --- Sending --- + +// Send sends a push notification to Android devices via the FCM HTTP v1 API. +// The v1 API requires one request per device token, so this iterates over all tokens. +// The method signature is kept identical to the previous legacy implementation +// so callers do not need to change. func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message string, data map[string]string) error { if len(tokens) == 0 { return nil } - msg := FCMMessage{ - RegistrationIDs: tokens, - Notification: &FCMNotification{ - Title: title, - Body: message, - Sound: "default", - }, - Data: data, - Priority: "high", - } + var sendErrors []error + successCount := 0 - body, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("failed to marshal FCM message: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", fcmEndpoint, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("failed to create FCM request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "key="+c.serverKey) - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to send FCM request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("FCM returned status %d", resp.StatusCode) - } - - var fcmResp FCMResponse - if err := json.NewDecoder(resp.Body).Decode(&fcmResp); err != nil { - return fmt.Errorf("failed to decode FCM response: %w", err) - } - - // Log individual results - for i, result := range fcmResp.Results { - if i >= len(tokens) { - break - } - if result.Error != "" { + for _, token := range tokens { + err := c.sendOne(ctx, token, title, message, data) + if err != nil { log.Error(). - Str("token", truncateToken(tokens[i])). - Str("error", result.Error). - Msg("FCM notification failed") + Err(err). + Str("token", truncateToken(token)). + Msg("FCM v1 notification failed") + sendErrors = append(sendErrors, err) + continue } + + successCount++ + log.Debug(). + Str("token", truncateToken(token)). + Msg("FCM v1 notification sent successfully") } log.Info(). Int("total", len(tokens)). - Int("success", fcmResp.Success). - Int("failure", fcmResp.Failure). - Msg("FCM batch send complete") + Int("success", successCount). + Int("failed", len(sendErrors)). + Msg("FCM v1 batch send complete") - if fcmResp.Success == 0 && fcmResp.Failure > 0 { - return fmt.Errorf("all FCM notifications failed") + if len(sendErrors) > 0 && successCount == 0 { + return fmt.Errorf("all FCM notifications failed: first error: %w", sendErrors[0]) } return nil } + +// sendOne sends a single FCM v1 message to one device token. +func (c *FCMClient) sendOne(ctx context.Context, token, title, message string, data map[string]string) error { + reqBody := fcmV1Request{ + Message: &fcmV1Message{ + Token: token, + Notification: &FCMNotification{ + Title: title, + Body: message, + }, + Data: data, + Android: &fcmAndroidConfig{ + Priority: "HIGH", + }, + }, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal FCM v1 message: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create FCM v1 request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send FCM v1 request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read FCM v1 response body: %w", err) + } + + if resp.StatusCode == http.StatusOK { + return nil + } + + // Parse the error response for structured error information. + return parseFCMV1Error(token, resp.StatusCode, respBody) +} + +// parseFCMV1Error extracts a structured FCMSendError from the v1 API error response. +func parseFCMV1Error(token string, statusCode int, body []byte) *FCMSendError { + var errResp fcmV1ErrorResponse + if err := json.Unmarshal(body, &errResp); err != nil { + return &FCMSendError{ + Token: token, + StatusCode: statusCode, + Message: fmt.Sprintf("unparseable error response: %s", string(body)), + } + } + + return &FCMSendError{ + Token: token, + StatusCode: statusCode, + ErrorCode: FCMErrorCode(errResp.Error.Status), + Message: errResp.Error.Message, + } +} diff --git a/internal/push/fcm_test.go b/internal/push/fcm_test.go index ff4bf95..de8d895 100644 --- a/internal/push/fcm_test.go +++ b/internal/push/fcm_test.go @@ -5,182 +5,266 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// newTestFCMClient creates an FCMClient pointing at the given test server URL. +// newTestFCMClient creates an FCMClient whose endpoint points at the given test +// server URL. The HTTP client uses no OAuth transport so tests can run without +// real Google credentials. func newTestFCMClient(serverURL string) *FCMClient { return &FCMClient{ - serverKey: "test-server-key", - httpClient: http.DefaultClient, + projectID: "test-project", + endpoint: serverURL, + httpClient: &http.Client{}, } } -// serveFCMResponse creates an httptest.Server that returns the given FCMResponse as JSON. -func serveFCMResponse(t *testing.T, resp FCMResponse) *httptest.Server { +// serveFCMV1Success creates a test server that returns a successful v1 response +// for every request. +func serveFCMV1Success(t *testing.T) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify it looks like a v1 request. + var req fcmV1Request + err := json.NewDecoder(r.Body).Decode(&req) + require.NoError(t, err) + require.NotNil(t, req.Message) + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(resp) - require.NoError(t, err) + resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"} + _ = json.NewEncoder(w).Encode(resp) })) } -// sendWithEndpoint is a helper that sends an FCM notification using a custom endpoint -// (the test server) instead of the real FCM endpoint. This avoids modifying the -// production code to be testable and instead temporarily overrides the client's HTTP -// transport to redirect requests to our test server. -func sendWithEndpoint(client *FCMClient, server *httptest.Server, ctx context.Context, tokens []string, title, message string, data map[string]string) error { - // Override the HTTP client to redirect all requests to the test server - client.httpClient = server.Client() - - // We need to intercept the request and redirect it to our test server. - // Use a custom RoundTripper that rewrites the URL. - originalTransport := server.Client().Transport - client.httpClient.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) { - // Rewrite the URL to point to the test server - req.URL.Scheme = "http" - req.URL.Host = server.Listener.Addr().String() - if originalTransport != nil { - return originalTransport.RoundTrip(req) +// serveFCMV1Error creates a test server that returns the given status code and +// a structured v1 error response for every request. +func serveFCMV1Error(t *testing.T, statusCode int, errStatus string, errMessage string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + resp := fcmV1ErrorResponse{ + Error: fcmV1Error{ + Code: statusCode, + Message: errMessage, + Status: errStatus, + }, } - return http.DefaultTransport.RoundTrip(req) - }) - - return client.Send(ctx, tokens, title, message, data) + _ = json.NewEncoder(w).Encode(resp) + })) } -// roundTripFunc is a function that implements http.RoundTripper. -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) -} - -func TestFCMSend_MoreResultsThanTokens_NoPanic(t *testing.T) { - // FCM returns 5 results but we only sent 2 tokens. - // Before the bounds check fix, this would panic with index out of range. - fcmResp := FCMResponse{ - MulticastID: 12345, - Success: 2, - Failure: 3, - Results: []FCMResult{ - {MessageID: "msg1"}, - {MessageID: "msg2"}, - {Error: "InvalidRegistration"}, - {Error: "NotRegistered"}, - {Error: "InvalidRegistration"}, - }, - } - - server := serveFCMResponse(t, fcmResp) +func TestFCMV1Send_Success_SingleToken(t *testing.T) { + server := serveFCMV1Success(t) defer server.Close() client := newTestFCMClient(server.URL) - tokens := []string{"token-aaa-111", "token-bbb-222"} - - // This must not panic - err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + err := client.Send(context.Background(), []string{"token-aaa-111"}, "Title", "Body", nil) assert.NoError(t, err) } -func TestFCMSend_FewerResultsThanTokens_NoPanic(t *testing.T) { - // FCM returns fewer results than tokens we sent. - // This is also a malformed response but should not panic. - fcmResp := FCMResponse{ - MulticastID: 12345, - Success: 1, - Failure: 0, - Results: []FCMResult{ - {MessageID: "msg1"}, - }, - } +func TestFCMV1Send_Success_MultipleTokens(t *testing.T) { + var mu sync.Mutex + receivedTokens := make([]string, 0) - server := serveFCMResponse(t, fcmResp) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req fcmV1Request + _ = json.NewDecoder(r.Body).Decode(&req) + + mu.Lock() + receivedTokens = append(receivedTokens, req.Message.Token) + mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"} + _ = json.NewEncoder(w).Encode(resp) + })) defer server.Close() client := newTestFCMClient(server.URL) tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"} - err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + err := client.Send(context.Background(), tokens, "Title", "Body", map[string]string{"key": "value"}) assert.NoError(t, err) + assert.ElementsMatch(t, tokens, receivedTokens) } -func TestFCMSend_EmptyResponse_NoPanic(t *testing.T) { - // FCM returns an empty Results slice. - fcmResp := FCMResponse{ - MulticastID: 12345, - Success: 0, - Failure: 0, - Results: []FCMResult{}, - } - - server := serveFCMResponse(t, fcmResp) - defer server.Close() - - client := newTestFCMClient(server.URL) - tokens := []string{"token-aaa-111"} - - err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) - // No panic expected. The function returns nil because fcmResp.Success == 0 - // and fcmResp.Failure == 0 (the "all failed" check requires Failure > 0). - assert.NoError(t, err) -} - -func TestFCMSend_NilResultsSlice_NoPanic(t *testing.T) { - // FCM returns a response with nil Results (e.g., malformed JSON). - fcmResp := FCMResponse{ - MulticastID: 12345, - Success: 0, - Failure: 1, - } - - server := serveFCMResponse(t, fcmResp) - defer server.Close() - - client := newTestFCMClient(server.URL) - tokens := []string{"token-aaa-111"} - - err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) - // Should return error because Success == 0 and Failure > 0 - assert.Error(t, err) - assert.Contains(t, err.Error(), "all FCM notifications failed") -} - -func TestFCMSend_EmptyTokens_ReturnsNil(t *testing.T) { - // Verify the early return for empty tokens. +func TestFCMV1Send_EmptyTokens_ReturnsNil(t *testing.T) { client := &FCMClient{ - serverKey: "test-key", - httpClient: http.DefaultClient, + projectID: "test-project", + endpoint: "http://unused", + httpClient: &http.Client{}, } - err := client.Send(context.Background(), []string{}, "Test", "Body", nil) + err := client.Send(context.Background(), []string{}, "Title", "Body", nil) assert.NoError(t, err) } -func TestFCMSend_ResultsWithErrorsMatchTokens(t *testing.T) { - // Normal case: results count matches tokens count, all with errors. - fcmResp := FCMResponse{ - MulticastID: 12345, - Success: 0, - Failure: 2, - Results: []FCMResult{ - {Error: "InvalidRegistration"}, - {Error: "NotRegistered"}, - }, - } - - server := serveFCMResponse(t, fcmResp) +func TestFCMV1Send_AllFail_ReturnsError(t *testing.T) { + server := serveFCMV1Error(t, http.StatusNotFound, "UNREGISTERED", "The registration token is not registered") defer server.Close() client := newTestFCMClient(server.URL) - tokens := []string{"token-aaa-111", "token-bbb-222"} + tokens := []string{"bad-token-1", "bad-token-2"} - err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil) + err := client.Send(context.Background(), tokens, "Title", "Body", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "all FCM notifications failed") } + +func TestFCMV1Send_PartialFailure_ReturnsNil(t *testing.T) { + var mu sync.Mutex + callCount := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + callCount++ + n := callCount + mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + if n == 1 { + // First token succeeds + w.WriteHeader(http.StatusOK) + resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"} + _ = json.NewEncoder(w).Encode(resp) + } else { + // Second token fails + w.WriteHeader(http.StatusNotFound) + resp := fcmV1ErrorResponse{ + Error: fcmV1Error{Code: 404, Message: "not registered", Status: "UNREGISTERED"}, + } + _ = json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + client := newTestFCMClient(server.URL) + tokens := []string{"good-token", "bad-token"} + + // Partial failure: at least one succeeded, so no error returned. + err := client.Send(context.Background(), tokens, "Title", "Body", nil) + assert.NoError(t, err) +} + +func TestFCMV1Send_UnregisteredError(t *testing.T) { + server := serveFCMV1Error(t, http.StatusNotFound, "UNREGISTERED", "The registration token is not registered") + defer server.Close() + + client := newTestFCMClient(server.URL) + + err := client.sendOne(context.Background(), "stale-token", "Title", "Body", nil) + require.Error(t, err) + + var sendErr *FCMSendError + require.ErrorAs(t, err, &sendErr) + assert.True(t, sendErr.IsUnregistered()) + assert.Equal(t, FCMErrUnregistered, sendErr.ErrorCode) + assert.Equal(t, http.StatusNotFound, sendErr.StatusCode) +} + +func TestFCMV1Send_QuotaExceededError(t *testing.T) { + server := serveFCMV1Error(t, http.StatusTooManyRequests, "QUOTA_EXCEEDED", "Sending quota exceeded") + defer server.Close() + + client := newTestFCMClient(server.URL) + + err := client.sendOne(context.Background(), "some-token", "Title", "Body", nil) + require.Error(t, err) + + var sendErr *FCMSendError + require.ErrorAs(t, err, &sendErr) + assert.Equal(t, FCMErrQuotaExceeded, sendErr.ErrorCode) + assert.Equal(t, http.StatusTooManyRequests, sendErr.StatusCode) +} + +func TestFCMV1Send_UnparseableErrorResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("not json at all")) + })) + defer server.Close() + + client := newTestFCMClient(server.URL) + + err := client.sendOne(context.Background(), "some-token", "Title", "Body", nil) + require.Error(t, err) + + var sendErr *FCMSendError + require.ErrorAs(t, err, &sendErr) + assert.Equal(t, http.StatusInternalServerError, sendErr.StatusCode) + assert.Contains(t, sendErr.Message, "unparseable error response") +} + +func TestFCMV1Send_RequestPayloadFormat(t *testing.T) { + var receivedReq fcmV1Request + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify Content-Type header + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, http.MethodPost, r.Method) + + err := json.NewDecoder(r.Body).Decode(&receivedReq) + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(fcmV1Response{Name: "projects/test-project/messages/0:12345"}) + })) + defer server.Close() + + client := newTestFCMClient(server.URL) + data := map[string]string{"task_id": "42", "action": "complete"} + err := client.Send(context.Background(), []string{"device-token-xyz"}, "Task Due", "Your task is due today", data) + require.NoError(t, err) + + // Verify the v1 message structure + require.NotNil(t, receivedReq.Message) + assert.Equal(t, "device-token-xyz", receivedReq.Message.Token) + assert.Equal(t, "Task Due", receivedReq.Message.Notification.Title) + assert.Equal(t, "Your task is due today", receivedReq.Message.Notification.Body) + assert.Equal(t, "42", receivedReq.Message.Data["task_id"]) + assert.Equal(t, "complete", receivedReq.Message.Data["action"]) + assert.Equal(t, "HIGH", receivedReq.Message.Android.Priority) +} + +func TestFCMSendError_Error(t *testing.T) { + sendErr := &FCMSendError{ + Token: "abcdef1234567890", + StatusCode: 404, + ErrorCode: FCMErrUnregistered, + Message: "token not registered", + } + + errStr := sendErr.Error() + assert.Contains(t, errStr, "abcdef12...") + assert.Contains(t, errStr, "token not registered") + assert.Contains(t, errStr, "404") + assert.Contains(t, errStr, "UNREGISTERED") +} + +func TestFCMSendError_IsUnregistered(t *testing.T) { + tests := []struct { + name string + code FCMErrorCode + expected bool + }{ + {"unregistered", FCMErrUnregistered, true}, + {"quota_exceeded", FCMErrQuotaExceeded, false}, + {"internal", FCMErrInternal, false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := &FCMSendError{ErrorCode: tt.code} + assert.Equal(t, tt.expected, err.IsUnregistered()) + }) + } +} diff --git a/internal/repositories/contractor_repo.go b/internal/repositories/contractor_repo.go index b941722..2d6240f 100644 --- a/internal/repositories/contractor_repo.go +++ b/internal/repositories/contractor_repo.go @@ -124,29 +124,33 @@ func (r *ContractorRepository) GetTasksForContractor(contractorID uint) ([]model return tasks, err } -// SetSpecialties sets the specialties for a contractor +// SetSpecialties sets the specialties for a contractor. +// Wrapped in a transaction so that clearing existing specialties and +// appending new ones are atomic -- a failure in either step rolls back both. func (r *ContractorRepository) SetSpecialties(contractorID uint, specialtyIDs []uint) error { - var contractor models.Contractor - if err := r.db.First(&contractor, contractorID).Error; err != nil { - return err - } + return r.db.Transaction(func(tx *gorm.DB) error { + var contractor models.Contractor + if err := tx.First(&contractor, contractorID).Error; err != nil { + return err + } - // Clear existing specialties - if err := r.db.Model(&contractor).Association("Specialties").Clear(); err != nil { - return err - } + // Clear existing specialties + if err := tx.Model(&contractor).Association("Specialties").Clear(); err != nil { + return err + } - if len(specialtyIDs) == 0 { - return nil - } + if len(specialtyIDs) == 0 { + return nil + } - // Add new specialties - var specialties []models.ContractorSpecialty - if err := r.db.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil { - return err - } + // Add new specialties + var specialties []models.ContractorSpecialty + if err := tx.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil { + return err + } - return r.db.Model(&contractor).Association("Specialties").Append(specialties) + return tx.Model(&contractor).Association("Specialties").Append(specialties) + }) } // CountByResidence counts contractors in a residence diff --git a/internal/repositories/document_repo.go b/internal/repositories/document_repo.go index 935f7f0..41b9e7f 100644 --- a/internal/repositories/document_repo.go +++ b/internal/repositories/document_repo.go @@ -98,7 +98,7 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc } var documents []models.Document - err := query.Order("created_at DESC").Find(&documents).Error + err := query.Order("created_at DESC").Limit(500).Find(&documents).Error return documents, err } diff --git a/internal/repositories/reminder_repo.go b/internal/repositories/reminder_repo.go index 908fe90..0fef467 100644 --- a/internal/repositories/reminder_repo.go +++ b/internal/repositories/reminder_repo.go @@ -88,10 +88,33 @@ func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]b userIDs = append(userIDs, id) } - // Query all matching reminder logs in one query + // Collect unique stages and due dates for tighter SQL filtering + stageSet := make(map[models.ReminderStage]bool) + dueDateSet := make(map[string]bool) + var minDueDate, maxDueDate time.Time + for _, k := range keys { + stageSet[k.Stage] = true + dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC) + dueDateSet[dueDateOnly.Format("2006-01-02")] = true + if minDueDate.IsZero() || dueDateOnly.Before(minDueDate) { + minDueDate = dueDateOnly + } + if maxDueDate.IsZero() || dueDateOnly.After(maxDueDate) { + maxDueDate = dueDateOnly + } + } + stages := make([]models.ReminderStage, 0, len(stageSet)) + for s := range stageSet { + stages = append(stages, s) + } + + // Query matching reminder logs with tighter filters to reduce result set. + // Filter on reminder_stage and due_date range in addition to task_id/user_id. var logs []models.TaskReminderLog - err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs). - Find(&logs).Error + err := r.db.Where( + "task_id IN ? AND user_id IN ? AND reminder_stage IN ? AND due_date >= ? AND due_date <= ?", + taskIDs, userIDs, stages, minDueDate, maxDueDate, + ).Find(&logs).Error if err != nil { return nil, err } diff --git a/internal/repositories/residence_repo.go b/internal/repositories/residence_repo.go index ac2ef28..108c952 100644 --- a/internal/repositories/residence_repo.go +++ b/internal/repositories/residence_repo.go @@ -196,6 +196,20 @@ func (r *ResidenceRepository) CountByOwner(userID uint) (int64, error) { return count, err } +// FindResidenceIDsByOwner returns just the IDs of residences a user owns. +// This is a lightweight alternative to FindOwnedByUser() when only IDs are needed +// for batch queries against related tables (tasks, contractors, documents). +func (r *ResidenceRepository) FindResidenceIDsByOwner(userID uint) ([]uint, error) { + var ids []uint + err := r.db.Model(&models.Residence{}). + Where("owner_id = ? AND is_active = ?", userID, true). + Pluck("id", &ids).Error + if err != nil { + return nil, err + } + return ids, nil +} + // === Share Code Operations === // CreateShareCode creates a new share code for a residence diff --git a/internal/repositories/subscription_repo.go b/internal/repositories/subscription_repo.go index 3c2dd05..f3b4a4f 100644 --- a/internal/repositories/subscription_repo.go +++ b/internal/repositories/subscription_repo.go @@ -129,12 +129,21 @@ func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) Update("google_purchase_token", token).Error } -// FindByAppleReceiptContains finds a subscription by Apple transaction ID -// Used by webhooks to find the user associated with a transaction +// FindByAppleReceiptContains finds a subscription by Apple transaction ID. +// Used by webhooks to find the user associated with a transaction. +// +// PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which +// cannot use a B-tree index and results in a full table scan. For better +// performance at scale, add a dedicated indexed column: +// +// AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"` +// +// Then look up by exact match: WHERE apple_transaction_id = ? func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) { var sub models.UserSubscription - // Search for transaction ID in the stored receipt data - err := r.db.Where("apple_receipt_data LIKE ?", "%"+transactionID+"%").First(&sub).Error + // Escape LIKE wildcards in the transaction ID to prevent wildcard injection + escaped := escapeLikeWildcards(transactionID) + err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error if err != nil { return nil, err } diff --git a/internal/repositories/task_repo.go b/internal/repositories/task_repo.go index f78795e..91be2b1 100644 --- a/internal/repositories/task_repo.go +++ b/internal/repositories/task_repo.go @@ -38,29 +38,40 @@ func (r *TaskRepository) CreateCompletionTx(tx *gorm.DB, completion *models.Task return tx.Create(completion).Error } +// taskUpdateFields returns the canonical field map used by both Update and UpdateTx. +// Centralised here so the two methods never drift out of sync. +func taskUpdateFields(t *models.Task) map[string]interface{} { + return map[string]interface{}{ + "title": t.Title, + "description": t.Description, + "category_id": t.CategoryID, + "priority_id": t.PriorityID, + "frequency_id": t.FrequencyID, + "custom_interval_days": t.CustomIntervalDays, + "in_progress": t.InProgress, + "assigned_to_id": t.AssignedToID, + "due_date": t.DueDate, + "next_due_date": t.NextDueDate, + "estimated_cost": t.EstimatedCost, + "actual_cost": t.ActualCost, + "contractor_id": t.ContractorID, + "is_cancelled": t.IsCancelled, + "is_archived": t.IsArchived, + "version": gorm.Expr("version + 1"), + } +} + +// taskUpdateOmitAssociations lists the association fields to omit during task updates. +var taskUpdateOmitAssociations = []string{ + "Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions", +} + // UpdateTx updates a task with optimistic locking within an existing transaction. func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error { result := tx.Model(task). Where("id = ? AND version = ?", task.ID, task.Version). - Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions"). - Updates(map[string]interface{}{ - "title": task.Title, - "description": task.Description, - "category_id": task.CategoryID, - "priority_id": task.PriorityID, - "frequency_id": task.FrequencyID, - "custom_interval_days": task.CustomIntervalDays, - "in_progress": task.InProgress, - "assigned_to_id": task.AssignedToID, - "due_date": task.DueDate, - "next_due_date": task.NextDueDate, - "estimated_cost": task.EstimatedCost, - "actual_cost": task.ActualCost, - "contractor_id": task.ContractorID, - "is_cancelled": task.IsCancelled, - "is_archived": task.IsArchived, - "version": gorm.Expr("version + 1"), - }) + Omit(taskUpdateOmitAssociations...). + Updates(taskUpdateFields(task)) if result.Error != nil { return result.Error } @@ -350,25 +361,8 @@ func (r *TaskRepository) Create(task *models.Task) error { func (r *TaskRepository) Update(task *models.Task) error { result := r.db.Model(task). Where("id = ? AND version = ?", task.ID, task.Version). - Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions"). - Updates(map[string]interface{}{ - "title": task.Title, - "description": task.Description, - "category_id": task.CategoryID, - "priority_id": task.PriorityID, - "frequency_id": task.FrequencyID, - "custom_interval_days": task.CustomIntervalDays, - "in_progress": task.InProgress, - "assigned_to_id": task.AssignedToID, - "due_date": task.DueDate, - "next_due_date": task.NextDueDate, - "estimated_cost": task.EstimatedCost, - "actual_cost": task.ActualCost, - "contractor_id": task.ContractorID, - "is_cancelled": task.IsCancelled, - "is_archived": task.IsArchived, - "version": gorm.Expr("version + 1"), - }) + Omit(taskUpdateOmitAssociations...). + Updates(taskUpdateFields(task)) if result.Error != nil { return result.Error } @@ -728,13 +722,18 @@ func (r *TaskRepository) UpdateCompletion(completion *models.TaskCompletion) err return r.db.Omit("Task", "CompletedBy", "Images").Save(completion).Error } -// DeleteCompletion deletes a task completion +// DeleteCompletion deletes a task completion and its associated images atomically. +// Wrapped in a transaction so that if the completion delete fails, image +// deletions are rolled back as well. func (r *TaskRepository) DeleteCompletion(id uint) error { - // Delete images first - if err := r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil { - log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images") - } - return r.db.Delete(&models.TaskCompletion{}, id).Error + return r.db.Transaction(func(tx *gorm.DB) error { + // Delete images first + if err := tx.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil { + log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images") + return err + } + return tx.Delete(&models.TaskCompletion{}, id).Error + }) } // CreateCompletionImage creates a new completion image @@ -912,3 +911,128 @@ func (r *TaskRepository) GetCompletionSummary(residenceID uint, now time.Time, m Months: months, }, nil } + +// GetBatchCompletionSummaries returns completion summaries for multiple residences +// in two queries total (one for all-time counts, one for monthly breakdowns), +// instead of 2*N queries when calling GetCompletionSummary per residence. +func (r *TaskRepository) GetBatchCompletionSummaries(residenceIDs []uint, now time.Time, maxPerMonth int) (map[uint]*responses.CompletionSummary, error) { + result := make(map[uint]*responses.CompletionSummary, len(residenceIDs)) + if len(residenceIDs) == 0 { + return result, nil + } + + // 1. Total all-time completions per residence (single query) + type allTimeRow struct { + ResidenceID uint + Count int64 + } + var allTimeRows []allTimeRow + err := r.db.Model(&models.TaskCompletion{}). + Select("task_task.residence_id, COUNT(*) as count"). + Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id"). + Where("task_task.residence_id IN ?", residenceIDs). + Group("task_task.residence_id"). + Scan(&allTimeRows).Error + if err != nil { + return nil, err + } + + allTimeMap := make(map[uint]int64, len(allTimeRows)) + for _, row := range allTimeRows { + allTimeMap[row.ResidenceID] = row.Count + } + + // 2. Monthly breakdown for last 12 months across all residences (single query) + startDate := time.Date(now.Year()-1, now.Month(), 1, 0, 0, 0, 0, now.Location()) + + dateExpr := "TO_CHAR(task_taskcompletion.completed_at, 'YYYY-MM')" + if r.db.Dialector.Name() == "sqlite" { + dateExpr = "strftime('%Y-%m', task_taskcompletion.completed_at)" + } + + var rows []completionAggRow + err = r.db.Model(&models.TaskCompletion{}). + Select(fmt.Sprintf("task_task.residence_id, task_taskcompletion.completed_from_column, %s as completed_month, COUNT(*) as count", dateExpr)). + Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id"). + Where("task_task.residence_id IN ? AND task_taskcompletion.completed_at >= ?", residenceIDs, startDate). + Group(fmt.Sprintf("task_task.residence_id, task_taskcompletion.completed_from_column, %s", dateExpr)). + Order("completed_month ASC"). + Scan(&rows).Error + if err != nil { + return nil, err + } + + // 3. Build per-residence summaries + type monthData struct { + columns map[string]int + total int + } + + // Initialize all residences with empty month maps + residenceMonths := make(map[uint]map[string]*monthData, len(residenceIDs)) + for _, rid := range residenceIDs { + mm := make(map[string]*monthData, 12) + for i := 0; i < 12; i++ { + m := startDate.AddDate(0, i, 0) + key := m.Format("2006-01") + mm[key] = &monthData{columns: make(map[string]int)} + } + residenceMonths[rid] = mm + } + + // Populate from query results + residenceLast12 := make(map[uint]int, len(residenceIDs)) + for _, row := range rows { + mm, ok := residenceMonths[row.ResidenceID] + if !ok { + continue + } + md, ok := mm[row.CompletedMonth] + if !ok { + continue + } + md.columns[row.CompletedFromColumn] = int(row.Count) + md.total += int(row.Count) + residenceLast12[row.ResidenceID] += int(row.Count) + } + + // Convert to response DTOs per residence + for _, rid := range residenceIDs { + mm := residenceMonths[rid] + months := make([]responses.MonthlyCompletionSummary, 0, 12) + for i := 0; i < 12; i++ { + m := startDate.AddDate(0, i, 0) + key := m.Format("2006-01") + md := mm[key] + + completions := make([]responses.ColumnCompletionCount, 0) + for col, count := range md.columns { + completions = append(completions, responses.ColumnCompletionCount{ + Column: col, + Color: KanbanColumnColor(col), + Count: count, + }) + } + + overflow := 0 + if md.total > maxPerMonth { + overflow = md.total - maxPerMonth + } + + months = append(months, responses.MonthlyCompletionSummary{ + Month: key, + Completions: completions, + Total: md.total, + Overflow: overflow, + }) + } + + result[rid] = &responses.CompletionSummary{ + TotalAllTime: int(allTimeMap[rid]), + TotalLast12Months: residenceLast12[rid], + Months: months, + } + } + + return result, nil +} diff --git a/internal/repositories/user_repo.go b/internal/repositories/user_repo.go index aed993d..251cc17 100644 --- a/internal/repositories/user_repo.go +++ b/internal/repositories/user_repo.go @@ -34,6 +34,16 @@ func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } +// Transaction runs fn inside a database transaction. The callback receives a +// new UserRepository backed by the transaction so all operations within fn +// share the same transactional connection. +func (r *UserRepository) Transaction(fn func(txRepo *UserRepository) error) error { + return r.db.Transaction(func(tx *gorm.DB) error { + txRepo := &UserRepository{db: tx} + return fn(txRepo) + }) +} + // FindByID finds a user by ID func (r *UserRepository) FindByID(id uint) (*models.User, error) { var user models.User @@ -130,18 +140,28 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) { // --- Auth Token Methods --- -// GetOrCreateToken gets or creates an auth token for a user +// GetOrCreateToken gets or creates an auth token for a user. +// Wrapped in a transaction to prevent race conditions where two +// concurrent requests could create duplicate tokens for the same user. func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) { var token models.AuthToken - result := r.db.Where("user_id = ?", userID).First(&token) - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - token = models.AuthToken{UserID: userID} - if err := r.db.Create(&token).Error; err != nil { - return nil, err + err := r.db.Transaction(func(tx *gorm.DB) error { + result := tx.Where("user_id = ?", userID).First(&token) + + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + token = models.AuthToken{UserID: userID} + if err := tx.Create(&token).Error; err != nil { + return err + } + } else if result.Error != nil { + return result.Error } - } else if result.Error != nil { - return nil, result.Error + + return nil + }) + if err != nil { + return nil, err } return &token, nil @@ -341,7 +361,7 @@ func (r *UserRepository) SearchUsers(query string, limit, offset int) ([]models. var users []models.User var total int64 - searchQuery := "%" + strings.ToLower(query) + "%" + searchQuery := "%" + escapeLikeWildcards(strings.ToLower(query)) + "%" baseQuery := r.db.Model(&models.User{}). Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?", @@ -384,7 +404,7 @@ func (r *UserRepository) FindUsersInSharedResidences(userID uint) ([]models.User // 2. Members of residences owned by current user // 3. Members of residences where current user is also a member err := r.db.Raw(` - SELECT DISTINCT u.* FROM user_customuser u + SELECT DISTINCT u.* FROM auth_user u WHERE u.id != ? AND u.is_active = true AND ( -- Users who own residences where current user is a shared user u.id IN ( @@ -417,7 +437,7 @@ func (r *UserRepository) FindUserIfSharedResidence(targetUserID, requestingUserI var user models.User err := r.db.Raw(` - SELECT u.* FROM user_customuser u + SELECT u.* FROM auth_user u WHERE u.id = ? AND u.is_active = true AND ( u.id = ? OR -- Target owns a residence where requester is a member @@ -460,7 +480,7 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U err := r.db.Raw(` SELECT p.* FROM user_userprofile p - INNER JOIN user_customuser u ON p.user_id = u.id + INNER JOIN auth_user u ON p.user_id = u.id WHERE u.is_active = true AND ( u.id = ? OR -- Users who own residences where current user is a shared user diff --git a/internal/router/router.go b/internal/router/router.go index ce15833..30bc9d2 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -59,6 +59,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo { e.Use(custommiddleware.RequestIDMiddleware()) e.Use(utils.EchoRecovery()) e.Use(custommiddleware.StructuredLogger()) + + // Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.) + e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ + XSSProtection: "1; mode=block", + ContentTypeNosniff: "nosniff", + XFrameOptions: "SAMEORIGIN", + HSTSMaxAge: 31536000, // 1 year in seconds + ReferrerPolicy: "strict-origin-when-cross-origin", + })) e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{ Limit: "1M", // 1MB default for JSON payloads Skipper: func(c echo.Context) bool { @@ -187,7 +196,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo { var uploadHandler *handlers.UploadHandler var mediaHandler *handlers.MediaHandler if deps.StorageService != nil { - uploadHandler = handlers.NewUploadHandler(deps.StorageService) + uploadHandler = handlers.NewUploadHandler(deps.StorageService, services.NewFileOwnershipService(deps.DB)) mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService) } @@ -247,13 +256,22 @@ func SetupRouter(deps *Dependencies) *echo.Echo { } // corsMiddleware configures CORS with restricted origins in production. -// In debug mode, all origins are allowed for development convenience. +// In debug mode, explicit localhost origins are allowed for development. // In production, origins are read from the CORS_ALLOWED_ORIGINS environment variable // (comma-separated), falling back to a restrictive default set. func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc { var origins []string if cfg.Server.Debug { - origins = []string{"*"} + origins = []string{ + "http://localhost:3000", + "http://localhost:3001", + "http://localhost:8080", + "http://localhost:8000", + "http://127.0.0.1:3000", + "http://127.0.0.1:3001", + "http://127.0.0.1:8080", + "http://127.0.0.1:8000", + } } else { origins = cfg.Server.CorsAllowedOrigins if len(origins) == 0 { @@ -286,17 +304,24 @@ func healthCheck(c echo.Context) error { }) } -// setupPublicAuthRoutes configures public authentication routes +// setupPublicAuthRoutes configures public authentication routes with +// per-endpoint rate limiters to mitigate brute-force and credential-stuffing. func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) { auth := api.Group("/auth") + + // Rate limiters — created once, shared across requests. + loginRL := custommiddleware.LoginRateLimiter() // 10 req/min + registerRL := custommiddleware.RegistrationRateLimiter() // 5 req/min + passwordRL := custommiddleware.PasswordResetRateLimiter() // 3 req/min + { - auth.POST("/login/", authHandler.Login) - auth.POST("/register/", authHandler.Register) - auth.POST("/forgot-password/", authHandler.ForgotPassword) - auth.POST("/verify-reset-code/", authHandler.VerifyResetCode) - auth.POST("/reset-password/", authHandler.ResetPassword) - auth.POST("/apple-sign-in/", authHandler.AppleSignIn) - auth.POST("/google-sign-in/", authHandler.GoogleSignIn) + auth.POST("/login/", authHandler.Login, loginRL) + auth.POST("/register/", authHandler.Register, registerRL) + auth.POST("/forgot-password/", authHandler.ForgotPassword, passwordRL) + auth.POST("/verify-reset-code/", authHandler.VerifyResetCode, passwordRL) + auth.POST("/reset-password/", authHandler.ResetPassword, passwordRL) + auth.POST("/apple-sign-in/", authHandler.AppleSignIn, loginRL) + auth.POST("/google-sign-in/", authHandler.GoogleSignIn, loginRL) } } diff --git a/internal/services/auth_service.go b/internal/services/auth_service.go index a9952af..bda3d8f 100644 --- a/internal/services/auth_service.go +++ b/internal/services/auth_service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" "github.com/treytartt/honeydue-api/internal/apperrors" @@ -90,7 +91,7 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons // Update last login if err := s.userRepo.UpdateLastLogin(user.ID); err != nil { // Log error but don't fail the login - fmt.Printf("Failed to update last login: %v\n", err) + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login") } return &responses.LoginResponse{ @@ -99,7 +100,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons }, nil } -// Register creates a new user account +// Register creates a new user account. +// F-10: User creation, profile creation, notification preferences, and confirmation code +// are wrapped in a transaction for atomicity. func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) { // Check if username exists exists, err := s.userRepo.ExistsByUsername(req.Username) @@ -133,43 +136,49 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist return nil, "", apperrors.Internal(err) } - // Save user - if err := s.userRepo.Create(user); err != nil { - return nil, "", apperrors.Internal(err) - } - - // Create user profile - if _, err := s.userRepo.GetOrCreateProfile(user.ID); err != nil { - // Log error but don't fail registration - fmt.Printf("Failed to create user profile: %v\n", err) - } - - // Create notification preferences with all options enabled - if s.notificationRepo != nil { - if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { - // Log error but don't fail registration - fmt.Printf("Failed to create notification preferences: %v\n", err) - } - } - - // Create auth token - token, err := s.userRepo.GetOrCreateToken(user.ID) - if err != nil { - return nil, "", apperrors.Internal(err) - } - - // Generate confirmation code - use fixed code in debug mode for easier local testing + // Generate confirmation code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing var code string - if s.cfg.Server.Debug { + if s.cfg.Server.DebugFixedCodes { code = "123456" } else { code = generateSixDigitCode() } expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) - if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil { - // Log error but don't fail registration - fmt.Printf("Failed to create confirmation code: %v\n", err) + // Wrap user creation + profile + notification preferences + confirmation code in a transaction + txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error { + // Save user + if err := txRepo.Create(user); err != nil { + return err + } + + // Create user profile + if _, err := txRepo.GetOrCreateProfile(user.ID); err != nil { + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile during registration") + } + + // Create notification preferences with all options enabled + if s.notificationRepo != nil { + if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration") + } + } + + // Create confirmation code + if _, err := txRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil { + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create confirmation code during registration") + } + + return nil + }) + if txErr != nil { + return nil, "", apperrors.Internal(txErr) + } + + // Create auth token (outside transaction since token generation is idempotent) + token, err := s.userRepo.GetOrCreateToken(user.ID) + if err != nil { + return nil, "", apperrors.Internal(err) } return &responses.RegisterResponse{ @@ -248,8 +257,8 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error { return apperrors.BadRequest("error.email_already_verified") } - // Check for test code in debug mode - if s.cfg.Server.Debug && code == "123456" { + // Check for test code when DEBUG_FIXED_CODES is enabled + if s.cfg.Server.DebugFixedCodes && code == "123456" { if err := s.userRepo.SetProfileVerified(userID, true); err != nil { return apperrors.Internal(err) } @@ -294,9 +303,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) { return "", apperrors.BadRequest("error.email_already_verified") } - // Generate new code - use fixed code in debug mode for easier local testing + // Generate new code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing var code string - if s.cfg.Server.Debug { + if s.cfg.Server.DebugFixedCodes { code = "123456" } else { code = generateSixDigitCode() @@ -331,9 +340,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded") } - // Generate code and reset token - use fixed code in debug mode for easier local testing + // Generate code and reset token - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing var code string - if s.cfg.Server.Debug { + if s.cfg.Server.DebugFixedCodes { code = "123456" } else { code = generateSixDigitCode() @@ -365,8 +374,8 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) { return "", apperrors.Internal(err) } - // Check for test code in debug mode - if s.cfg.Server.Debug && code == "123456" { + // Check for test code when DEBUG_FIXED_CODES is enabled + if s.cfg.Server.DebugFixedCodes && code == "123456" { return resetCode.ResetToken, nil } @@ -422,13 +431,13 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error { // Mark reset code as used if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil { // Log error but don't fail - fmt.Printf("Failed to mark reset code as used: %v\n", err) + log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used") } // Invalidate all existing tokens for this user (security measure) if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil { // Log error but don't fail - fmt.Printf("Failed to delete user tokens: %v\n", err) + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset") } return nil @@ -482,6 +491,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi if email != "" { existingUser, err := s.userRepo.FindByEmail(email) if err == nil && existingUser != nil { + // S-06: Log auto-linking of social account to existing user + log.Warn(). + Str("email", email). + Str("provider", "apple"). + Uint("user_id", existingUser.ID). + Msg("Auto-linking social account to existing user by email match") + // Link Apple ID to existing account appleAuthRecord := &models.AppleSocialAuth{ UserID: existingUser.ID, @@ -505,8 +521,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi // Update last login _ = s.userRepo.UpdateLastLogin(existingUser.ID) - // Reload user with profile - existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID) + // B-08: Check error from FindByIDWithProfile + existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID) + if err != nil { + return nil, apperrors.Internal(err) + } return &responses.AppleSignInResponse{ Token: token.Key, @@ -544,8 +563,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi // Create notification preferences with all options enabled if s.notificationRepo != nil { if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { - // Log error but don't fail registration - fmt.Printf("Failed to create notification preferences: %v\n", err) + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user") } } @@ -566,8 +584,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi return nil, apperrors.Internal(err) } - // Reload user with profile - user, _ = s.userRepo.FindByIDWithProfile(user.ID) + // B-08: Check error from FindByIDWithProfile + user, err = s.userRepo.FindByIDWithProfile(user.ID) + if err != nil { + return nil, apperrors.Internal(err) + } return &responses.AppleSignInResponse{ Token: token.Key, @@ -623,6 +644,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe if email != "" { existingUser, err := s.userRepo.FindByEmail(email) if err == nil && existingUser != nil { + // S-06: Log auto-linking of social account to existing user + log.Warn(). + Str("email", email). + Str("provider", "google"). + Uint("user_id", existingUser.ID). + Msg("Auto-linking social account to existing user by email match") + // Link Google ID to existing account googleAuthRecord := &models.GoogleSocialAuth{ UserID: existingUser.ID, @@ -649,8 +677,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe // Update last login _ = s.userRepo.UpdateLastLogin(existingUser.ID) - // Reload user with profile - existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID) + // B-08: Check error from FindByIDWithProfile + existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID) + if err != nil { + return nil, apperrors.Internal(err) + } return &responses.GoogleSignInResponse{ Token: token.Key, @@ -688,8 +719,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe // Create notification preferences with all options enabled if s.notificationRepo != nil { if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil { - // Log error but don't fail registration - fmt.Printf("Failed to create notification preferences: %v\n", err) + log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user") } } @@ -711,8 +741,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe return nil, apperrors.Internal(err) } - // Reload user with profile - user, _ = s.userRepo.FindByIDWithProfile(user.ID) + // B-08: Check error from FindByIDWithProfile + user, err = s.userRepo.FindByIDWithProfile(user.ID) + if err != nil { + return nil, apperrors.Internal(err) + } return &responses.GoogleSignInResponse{ Token: token.Key, diff --git a/internal/services/cache_service.go b/internal/services/cache_service.go index 84641b3..800ebc8 100644 --- a/internal/services/cache_service.go +++ b/internal/services/cache_service.go @@ -2,9 +2,10 @@ package services import ( "context" - "crypto/md5" "encoding/json" "fmt" + "hash/fnv" + "sync" "time" "github.com/redis/go-redis/v9" @@ -18,38 +19,55 @@ type CacheService struct { client *redis.Client } -var cacheInstance *CacheService +var ( + cacheInstance *CacheService + cacheOnce sync.Once +) -// NewCacheService creates a new cache service +// NewCacheService creates a new cache service (thread-safe via sync.Once) func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) { - opt, err := redis.ParseURL(cfg.URL) - if err != nil { - return nil, fmt.Errorf("failed to parse Redis URL: %w", err) + var initErr error + + cacheOnce.Do(func() { + opt, err := redis.ParseURL(cfg.URL) + if err != nil { + initErr = fmt.Errorf("failed to parse Redis URL: %w", err) + return + } + + if cfg.Password != "" { + opt.Password = cfg.Password + } + if cfg.DB != 0 { + opt.DB = cfg.DB + } + + client := redis.NewClient(opt) + + // Test connection + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + initErr = fmt.Errorf("failed to connect to Redis: %w", err) + // Reset Once so a retry is possible after transient failures + cacheOnce = sync.Once{} + return + } + + // S-14: Mask credentials in Redis URL before logging + log.Info(). + Str("url", config.MaskURLCredentials(cfg.URL)). + Int("db", opt.DB). + Msg("Connected to Redis") + + cacheInstance = &CacheService{client: client} + }) + + if initErr != nil { + return nil, initErr } - if cfg.Password != "" { - opt.Password = cfg.Password - } - if cfg.DB != 0 { - opt.DB = cfg.DB - } - - client := redis.NewClient(opt) - - // Test connection - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := client.Ping(ctx).Err(); err != nil { - return nil, fmt.Errorf("failed to connect to Redis: %w", err) - } - - log.Info(). - Str("url", cfg.URL). - Int("db", opt.DB). - Msg("Connected to Redis") - - cacheInstance = &CacheService{client: client} return cacheInstance, nil } @@ -311,9 +329,10 @@ func (c *CacheService) CacheSeededData(ctx context.Context, data interface{}) (s return "", fmt.Errorf("failed to marshal seeded data: %w", err) } - // Generate MD5 ETag from the JSON data - hash := md5.Sum(jsonData) - etag := fmt.Sprintf("\"%x\"", hash) + // Generate FNV-64a ETag from the JSON data (faster than MD5, non-cryptographic) + h := fnv.New64a() + h.Write(jsonData) + etag := fmt.Sprintf("\"%x\"", h.Sum64()) // Store both the data and the ETag if err := c.client.Set(ctx, SeededDataKey, jsonData, SeededDataTTL).Err(); err != nil { diff --git a/internal/services/email_service.go b/internal/services/email_service.go index b5d7817..b3656a4 100644 --- a/internal/services/email_service.go +++ b/internal/services/email_service.go @@ -4,11 +4,11 @@ import ( "bytes" "fmt" "html/template" - "io" "time" + mail "github.com/wneessen/go-mail" + "github.com/rs/zerolog/log" - "gopkg.in/gomail.v2" "github.com/treytartt/honeydue-api/internal/config" ) @@ -16,17 +16,31 @@ import ( // EmailService handles sending emails type EmailService struct { cfg *config.EmailConfig - dialer *gomail.Dialer + client *mail.Client enabled bool } // NewEmailService creates a new email service func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService { - dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password) + client, err := mail.NewClient(cfg.Host, + mail.WithPort(cfg.Port), + mail.WithSMTPAuth(mail.SMTPAuthPlain), + mail.WithUsername(cfg.User), + mail.WithPassword(cfg.Password), + mail.WithTLSPortPolicy(mail.TLSOpportunistic), + ) + if err != nil { + log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent") + return &EmailService{ + cfg: cfg, + client: nil, + enabled: false, + } + } return &EmailService{ cfg: cfg, - dialer: dialer, + client: client, enabled: enabled, } } @@ -37,14 +51,18 @@ func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error { log.Debug().Msg("Email sending disabled by feature flag") return nil } - m := gomail.NewMessage() - m.SetHeader("From", s.cfg.From) - m.SetHeader("To", to) - m.SetHeader("Subject", subject) - m.SetBody("text/plain", textBody) - m.AddAlternative("text/html", htmlBody) + m := mail.NewMsg() + if err := m.FromFormat("honeyDue", s.cfg.From); err != nil { + return fmt.Errorf("failed to set from address: %w", err) + } + if err := m.AddTo(to); err != nil { + return fmt.Errorf("failed to set to address: %w", err) + } + m.Subject(subject) + m.SetBodyString(mail.TypeTextPlain, textBody) + m.AddAlternativeString(mail.TypeTextHTML, htmlBody) - if err := s.dialer.DialAndSend(m); err != nil { + if err := s.client.DialAndSend(m); err != nil { log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email") return fmt.Errorf("failed to send email: %w", err) } @@ -74,26 +92,25 @@ func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody s log.Debug().Msg("Email sending disabled by feature flag") return nil } - m := gomail.NewMessage() - m.SetHeader("From", s.cfg.From) - m.SetHeader("To", to) - m.SetHeader("Subject", subject) - m.SetBody("text/plain", textBody) - m.AddAlternative("text/html", htmlBody) + m := mail.NewMsg() + if err := m.FromFormat("honeyDue", s.cfg.From); err != nil { + return fmt.Errorf("failed to set from address: %w", err) + } + if err := m.AddTo(to); err != nil { + return fmt.Errorf("failed to set to address: %w", err) + } + m.Subject(subject) + m.SetBodyString(mail.TypeTextPlain, textBody) + m.AddAlternativeString(mail.TypeTextHTML, htmlBody) if attachment != nil { - m.Attach(attachment.Filename, - gomail.SetCopyFunc(func(w io.Writer) error { - _, err := w.Write(attachment.Data) - return err - }), - gomail.SetHeader(map[string][]string{ - "Content-Type": {attachment.ContentType}, - }), + m.AttachReader(attachment.Filename, + bytes.NewReader(attachment.Data), + mail.WithFileContentType(mail.ContentType(attachment.ContentType)), ) } - if err := s.dialer.DialAndSend(m); err != nil { + if err := s.client.DialAndSend(m); err != nil { log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email with attachment") return fmt.Errorf("failed to send email: %w", err) } @@ -108,29 +125,28 @@ func (s *EmailService) SendEmailWithEmbeddedImages(to, subject, htmlBody, textBo log.Debug().Msg("Email sending disabled by feature flag") return nil } - m := gomail.NewMessage() - m.SetHeader("From", s.cfg.From) - m.SetHeader("To", to) - m.SetHeader("Subject", subject) - m.SetBody("text/plain", textBody) - m.AddAlternative("text/html", htmlBody) + m := mail.NewMsg() + if err := m.FromFormat("honeyDue", s.cfg.From); err != nil { + return fmt.Errorf("failed to set from address: %w", err) + } + if err := m.AddTo(to); err != nil { + return fmt.Errorf("failed to set to address: %w", err) + } + m.Subject(subject) + m.SetBodyString(mail.TypeTextPlain, textBody) + m.AddAlternativeString(mail.TypeTextHTML, htmlBody) // Embed each image with Content-ID for inline display for _, img := range images { - m.Embed(img.Filename, - gomail.SetCopyFunc(func(w io.Writer) error { - _, err := w.Write(img.Data) - return err - }), - gomail.SetHeader(map[string][]string{ - "Content-Type": {img.ContentType}, - "Content-ID": {"<" + img.ContentID + ">"}, - "Content-Disposition": {"inline; filename=\"" + img.Filename + "\""}, - }), + img := img // capture range variable for closure + m.EmbedReader(img.Filename, + bytes.NewReader(img.Data), + mail.WithFileContentType(mail.ContentType(img.ContentType)), + mail.WithFileContentID(img.ContentID), ) } - if err := s.dialer.DialAndSend(m); err != nil { + if err := s.client.DialAndSend(m); err != nil { log.Error().Err(err).Str("to", to).Str("subject", subject).Int("images", len(images)).Msg("Failed to send email with embedded images") return fmt.Errorf("failed to send email: %w", err) } diff --git a/internal/services/file_ownership_service.go b/internal/services/file_ownership_service.go new file mode 100644 index 0000000..fc1d084 --- /dev/null +++ b/internal/services/file_ownership_service.go @@ -0,0 +1,66 @@ +package services + +import ( + "github.com/treytartt/honeydue-api/internal/models" + "gorm.io/gorm" +) + +// FileOwnershipService checks whether a user owns a file referenced by URL. +// It queries task completion images, document files, and document images +// to determine ownership through residence access. +type FileOwnershipService struct { + db *gorm.DB +} + +// NewFileOwnershipService creates a new FileOwnershipService +func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService { + return &FileOwnershipService{db: db} +} + +// IsFileOwnedByUser checks if the given file URL belongs to a record +// that the user has access to (via residence membership). +func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) { + // Check task completion images: image_url -> completion -> task -> residence -> user access + var completionImageCount int64 + err := s.db.Model(&models.TaskCompletionImage{}). + Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id"). + Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id"). + Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id"). + Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID). + Count(&completionImageCount).Error + if err != nil { + return false, err + } + if completionImageCount > 0 { + return true, nil + } + + // Check document files: file_url -> document -> residence -> user access + var documentCount int64 + err = s.db.Model(&models.Document{}). + Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id"). + Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID). + Count(&documentCount).Error + if err != nil { + return false, err + } + if documentCount > 0 { + return true, nil + } + + // Check document images: image_url -> document_image -> document -> residence -> user access + var documentImageCount int64 + err = s.db.Model(&models.DocumentImage{}). + Joins("JOIN task_document ON task_document.id = task_documentimage.document_id"). + Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id"). + Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID). + Count(&documentImageCount).Error + if err != nil { + return false, err + } + if documentImageCount > 0 { + return true, nil + } + + return false, nil +} diff --git a/internal/services/iap_validation.go b/internal/services/iap_validation.go index 8395008..6b43930 100644 --- a/internal/services/iap_validation.go +++ b/internal/services/iap_validation.go @@ -36,11 +36,12 @@ var ( // AppleIAPClient handles Apple App Store Server API validation type AppleIAPClient struct { - keyID string - issuerID string - bundleID string + keyID string + issuerID string + bundleID string privateKey *ecdsa.PrivateKey - sandbox bool + sandbox bool + httpClient *http.Client // P-07: Reused across requests } // GoogleIAPClient handles Google Play Developer API validation @@ -122,6 +123,7 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) { bundleID: cfg.BundleID, privateKey: ecdsaKey, sandbox: cfg.Sandbox, + httpClient: &http.Client{Timeout: 30 * time.Second}, // P-07: Single client reused across requests }, nil } @@ -168,8 +170,8 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + // P-07: Reuse the single http.Client instead of creating one per request + resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to call Apple API: %w", err) } @@ -276,8 +278,8 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + // P-07: Reuse the single http.Client + resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to call Apple API: %w", err) } @@ -357,8 +359,8 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + // P-07: Reuse the single http.Client + resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to call Apple verifyReceipt: %w", err) } diff --git a/internal/services/notification_service.go b/internal/services/notification_service.go index 684c287..1ce4858 100644 --- a/internal/services/notification_service.go +++ b/internal/services/notification_service.go @@ -4,9 +4,11 @@ import ( "context" "encoding/json" "errors" + "fmt" "strconv" "time" + "github.com/rs/zerolog/log" "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/apperrors" @@ -184,8 +186,33 @@ func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferen return NewNotificationPreferencesResponse(prefs), nil } +// validateHourField checks that an optional hour value is in the valid range 0-23. +func validateHourField(val *int, fieldName string) error { + if val != nil && (*val < 0 || *val > 23) { + return apperrors.BadRequest("error.invalid_hour"). + WithMessage(fmt.Sprintf("%s must be between 0 and 23", fieldName)) + } + return nil +} + // UpdatePreferences updates notification preferences func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) { + // B-12: Validate hour fields are in range 0-23 + hourFields := []struct { + value *int + name string + }{ + {req.TaskDueSoonHour, "task_due_soon_hour"}, + {req.TaskOverdueHour, "task_overdue_hour"}, + {req.WarrantyExpiringHour, "warranty_expiring_hour"}, + {req.DailyDigestHour, "daily_digest_hour"}, + } + for _, hf := range hourFields { + if err := validateHourField(hf.value, hf.name); err != nil { + return nil, err + } + } + prefs, err := s.notificationRepo.GetOrCreatePreferences(userID) if err != nil { return nil, apperrors.Internal(err) @@ -256,7 +283,10 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) { // Only update if timezone changed (avoid unnecessary DB writes) if prefs.Timezone == nil || *prefs.Timezone != timezone { prefs.Timezone = &timezone - _ = s.notificationRepo.UpdatePreferences(prefs) + if err := s.notificationRepo.UpdatePreferences(prefs); err != nil { + log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone). + Msg("Failed to update user timezone in notification preferences") + } } } @@ -430,6 +460,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string, // === Response/Request Types === +// TODO(hardening): Move to internal/dto/responses/notification.go // NotificationResponse represents a notification in API response type NotificationResponse struct { ID uint `json:"id"` @@ -473,6 +504,7 @@ func NewNotificationResponse(n *models.Notification) NotificationResponse { return resp } +// TODO(hardening): Move to internal/dto/responses/notification.go // NotificationPreferencesResponse represents notification preferences type NotificationPreferencesResponse struct { TaskDueSoon bool `json:"task_due_soon"` @@ -511,6 +543,7 @@ func NewNotificationPreferencesResponse(p *models.NotificationPreference) *Notif } } +// TODO(hardening): Move to internal/dto/requests/notification.go // UpdatePreferencesRequest represents preferences update request type UpdatePreferencesRequest struct { TaskDueSoon *bool `json:"task_due_soon"` @@ -532,6 +565,7 @@ type UpdatePreferencesRequest struct { DailyDigestHour *int `json:"daily_digest_hour"` } +// TODO(hardening): Move to internal/dto/responses/notification.go // DeviceResponse represents a device in API response type DeviceResponse struct { ID uint `json:"id"` @@ -569,6 +603,7 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse { } } +// TODO(hardening): Move to internal/dto/requests/notification.go // RegisterDeviceRequest represents device registration request type RegisterDeviceRequest struct { Name string `json:"name"` diff --git a/internal/services/onboarding_email_service.go b/internal/services/onboarding_email_service.go index 1dc36b0..4f18616 100644 --- a/internal/services/onboarding_email_service.go +++ b/internal/services/onboarding_email_service.go @@ -39,6 +39,7 @@ func generateTrackingID() string { // HasSentEmail checks if a specific email type has already been sent to a user func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool { + // TODO(hardening): Replace with OnboardingEmailRepository.HasSentEmail() var count int64 if err := s.db.Model(&models.OnboardingEmail{}). Where("user_id = ? AND email_type = ?", userID, emailType). @@ -51,6 +52,7 @@ func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.Onbo // RecordEmailSent records that an email was sent to a user func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.OnboardingEmailType, trackingID string) error { + // TODO(hardening): Replace with OnboardingEmailRepository.Create() email := &models.OnboardingEmail{ UserID: userID, EmailType: emailType, @@ -66,6 +68,7 @@ func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.O // RecordEmailOpened records that an email was opened based on tracking ID func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error { + // TODO(hardening): Replace with OnboardingEmailRepository.MarkOpened() now := time.Now().UTC() result := s.db.Model(&models.OnboardingEmail{}). Where("tracking_id = ? AND opened_at IS NULL", trackingID). @@ -84,6 +87,7 @@ func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error { // GetEmailHistory gets all onboarding emails for a specific user func (s *OnboardingEmailService) GetEmailHistory(userID uint) ([]models.OnboardingEmail, error) { + // TODO(hardening): Replace with OnboardingEmailRepository.FindByUserID() var emails []models.OnboardingEmail if err := s.db.Where("user_id = ?", userID).Order("sent_at DESC").Find(&emails).Error; err != nil { return nil, err @@ -105,11 +109,13 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model offset := (page - 1) * pageSize + // TODO(hardening): Replace with OnboardingEmailRepository.CountAll() // Count total if err := s.db.Model(&models.OnboardingEmail{}).Count(&total).Error; err != nil { return nil, 0, err } + // TODO(hardening): Replace with OnboardingEmailRepository.FindAllPaginated() // Get paginated results with user info if err := s.db.Preload("User"). Order("sent_at DESC"). @@ -126,6 +132,7 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) { stats := &OnboardingEmailStats{} + // TODO(hardening): Replace with OnboardingEmailRepository.GetStats() // No residence email stats var noResTotal, noResOpened int64 if err := s.db.Model(&models.OnboardingEmail{}). @@ -159,6 +166,7 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) return stats, nil } +// TODO(hardening): Move to internal/dto/responses/onboarding_email.go // OnboardingEmailStats represents statistics about onboarding emails type OnboardingEmailStats struct { NoResidenceTotal int64 `json:"no_residence_total"` @@ -173,6 +181,7 @@ func (s *OnboardingEmailService) UsersNeedingNoResidenceEmail() ([]models.User, twoDaysAgo := time.Now().UTC().AddDate(0, 0, -2) + // TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutResidence() // Find users who: // 1. Are verified // 2. Registered 2+ days ago @@ -201,6 +210,7 @@ func (s *OnboardingEmailService) UsersNeedingNoTasksEmail() ([]models.User, erro fiveDaysAgo := time.Now().UTC().AddDate(0, 0, -5) + // TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutTasks() // Find users who: // 1. Are verified // 2. Have at least one residence @@ -325,6 +335,7 @@ func (s *OnboardingEmailService) sendNoTasksEmail(user models.User) error { // SendOnboardingEmailToUser manually sends an onboarding email to a specific user // This is used by admin to force-send emails regardless of eligibility criteria func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailType models.OnboardingEmailType) error { + // TODO(hardening): Replace with UserRepository.FindByID() (inject UserRepository) // Load the user var user models.User if err := s.db.First(&user, userID).Error; err != nil { @@ -362,6 +373,7 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp // If already sent before, delete the old record first to allow re-recording // This allows admins to "resend" emails while still tracking them if alreadySent { + // TODO(hardening): Replace with OnboardingEmailRepository.DeleteByUserAndType() if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil { log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend") } diff --git a/internal/services/pdf_service.go b/internal/services/pdf_service.go index 3d0abc3..19eb783 100644 --- a/internal/services/pdf_service.go +++ b/internal/services/pdf_service.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/jung-kurt/gofpdf" + "github.com/go-pdf/fpdf" ) // PDFService handles PDF generation @@ -18,7 +18,7 @@ func NewPDFService() *PDFService { // GenerateTasksReportPDF generates a PDF report from task report data func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte, error) { - pdf := gofpdf.New("P", "mm", "A4", "") + pdf := fpdf.New("P", "mm", "A4", "") pdf.SetMargins(15, 15, 15) pdf.AddPage() diff --git a/internal/services/residence_service.go b/internal/services/residence_service.go index 369fe67..c6674af 100644 --- a/internal/services/residence_service.go +++ b/internal/services/residence_service.go @@ -133,14 +133,16 @@ func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*respons } } - // Attach completion summaries (honeycomb grid data) - for i := range residenceResponses { - summary, err := s.taskRepo.GetCompletionSummary(residenceResponses[i].ID, now, 10) - if err != nil { - log.Warn().Err(err).Uint("residence_id", residenceResponses[i].ID).Msg("Failed to fetch completion summary") - continue + // P-01: Batch fetch completion summaries in 2 queries total instead of 2*N + summaries, err := s.taskRepo.GetBatchCompletionSummaries(residenceIDs, now, 10) + if err != nil { + log.Warn().Err(err).Msg("Failed to fetch batch completion summaries") + } else { + for i := range residenceResponses { + if summary, ok := summaries[residenceResponses[i].ID]; ok { + residenceResponses[i].CompletionSummary = summary + } } - residenceResponses[i].CompletionSummary = summary } } diff --git a/internal/services/storage_service.go b/internal/services/storage_service.go index aea834d..2d370d7 100644 --- a/internal/services/storage_service.go +++ b/internal/services/storage_service.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "mime/multipart" + "net/http" "os" "path/filepath" "strings" @@ -17,7 +18,8 @@ import ( // StorageService handles file uploads to local filesystem type StorageService struct { - cfg *config.StorageConfig + cfg *config.StorageConfig + allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups } // UploadResult contains information about an uploaded file @@ -44,9 +46,18 @@ func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) { } } - log.Info().Str("upload_dir", cfg.UploadDir).Msg("Storage service initialized") + // P-12: Parse AllowedTypes once at initialization for O(1) lookups + allowedTypes := make(map[string]struct{}) + for _, t := range strings.Split(cfg.AllowedTypes, ",") { + trimmed := strings.TrimSpace(t) + if trimmed != "" { + allowedTypes[trimmed] = struct{}{} + } + } - return &StorageService{cfg: cfg}, nil + log.Info().Str("upload_dir", cfg.UploadDir).Int("allowed_types", len(allowedTypes)).Msg("Storage service initialized") + + return &StorageService{cfg: cfg, allowedTypes: allowedTypes}, nil } // Upload saves a file to the local filesystem @@ -56,17 +67,47 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize) } - // Get MIME type - mimeType := file.Header.Get("Content-Type") - if mimeType == "" { - mimeType = "application/octet-stream" + // Get claimed MIME type from header + claimedMimeType := file.Header.Get("Content-Type") + if claimedMimeType == "" { + claimedMimeType = "application/octet-stream" } - // Validate MIME type + // S-09: Detect actual content type from file bytes to prevent disguised uploads + src, err := file.Open() + if err != nil { + return nil, fmt.Errorf("failed to open uploaded file: %w", err) + } + defer src.Close() + + // Read the first 512 bytes for content type detection + sniffBuf := make([]byte, 512) + n, err := src.Read(sniffBuf) + if err != nil && n == 0 { + return nil, fmt.Errorf("failed to read file for content type detection: %w", err) + } + detectedMimeType := http.DetectContentType(sniffBuf[:n]) + + // Validate that the detected type matches the claimed type (at the category level) + // Allow application/octet-stream from detection since DetectContentType may not + // recognize all valid types, but the claimed type must still be in our allowed list + if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) { + return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType) + } + + // Use the claimed MIME type (which is more specific) if it's allowed + mimeType := claimedMimeType + + // Validate MIME type against allowed list if !s.isAllowedType(mimeType) { return nil, fmt.Errorf("file type %s is not allowed", mimeType) } + // Seek back to beginning after sniffing + if _, err := src.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("failed to seek file: %w", err) + } + // Generate unique filename ext := filepath.Ext(file.Filename) if ext == "" { @@ -83,15 +124,11 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U subdir = "completions" } - // Full path - destPath := filepath.Join(s.cfg.UploadDir, subdir, newFilename) - - // Open source file - src, err := file.Open() + // S-18: Sanitize path to prevent traversal attacks + destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename)) if err != nil { - return nil, fmt.Errorf("failed to open uploaded file: %w", err) + return nil, fmt.Errorf("invalid upload path: %w", err) } - defer src.Close() // Create destination file dst, err := os.Create(destPath) @@ -131,19 +168,11 @@ func (s *StorageService) Delete(fileURL string) error { // Convert URL to file path relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL) relativePath = strings.TrimPrefix(relativePath, "/") - fullPath := filepath.Join(s.cfg.UploadDir, relativePath) - // Security check: ensure path is within upload directory - absUploadDir, err := filepath.Abs(s.cfg.UploadDir) + // S-18: Use SafeResolvePath to prevent path traversal + fullPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath) if err != nil { - return fmt.Errorf("failed to resolve upload directory: %w", err) - } - absFilePath, err := filepath.Abs(fullPath) - if err != nil { - return fmt.Errorf("failed to resolve file path: %w", err) - } - if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir { - return fmt.Errorf("invalid file path") + return fmt.Errorf("invalid file path: %w", err) } if err := os.Remove(fullPath); err != nil { @@ -157,15 +186,23 @@ func (s *StorageService) Delete(fileURL string) error { return nil } -// isAllowedType checks if the MIME type is in the allowed list +// isAllowedType checks if the MIME type is in the allowed list. +// P-12: Uses the pre-parsed allowedTypes map for O(1) lookups instead of +// splitting the config string on every call. func (s *StorageService) isAllowedType(mimeType string) bool { - allowed := strings.Split(s.cfg.AllowedTypes, ",") - for _, t := range allowed { - if strings.TrimSpace(t) == mimeType { - return true - } + _, ok := s.allowedTypes[mimeType] + return ok +} + +// mimeTypesCompatible checks if the claimed and detected MIME types are compatible. +// Two MIME types are compatible if they share the same primary type (e.g., both "image/*"). +func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool { + claimedParts := strings.SplitN(claimed, "/", 2) + detectedParts := strings.SplitN(detected, "/", 2) + if len(claimedParts) < 1 || len(detectedParts) < 1 { + return false } - return false + return claimedParts[0] == detectedParts[0] } // getExtensionFromMimeType returns a file extension for common MIME types @@ -191,5 +228,12 @@ func (s *StorageService) GetUploadDir() string { // NewStorageServiceForTest creates a StorageService without creating directories. // This is intended only for unit tests that need a StorageService with a known config. func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService { - return &StorageService{cfg: cfg} + allowedTypes := make(map[string]struct{}) + for _, t := range strings.Split(cfg.AllowedTypes, ",") { + trimmed := strings.TrimSpace(t) + if trimmed != "" { + allowedTypes[trimmed] = struct{}{} + } + } + return &StorageService{cfg: cfg, allowedTypes: allowedTypes} } diff --git a/internal/services/stripe_service.go b/internal/services/stripe_service.go index 2c40d5d..709208c 100644 --- a/internal/services/stripe_service.go +++ b/internal/services/stripe_service.go @@ -3,10 +3,10 @@ package services import ( "encoding/json" "fmt" - "os" "time" "github.com/rs/zerolog/log" + "github.com/spf13/viper" "github.com/stripe/stripe-go/v81" portalsession "github.com/stripe/stripe-go/v81/billingportal/session" checkoutsession "github.com/stripe/stripe-go/v81/checkout/session" @@ -34,7 +34,8 @@ func NewStripeService( subscriptionRepo *repositories.SubscriptionRepository, userRepo *repositories.UserRepository, ) *StripeService { - key := os.Getenv("STRIPE_SECRET_KEY") + // S-21: Use Viper config instead of os.Getenv for consistent configuration management + key := viper.GetString("STRIPE_SECRET_KEY") if key == "" { log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work") } else { @@ -42,7 +43,7 @@ func NewStripeService( log.Info().Msg("Stripe API key configured") } - webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET") + webhookSecret := viper.GetString("STRIPE_WEBHOOK_SECRET") if webhookSecret == "" { log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail") } diff --git a/internal/services/subscription_service.go b/internal/services/subscription_service.go index 52cc323..52e8db8 100644 --- a/internal/services/subscription_service.go +++ b/internal/services/subscription_service.go @@ -202,18 +202,19 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS } // getUserUsage calculates current usage for a user. +// P-10: Uses CountByOwner for properties count instead of loading all owned residences. // Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)). func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) { - residences, err := s.residenceRepo.FindOwnedByUser(userID) + // P-10: Use CountByOwner for an efficient COUNT query instead of loading all records + propertiesCount, err := s.residenceRepo.CountByOwner(userID) if err != nil { return nil, apperrors.Internal(err) } - propertiesCount := int64(len(residences)) - // Collect residence IDs for batch queries - residenceIDs := make([]uint, len(residences)) - for i, r := range residences { - residenceIDs[i] = r.ID + // Still need residence IDs for batch counting tasks/contractors/documents + residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID) + if err != nil { + return nil, apperrors.Internal(err) } // Count tasks, contractors, and documents across all residences with single queries each diff --git a/internal/services/task_service.go b/internal/services/task_service.go index a5e0fae..ac44c31 100644 --- a/internal/services/task_service.go +++ b/internal/services/task_service.go @@ -130,7 +130,7 @@ func (s *TaskService) ListTasks(userID uint, daysThreshold int, now time.Time) ( return nil, apperrors.Internal(err) } - resp := responses.NewKanbanBoardResponseForAll(board) + resp := responses.NewKanbanBoardResponseForAll(board, now) // NOTE: Summary statistics are calculated client-side from kanban data return &resp, nil } @@ -157,7 +157,7 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol return nil, apperrors.Internal(err) } - resp := responses.NewKanbanBoardResponse(board, residenceID) + resp := responses.NewKanbanBoardResponse(board, residenceID, now) // NOTE: Summary statistics are calculated client-side from kanban data return &resp, nil } @@ -601,8 +601,8 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest task.InProgress = false } - // P1-5: Wrap completion creation and task update in a transaction. - // If either operation fails, both are rolled back to prevent orphaned completions. + // P1-5 + B-07: Wrap completion creation, task update, and image creation + // in a single transaction for atomicity. If any operation fails, all are rolled back. txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error { if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil { return err @@ -610,6 +610,18 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest if err := s.taskRepo.UpdateTx(tx, task); err != nil { return err } + // B-07: Create images inside the same transaction as completion + for _, imageURL := range req.ImageURLs { + if imageURL != "" { + img := &models.TaskCompletionImage{ + CompletionID: completion.ID, + ImageURL: imageURL, + } + if err := tx.Create(img).Error; err != nil { + return fmt.Errorf("failed to create completion image: %w", err) + } + } + } return nil }) if txErr != nil { @@ -621,19 +633,6 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest return nil, apperrors.Internal(txErr) } - // Create images if provided - for _, imageURL := range req.ImageURLs { - if imageURL != "" { - img := &models.TaskCompletionImage{ - CompletionID: completion.ID, - ImageURL: imageURL, - } - if err := s.taskRepo.CreateCompletionImage(img); err != nil { - log.Error().Err(err).Uint("completion_id", completion.ID).Msg("Failed to create completion image") - } - } - } - // Reload completion with user info and images completion, err = s.taskRepo.FindCompletionByID(completion.ID) if err != nil { @@ -663,8 +662,10 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest }, nil } -// QuickComplete creates a minimal task completion (for widget use) -// Returns only success/error, no response body +// QuickComplete creates a minimal task completion (for widget use). +// LE-01: The entire operation (completion creation + task update) is wrapped in a +// transaction for atomicity. +// Returns only success/error, no response body. func (s *TaskService) QuickComplete(taskID uint, userID uint) error { // Get the task task, err := s.taskRepo.FindByID(taskID) @@ -697,10 +698,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error { CompletedFromColumn: completedFromColumn, } - if err := s.taskRepo.CreateCompletion(completion); err != nil { - return apperrors.Internal(err) - } - // Update next_due_date and in_progress based on frequency // Determine interval days: Custom frequency uses task.CustomIntervalDays, otherwise use frequency.Days // Note: Frequency is no longer preloaded for performance, so we load it separately if needed @@ -729,7 +726,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error { } else { // Recurring task - calculate next due date from completion date + interval nextDue := completedAt.AddDate(0, 0, *quickIntervalDays) - // frequencyName was already set when loading frequency above log.Info(). Uint("task_id", task.ID). Str("frequency_name", frequencyName). @@ -742,12 +738,23 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error { // Reset in_progress to false task.InProgress = false } - if err := s.taskRepo.Update(task); err != nil { - if errors.Is(err, repositories.ErrVersionConflict) { + + // LE-01: Wrap completion creation and task update in a transaction for atomicity + txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error { + if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil { + return err + } + if err := s.taskRepo.UpdateTx(tx, task); err != nil { + return err + } + return nil + }) + if txErr != nil { + if errors.Is(txErr, repositories.ErrVersionConflict) { return apperrors.Conflict("error.version_conflict") } - log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion") - return apperrors.Internal(err) // Return error so caller knows the update failed + log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task in QuickComplete") + return apperrors.Internal(txErr) } log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully") @@ -813,8 +820,16 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio // Send email notification (to everyone INCLUDING the person who completed it) // Check user's email notification preferences first if s.emailService != nil && user.Email != "" && s.notificationService != nil { - prefs, err := s.notificationService.GetPreferences(user.ID) - if err != nil || (prefs != nil && prefs.EmailTaskCompleted) { + prefs, prefsErr := s.notificationService.GetPreferences(user.ID) + // LE-06: Log fail-open behavior when preferences cannot be loaded + if prefsErr != nil { + log.Warn(). + Err(prefsErr). + Uint("user_id", user.ID). + Uint("task_id", task.ID). + Msg("Failed to load notification preferences, falling back to sending email (fail-open)") + } + if prefsErr != nil || (prefs != nil && prefs.EmailTaskCompleted) { // Send email if we couldn't get prefs (fail-open) or if email notifications are enabled if err := s.emailService.SendTaskCompletedEmail( user.Email, diff --git a/internal/services/user_service.go b/internal/services/user_service.go index d91903f..abddc3c 100644 --- a/internal/services/user_service.go +++ b/internal/services/user_service.go @@ -32,7 +32,8 @@ func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.User return nil, apperrors.Internal(err) } - var result []responses.UserSummary + // F-23: Initialize as empty slice so JSON serialization produces [] instead of null + result := make([]responses.UserSummary, 0, len(users)) for _, u := range users { result = append(result, responses.UserSummary{ ID: u.ID, @@ -72,7 +73,8 @@ func (s *UserService) ListProfilesInSharedResidences(userID uint) ([]responses.U return nil, apperrors.Internal(err) } - var result []responses.UserProfileSummary + // F-23: Initialize as empty slice so JSON serialization produces [] instead of null + result := make([]responses.UserProfileSummary, 0, len(profiles)) for _, p := range profiles { result = append(result, responses.UserProfileSummary{ ID: p.ID, diff --git a/internal/task/consistency_test.go b/internal/task/consistency_test.go index 91ad449..ebf8062 100644 --- a/internal/task/consistency_test.go +++ b/internal/task/consistency_test.go @@ -27,7 +27,10 @@ var testDB *gorm.DB // testUserID is a user ID that exists in the database for foreign key constraints var testUserID uint = 1 -// TestMain sets up the database connection for all tests in this package +// TestMain sets up the database connection for all tests in this package. +// If the database is not available, testDB remains nil and individual tests +// will call t.Skip() instead of using os.Exit(0), which preserves proper +// test reporting and coverage output. func TestMain(m *testing.M) { dsn := os.Getenv("TEST_DATABASE_URL") if dsn == "" { @@ -39,15 +42,23 @@ func TestMain(m *testing.M) { Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { - println("Skipping consistency integration tests: database not available") + // Explicitly nil out testDB; individual tests will t.Skip("Database not available") + testDB = nil + println("Consistency integration tests will be skipped: database not available") println("Set TEST_DATABASE_URL to run these tests") - os.Exit(0) + os.Exit(m.Run()) } sqlDB, err := testDB.DB() - if err != nil || sqlDB.Ping() != nil { - println("Failed to connect to database") - os.Exit(0) + if err != nil { + println("Failed to get underlying DB:", err.Error()) + testDB = nil + os.Exit(m.Run()) + } + if pingErr := sqlDB.Ping(); pingErr != nil { + println("Failed to ping database:", pingErr.Error()) + testDB = nil + os.Exit(m.Run()) } println("Database connected, running consistency tests...") diff --git a/internal/task/scopes/scopes_test.go b/internal/task/scopes/scopes_test.go index 40596eb..2dcdcab 100644 --- a/internal/task/scopes/scopes_test.go +++ b/internal/task/scopes/scopes_test.go @@ -17,7 +17,10 @@ import ( // testDB holds the database connection for integration tests var testDB *gorm.DB -// TestMain sets up the database connection for all tests in this package +// TestMain sets up the database connection for all tests in this package. +// If the database is not available, testDB remains nil and individual tests +// will call t.Skip() instead of using os.Exit(0), which preserves proper +// test reporting and coverage output. func TestMain(m *testing.M) { // Get database URL from environment or use default dsn := os.Getenv("TEST_DATABASE_URL") @@ -30,22 +33,25 @@ func TestMain(m *testing.M) { Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { - // Print message and skip tests if database is not available - println("Skipping scope integration tests: database not available") + // Explicitly nil out testDB; individual tests will t.Skip("Database not available") + testDB = nil + println("Scope integration tests will be skipped: database not available") println("Set TEST_DATABASE_URL to run these tests") println("Error:", err.Error()) - os.Exit(0) + os.Exit(m.Run()) } // Verify connection works sqlDB, err := testDB.DB() if err != nil { println("Failed to get underlying DB:", err.Error()) - os.Exit(0) + testDB = nil + os.Exit(m.Run()) } if err := sqlDB.Ping(); err != nil { println("Failed to ping database:", err.Error()) - os.Exit(0) + testDB = nil + os.Exit(m.Run()) } println("Database connected successfully, running integration tests...") @@ -57,7 +63,9 @@ func TestMain(m *testing.M) { &models.Residence{}, ) if err != nil { - os.Exit(1) + println("Failed to run migrations:", err.Error()) + testDB = nil + os.Exit(m.Run()) } // Run tests diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index dfca7f4..a9165cb 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -3,6 +3,7 @@ package testutil import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" "sync" @@ -71,7 +72,12 @@ func SetupTestDB(t *testing.T) *gorm.DB { return db } -// SetupTestRouter creates a test Echo router with the custom error handler +// SetupTestRouter creates a test Echo router with the custom error handler. +// Uses apperrors.HTTPErrorHandler which is the base error handler shared with +// production (router.customHTTPErrorHandler). Both handle AppError, ValidationErrors, +// and echo.HTTPError identically. Production additionally maps legacy service sentinel +// errors (e.g., services.ErrTaskNotFound) which are being migrated to AppError types. +// Tests exercise handlers that return AppError, so this handler covers all test scenarios. func SetupTestRouter() *echo.Echo { e := echo.New() e.Validator = validator.NewCustomValidator() @@ -79,17 +85,52 @@ func SetupTestRouter() *echo.Echo { return e } -// MakeRequest makes a test HTTP request and returns the response +// MakeRequest makes a test HTTP request and returns the response. +// Errors from JSON marshaling and HTTP request construction are checked and +// will panic if they occur, since these indicate programming errors in tests. +// Prefer MakeRequestT for new tests, which uses t.Fatal for better reporting. func MakeRequest(router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder { var reqBody *bytes.Buffer if body != nil { - jsonBody, _ := json.Marshal(body) + jsonBody, err := json.Marshal(body) + if err != nil { + panic(fmt.Sprintf("testutil.MakeRequest: failed to marshal request body: %v", err)) + } reqBody = bytes.NewBuffer(jsonBody) } else { reqBody = bytes.NewBuffer(nil) } - req, _ := http.NewRequest(method, path, reqBody) + req, err := http.NewRequest(method, path, reqBody) + if err != nil { + panic(fmt.Sprintf("testutil.MakeRequest: failed to create HTTP request: %v", err)) + } + req.Header.Set("Content-Type", "application/json") + if token != "" { + req.Header.Set("Authorization", "Token "+token) + } + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec +} + +// MakeRequestT is like MakeRequest but accepts a *testing.T for proper test +// failure reporting. Prefer this over MakeRequest in new tests. +func MakeRequestT(t *testing.T, router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder { + t.Helper() + + var reqBody *bytes.Buffer + if body != nil { + jsonBody, err := json.Marshal(body) + require.NoError(t, err, "failed to marshal request body") + reqBody = bytes.NewBuffer(jsonBody) + } else { + reqBody = bytes.NewBuffer(nil) + } + + req, err := http.NewRequest(method, path, reqBody) + require.NoError(t, err, "failed to create HTTP request") req.Header.Set("Content-Type", "application/json") if token != "" { req.Header.Set("Authorization", "Token "+token) @@ -215,8 +256,12 @@ func CreateTestTask(t *testing.T, db *gorm.DB, residenceID, createdByID uint, ti return task } -// SeedLookupData seeds all lookup tables with test data +// SeedLookupData seeds all lookup tables with test data. +// All GORM create operations are checked for errors to prevent silent failures +// that could cause misleading test results. func SeedLookupData(t *testing.T, db *gorm.DB) { + t.Helper() + // Residence types residenceTypes := []models.ResidenceType{ {Name: "House"}, @@ -224,8 +269,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) { {Name: "Condo"}, {Name: "Townhouse"}, } - for _, rt := range residenceTypes { - db.Create(&rt) + for i := range residenceTypes { + err := db.Create(&residenceTypes[i]).Error + require.NoError(t, err, "failed to seed residence type: %s", residenceTypes[i].Name) } // Task categories @@ -235,8 +281,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) { {Name: "HVAC", DisplayOrder: 3}, {Name: "General", DisplayOrder: 99}, } - for _, c := range categories { - db.Create(&c) + for i := range categories { + err := db.Create(&categories[i]).Error + require.NoError(t, err, "failed to seed task category: %s", categories[i].Name) } // Task priorities @@ -246,8 +293,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) { {Name: "High", Level: 3, DisplayOrder: 3}, {Name: "Urgent", Level: 4, DisplayOrder: 4}, } - for _, p := range priorities { - db.Create(&p) + for i := range priorities { + err := db.Create(&priorities[i]).Error + require.NoError(t, err, "failed to seed task priority: %s", priorities[i].Name) } // Task frequencies @@ -258,8 +306,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) { {Name: "Weekly", Days: &days7, DisplayOrder: 2}, {Name: "Monthly", Days: &days30, DisplayOrder: 3}, } - for _, f := range frequencies { - db.Create(&f) + for i := range frequencies { + err := db.Create(&frequencies[i]).Error + require.NoError(t, err, "failed to seed task frequency: %s", frequencies[i].Name) } // Contractor specialties @@ -269,8 +318,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) { {Name: "HVAC Technician"}, {Name: "Handyman"}, } - for _, s := range specialties { - db.Create(&s) + for i := range specialties { + err := db.Create(&specialties[i]).Error + require.NoError(t, err, "failed to seed contractor specialty: %s", specialties[i].Name) } } diff --git a/internal/worker/jobs/handler.go b/internal/worker/jobs/handler.go index dd0cc2e..f43cbfb 100644 --- a/internal/worker/jobs/handler.go +++ b/internal/worker/jobs/handler.go @@ -20,8 +20,6 @@ import ( // Task types const ( - TypeTaskReminder = "notification:task_reminder" - TypeOverdueReminder = "notification:overdue_reminder" TypeSmartReminder = "notification:smart_reminder" // Frequency-aware reminders TypeDailyDigest = "notification:daily_digest" TypeSendEmail = "email:send" @@ -36,6 +34,7 @@ type Handler struct { taskRepo *repositories.TaskRepository residenceRepo *repositories.ResidenceRepository reminderRepo *repositories.ReminderRepository + notificationRepo *repositories.NotificationRepository pushClient *push.Client emailService *services.EmailService notificationService *services.NotificationService @@ -56,6 +55,7 @@ func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.Ema taskRepo: repositories.NewTaskRepository(db), residenceRepo: repositories.NewResidenceRepository(db), reminderRepo: repositories.NewReminderRepository(db), + notificationRepo: repositories.NewNotificationRepository(db), pushClient: pushClient, emailService: emailService, notificationService: notificationService, @@ -64,218 +64,6 @@ func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.Ema } } -// TaskReminderData represents a task due soon for reminder notifications -type TaskReminderData struct { - TaskID uint - TaskTitle string - DueDate time.Time - UserID uint - UserEmail string - UserName string - ResidenceName string -} - -// HandleTaskReminder processes task reminder notifications for tasks due today or tomorrow with actionable buttons -func (h *Handler) HandleTaskReminder(ctx context.Context, task *asynq.Task) error { - log.Info().Msg("Processing task reminder notifications...") - - now := time.Now().UTC() - currentHour := now.Hour() - systemDefaultHour := h.config.Worker.TaskReminderHour - - log.Info().Int("current_hour", currentHour).Int("system_default_hour", systemDefaultHour).Msg("Task reminder check") - - // Step 1: Find users who should receive notifications THIS hour - // Logic: Each user gets notified ONCE per day at exactly ONE hour: - // - If user has custom hour set: notify ONLY at that custom hour - // - If user has NO custom hour (NULL): notify ONLY at system default hour - // This prevents duplicates: a user with custom hour is NEVER notified at default hour - var eligibleUserIDs []uint - - query := h.db.Model(&models.NotificationPreference{}). - Select("user_id"). - Where("task_due_soon = true") - - if currentHour == systemDefaultHour { - // At system default hour: notify users who have NO custom hour (NULL) OR whose custom hour equals default - query = query.Where("task_due_soon_hour IS NULL OR task_due_soon_hour = ?", currentHour) - } else { - // At non-default hour: only notify users who have this specific custom hour set - // Exclude users with NULL (they get notified at default hour only) - query = query.Where("task_due_soon_hour = ?", currentHour) - } - - err := query.Pluck("user_id", &eligibleUserIDs).Error - - if err != nil { - log.Error().Err(err).Msg("Failed to query eligible users for task reminders") - return err - } - - // Early exit if no users need notifications this hour - if len(eligibleUserIDs) == 0 { - log.Debug().Int("hour", currentHour).Msg("No users scheduled for task reminder notifications this hour") - return nil - } - - log.Info().Int("eligible_users", len(eligibleUserIDs)).Msg("Found users eligible for task reminders this hour") - - // Step 2: Query tasks due today or tomorrow using the single-purpose repository function - // Uses the same scopes as kanban for consistency, with IncludeInProgress=true - // so users still get notified about in-progress tasks that are due soon. - opts := repositories.TaskFilterOptions{ - UserIDs: eligibleUserIDs, - IncludeInProgress: true, // Notifications should include in-progress tasks - PreloadResidence: true, - PreloadCompletions: true, - } - - // Due soon = due within 2 days (today and tomorrow) - dueSoonTasks, err := h.taskRepo.GetDueSoonTasks(now, 2, opts) - if err != nil { - log.Error().Err(err).Msg("Failed to query tasks due soon") - return err - } - - log.Info().Int("count", len(dueSoonTasks)).Msg("Found tasks due today/tomorrow for eligible users") - - // Build set for O(1) eligibility lookups instead of O(N) linear scan - eligibleSet := make(map[uint]bool, len(eligibleUserIDs)) - for _, id := range eligibleUserIDs { - eligibleSet[id] = true - } - - // Group tasks by user (assigned_to or residence owner) - userTasks := make(map[uint][]models.Task) - for _, t := range dueSoonTasks { - var userID uint - if t.AssignedToID != nil { - userID = *t.AssignedToID - } else if t.Residence.ID != 0 { - userID = t.Residence.OwnerID - } else { - continue - } - // Only include if user is in eligible set (O(1) lookup) - if eligibleSet[userID] { - userTasks[userID] = append(userTasks[userID], t) - } - } - - // Step 3: Send notifications (no need to check preferences again - already filtered) - // Send individual task-specific notification for each task (all tasks, no limit) - for userID, taskList := range userTasks { - for _, t := range taskList { - if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, models.NotificationTaskDueSoon, &t); err != nil { - log.Error().Err(err).Uint("user_id", userID).Uint("task_id", t.ID).Msg("Failed to send task reminder notification") - } - } - } - - log.Info().Int("users_notified", len(userTasks)).Msg("Task reminder notifications completed") - return nil -} - -// HandleOverdueReminder processes overdue task notifications with actionable buttons -func (h *Handler) HandleOverdueReminder(ctx context.Context, task *asynq.Task) error { - log.Info().Msg("Processing overdue task notifications...") - - now := time.Now().UTC() - currentHour := now.Hour() - systemDefaultHour := h.config.Worker.OverdueReminderHour - - log.Info().Int("current_hour", currentHour).Int("system_default_hour", systemDefaultHour).Msg("Overdue reminder check") - - // Step 1: Find users who should receive notifications THIS hour - // Logic: Each user gets notified ONCE per day at exactly ONE hour: - // - If user has custom hour set: notify ONLY at that custom hour - // - If user has NO custom hour (NULL): notify ONLY at system default hour - // This prevents duplicates: a user with custom hour is NEVER notified at default hour - var eligibleUserIDs []uint - - query := h.db.Model(&models.NotificationPreference{}). - Select("user_id"). - Where("task_overdue = true") - - if currentHour == systemDefaultHour { - // At system default hour: notify users who have NO custom hour (NULL) OR whose custom hour equals default - query = query.Where("task_overdue_hour IS NULL OR task_overdue_hour = ?", currentHour) - } else { - // At non-default hour: only notify users who have this specific custom hour set - // Exclude users with NULL (they get notified at default hour only) - query = query.Where("task_overdue_hour = ?", currentHour) - } - - err := query.Pluck("user_id", &eligibleUserIDs).Error - - if err != nil { - log.Error().Err(err).Msg("Failed to query eligible users for overdue reminders") - return err - } - - // Early exit if no users need notifications this hour - if len(eligibleUserIDs) == 0 { - log.Debug().Int("hour", currentHour).Msg("No users scheduled for overdue notifications this hour") - return nil - } - - log.Info().Int("eligible_users", len(eligibleUserIDs)).Msg("Found users eligible for overdue reminders this hour") - - // Step 2: Query overdue tasks using the single-purpose repository function - // Uses the same scopes as kanban for consistency, with IncludeInProgress=true - // so users still get notified about in-progress tasks that are overdue. - opts := repositories.TaskFilterOptions{ - UserIDs: eligibleUserIDs, - IncludeInProgress: true, // Notifications should include in-progress tasks - PreloadResidence: true, - PreloadCompletions: true, - } - - overdueTasks, err := h.taskRepo.GetOverdueTasks(now, opts) - if err != nil { - log.Error().Err(err).Msg("Failed to query overdue tasks") - return err - } - - log.Info().Int("count", len(overdueTasks)).Msg("Found overdue tasks for eligible users") - - // Build set for O(1) eligibility lookups instead of O(N) linear scan - eligibleSet := make(map[uint]bool, len(eligibleUserIDs)) - for _, id := range eligibleUserIDs { - eligibleSet[id] = true - } - - // Group tasks by user (assigned_to or residence owner) - userTasks := make(map[uint][]models.Task) - for _, t := range overdueTasks { - var userID uint - if t.AssignedToID != nil { - userID = *t.AssignedToID - } else if t.Residence.ID != 0 { - userID = t.Residence.OwnerID - } else { - continue - } - // Only include if user is in eligible set (O(1) lookup) - if eligibleSet[userID] { - userTasks[userID] = append(userTasks[userID], t) - } - } - - // Step 3: Send notifications (no need to check preferences again - already filtered) - // Send individual task-specific notification for each task (all tasks, no limit) - for userID, taskList := range userTasks { - for _, t := range taskList { - if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, models.NotificationTaskOverdue, &t); err != nil { - log.Error().Err(err).Uint("user_id", userID).Uint("task_id", t.ID).Msg("Failed to send overdue notification") - } - } - } - - log.Info().Int("users_notified", len(userTasks)).Msg("Overdue task notifications completed") - return nil -} - // HandleDailyDigest processes daily digest notifications with task statistics func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error { log.Info().Msg("Processing daily digest notifications...") @@ -328,8 +116,7 @@ func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error // Get user's timezone from notification preferences for accurate overdue calculation // This ensures the daily digest matches what the user sees in the kanban UI var userNow time.Time - var prefs models.NotificationPreference - if err := h.db.Where("user_id = ?", userID).First(&prefs).Error; err == nil && prefs.Timezone != nil { + if prefs, err := h.notificationRepo.FindPreferencesByUser(userID); err == nil && prefs.Timezone != nil { if loc, err := time.LoadLocation(*prefs.Timezone); err == nil { // Use start of day in user's timezone (matches kanban behavior) userNowInTz := time.Now().In(loc) @@ -481,22 +268,11 @@ func (h *Handler) sendPushToUser(ctx context.Context, userID uint, title, messag return nil } - // Get iOS device tokens - var iosTokens []string - err := h.db.Model(&models.APNSDevice{}). - Where("user_id = ? AND active = ?", userID, true). - Pluck("registration_id", &iosTokens).Error + // Get active device tokens via repository + iosTokens, androidTokens, err := h.notificationRepo.GetActiveTokensForUser(userID) if err != nil { - log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get iOS tokens") - } - - // Get Android device tokens - var androidTokens []string - err = h.db.Model(&models.GCMDevice{}). - Where("user_id = ? AND active = ?", userID, true). - Pluck("registration_id", &androidTokens).Error - if err != nil { - log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get Android tokens") + log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get device tokens") + return err } if len(iosTokens) == 0 && len(androidTokens) == 0 { @@ -561,18 +337,17 @@ func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task) } // Send no-residence emails (users without any residences after 2 days) - noResCount, err := h.onboardingService.CheckAndSendNoResidenceEmails() - if err != nil { - log.Error().Err(err).Msg("Failed to process no-residence onboarding emails") - // Continue to next type, don't return error + noResCount, noResErr := h.onboardingService.CheckAndSendNoResidenceEmails() + if noResErr != nil { + log.Error().Err(noResErr).Msg("Failed to process no-residence onboarding emails") } else { log.Info().Int("count", noResCount).Msg("Sent no-residence onboarding emails") } // Send no-tasks emails (users with residence but no tasks after 5 days) - noTasksCount, err := h.onboardingService.CheckAndSendNoTasksEmails() - if err != nil { - log.Error().Err(err).Msg("Failed to process no-tasks onboarding emails") + noTasksCount, noTasksErr := h.onboardingService.CheckAndSendNoTasksEmails() + if noTasksErr != nil { + log.Error().Err(noTasksErr).Msg("Failed to process no-tasks onboarding emails") } else { log.Info().Int("count", noTasksCount).Msg("Sent no-tasks onboarding emails") } @@ -582,6 +357,11 @@ func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task) Int("no_tasks_sent", noTasksCount). Msg("Onboarding email processing completed") + // If all sub-tasks failed, return an error so Asynq retries + if noResErr != nil && noTasksErr != nil { + return fmt.Errorf("all onboarding email sub-tasks failed: no-residence: %w, no-tasks: %v", noResErr, noTasksErr) + } + return nil } @@ -603,7 +383,6 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err log.Info().Msg("Processing smart task reminders...") now := time.Now().UTC() - today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) currentHour := now.Hour() dueSoonDefault := h.config.Worker.TaskReminderHour @@ -673,6 +452,22 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err Int("want_overdue", len(userWantsOverdue)). Msg("Found users eligible for reminders") + // Build per-user "today" using their timezone preference + // This ensures reminder stage calculations (overdue, due-soon) match the user's local date + userToday := make(map[uint]time.Time, len(allUserIDs)) + utcToday := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + for _, uid := range allUserIDs { + prefs, err := h.notificationRepo.FindPreferencesByUser(uid) + if err == nil && prefs.Timezone != nil { + if loc, locErr := time.LoadLocation(*prefs.Timezone); locErr == nil { + userNowInTz := time.Now().In(loc) + userToday[uid] = time.Date(userNowInTz.Year(), userNowInTz.Month(), userNowInTz.Day(), 0, 0, 0, 0, loc) + continue + } + } + userToday[uid] = utcToday // Fallback to UTC + } + // Step 2: Single query to get ALL active tasks (both due-soon and overdue) for these users opts := repositories.TaskFilterOptions{ UserIDs: allUserIDs, @@ -734,7 +529,8 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err frequencyDays = &days } - // Determine which reminder stage applies today + // Determine which reminder stage applies today using the user's local date + today := userToday[userID] stage := notifications.GetReminderStageForToday(effectiveDate, frequencyDays, today) if stage == "" { continue @@ -800,6 +596,11 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err notificationType = models.NotificationTaskDueSoon } + // Log the reminder BEFORE sending so we have a record even if the send crashes + if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil { + log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder") + } + // Send notification if err := h.notificationService.CreateAndSendTaskNotification(ctx, c.userID, notificationType, &t); err != nil { log.Error().Err(err). @@ -810,11 +611,6 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err continue } - // Log the reminder - if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil { - log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder") - } - if c.isOverdue { overdueSent++ } else { diff --git a/internal/worker/scheduler.go b/internal/worker/scheduler.go index 9bb29d7..9694370 100644 --- a/internal/worker/scheduler.go +++ b/internal/worker/scheduler.go @@ -1,27 +1,18 @@ package worker import ( - "context" "encoding/json" - "fmt" - "time" "github.com/hibiken/asynq" "github.com/rs/zerolog/log" ) -// Task types +// Task types for email jobs const ( - TypeWelcomeEmail = "email:welcome" - TypeVerificationEmail = "email:verification" - TypePasswordResetEmail = "email:password_reset" - TypePasswordChangedEmail = "email:password_changed" - TypeTaskCompletionEmail = "email:task_completion" - TypeGeneratePDFReport = "pdf:generate_report" - TypeUpdateContractorRating = "contractor:update_rating" - TypeDailyNotifications = "notifications:daily" - TypeTaskReminders = "notifications:task_reminders" - TypeOverdueReminders = "notifications:overdue_reminders" + TypeWelcomeEmail = "email:welcome" + TypeVerificationEmail = "email:verification" + TypePasswordResetEmail = "email:password_reset" + TypePasswordChangedEmail = "email:password_changed" ) // EmailPayload is the base payload for email tasks @@ -146,94 +137,3 @@ func (c *TaskClient) EnqueuePasswordChangedEmail(to, firstName string) error { log.Debug().Str("to", to).Msg("Password changed email task enqueued") return nil } - -// WorkerServer manages the asynq worker server -type WorkerServer struct { - server *asynq.Server - scheduler *asynq.Scheduler -} - -// NewWorkerServer creates a new worker server -func NewWorkerServer(redisAddr string, concurrency int) *WorkerServer { - srv := asynq.NewServer( - asynq.RedisClientOpt{Addr: redisAddr}, - asynq.Config{ - Concurrency: concurrency, - Queues: map[string]int{ - "critical": 6, - "default": 3, - "low": 1, - }, - ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) { - log.Error(). - Err(err). - Str("type", task.Type()). - Bytes("payload", task.Payload()). - Msg("Task failed") - }), - }, - ) - - // Create scheduler for periodic tasks - loc, _ := time.LoadLocation("UTC") - scheduler := asynq.NewScheduler( - asynq.RedisClientOpt{Addr: redisAddr}, - &asynq.SchedulerOpts{ - Location: loc, - }, - ) - - return &WorkerServer{ - server: srv, - scheduler: scheduler, - } -} - -// RegisterHandlers registers task handlers -func (w *WorkerServer) RegisterHandlers(mux *asynq.ServeMux) { - // Handlers will be registered by the main worker process -} - -// RegisterScheduledTasks registers periodic tasks -func (w *WorkerServer) RegisterScheduledTasks() error { - // Task reminders - 8:00 PM UTC daily - _, err := w.scheduler.Register("0 20 * * *", asynq.NewTask(TypeTaskReminders, nil)) - if err != nil { - return fmt.Errorf("failed to register task reminders: %w", err) - } - - // Overdue reminders - 9:00 AM UTC daily - _, err = w.scheduler.Register("0 9 * * *", asynq.NewTask(TypeOverdueReminders, nil)) - if err != nil { - return fmt.Errorf("failed to register overdue reminders: %w", err) - } - - // Daily notifications - 11:00 AM UTC daily - _, err = w.scheduler.Register("0 11 * * *", asynq.NewTask(TypeDailyNotifications, nil)) - if err != nil { - return fmt.Errorf("failed to register daily notifications: %w", err) - } - - return nil -} - -// Start starts the worker server and scheduler -func (w *WorkerServer) Start(mux *asynq.ServeMux) error { - // Start scheduler - if err := w.scheduler.Start(); err != nil { - return fmt.Errorf("failed to start scheduler: %w", err) - } - - // Start server - if err := w.server.Start(mux); err != nil { - return fmt.Errorf("failed to start worker server: %w", err) - } - - return nil -} - -// Shutdown gracefully shuts down the worker server -func (w *WorkerServer) Shutdown() { - w.scheduler.Shutdown() - w.server.Shutdown() -} diff --git a/pkg/utils/logger.go b/pkg/utils/logger.go index a13b7f2..6a7921d 100644 --- a/pkg/utils/logger.go +++ b/pkg/utils/logger.go @@ -2,12 +2,16 @@ package utils import ( "io" + "net/http" "os" + "runtime/debug" "time" "github.com/labstack/echo/v4" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + + "github.com/treytartt/honeydue-api/internal/dto/responses" ) // InitLogger initializes the zerolog logger @@ -113,14 +117,17 @@ func EchoRecovery() echo.MiddlewareFunc { return func(c echo.Context) error { defer func() { if err := recover(); err != nil { + // F-14: Include full stack trace for debugging log.Error(). Interface("error", err). Str("path", c.Request().URL.Path). Str("method", c.Request().Method). + Str("stack", string(debug.Stack())). Msg("Panic recovered") - c.JSON(500, map[string]interface{}{ - "error": "Internal server error", + // F-15: Use the project's standard ErrorResponse struct + c.JSON(http.StatusInternalServerError, responses.ErrorResponse{ + Error: "Internal server error", }) } }()